@@ -141,9 +141,10 @@ template <ggml_type type, int ncols_dst>
141
141
__launch_bounds__ (calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142
142
static __global__ void mul_mat_vec_q(
143
143
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
144
- const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
145
- const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
146
- const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
144
+ const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
145
+ const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
146
+ const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
147
+ const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
147
148
148
149
constexpr int qk = ggml_cuda_type_traits<type>::qk;
149
150
constexpr int qi = ggml_cuda_type_traits<type>::qi;
@@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q(
161
162
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
162
163
163
164
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
164
- const int channel_dst = blockIdx .y ;
165
- const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
166
- const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
167
- const int sample_dst = blockIdx .z ;
168
- const int sample_x = sample_dst / sample_ratio;
169
- const int sample_y = sample_dst;
165
+ const uint32_t channel_dst = blockIdx .y ;
166
+ const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv ( channel_dst, channel_ratio) ;
167
+ const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo ( channel_dst, nchannels_y) : channel_dst;
168
+ const uint32_t sample_dst = blockIdx .z ;
169
+ const uint32_t sample_x = fastdiv ( sample_dst, sample_ratio) ;
170
+ const uint32_t sample_y = sample_dst;
170
171
171
172
// partial sum for each thread
172
173
float tmp[ncols_dst][rows_per_cuda_block] = {{0 .0f }};
@@ -247,95 +248,80 @@ static void mul_mat_vec_q_switch_ncols_dst(
247
248
GGML_ASSERT (ncols_x % ggml_blck_size (type) == 0 );
248
249
GGML_ASSERT (ncols_dst <= MMVQ_MAX_BATCH_SIZE);
249
250
250
- const int channel_ratio = nchannels_dst / nchannels_x;
251
- const int sample_ratio = nsamples_dst / nsamples_x;
251
+ const uint3 nchannels_y_fd = ids ? init_fastdiv_values (nchannels_y) : make_uint3 (0 , 0 , 0 );
252
+ const uint3 channel_ratio_fd = ids ? make_uint3 (0 , 0 , 0 ) : init_fastdiv_values (nchannels_dst / nchannels_x);
253
+ const uint3 sample_ratio_fd = init_fastdiv_values (nsamples_dst / nsamples_x);
252
254
253
255
const int device = ggml_cuda_get_device ();
254
256
const int warp_size = ggml_cuda_info ().devices [device].warp_size ;
255
257
const mmvq_parameter_table_id table_id = get_device_table_id (ggml_cuda_info ().devices [device].cc );
256
258
257
259
GGML_ASSERT (!ids || ncols_dst == 1 );
258
260
switch (ncols_dst) {
259
- case 1 :
260
- {
261
+ case 1 : {
261
262
constexpr int c_ncols_dst = 1 ;
262
263
std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
263
264
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
264
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
265
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
266
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
267
- break ;
268
- }
269
- case 2 :
270
- {
265
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
266
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
267
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
268
+ } break ;
269
+ case 2 : {
271
270
constexpr int c_ncols_dst = 2 ;
272
271
std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
273
272
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
274
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
275
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
276
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
277
- break ;
278
- }
279
- case 3 :
280
- {
273
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
274
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
275
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
276
+ } break ;
277
+ case 3 : {
281
278
constexpr int c_ncols_dst = 3 ;
282
279
std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
283
280
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
284
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
285
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
286
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
287
- break ;
288
- }
289
- case 4 :
290
- {
281
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
282
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
283
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
284
+ } break ;
285
+ case 4 : {
291
286
constexpr int c_ncols_dst = 4 ;
292
287
std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
293
288
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
294
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
295
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
296
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
297
- break ;
298
- }
299
- case 5 :
300
- {
289
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
290
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
291
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
292
+ } break ;
293
+ case 5 : {
301
294
constexpr int c_ncols_dst = 5 ;
302
295
std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
303
296
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
304
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
305
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
306
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
307
- break ;
308
- }
309
- case 6 :
310
- {
297
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
298
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
299
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
300
+ } break ;
301
+ case 6 : {
311
302
constexpr int c_ncols_dst = 6 ;
312
303
std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
313
304
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
314
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
315
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
316
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
317
- break ;
318
- }
319
- case 7 :
320
- {
305
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
306
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
307
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
308
+ } break ;
309
+ case 7 : {
321
310
constexpr int c_ncols_dst = 7 ;
322
311
std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
323
312
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
324
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
325
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
326
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
327
- break ;
328
- }
329
- case 8 :
330
- {
313
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
314
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
315
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
316
+ } break ;
317
+ case 8 : {
331
318
constexpr int c_ncols_dst = 8 ;
332
319
std::pair<dim3 , dim3 > dims = calc_launch_params (c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
333
320
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0 , stream>>>
334
- (vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
335
- channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
336
- sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
337
- break ;
338
- }
321
+ (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
322
+ channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
323
+ sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
324
+ } break ;
339
325
default :
340
326
GGML_ABORT (" fatal error" );
341
327
break ;
0 commit comments