Skip to content

Commit bc7497a

Browse files
CUDA: use fastdiv for mmvq + q8_1 quantization
1 parent fb15d64 commit bc7497a

File tree

2 files changed

+40
-53
lines changed

2 files changed

+40
-53
lines changed

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 28 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,10 @@ template <ggml_type type, int ncols_dst>
141141
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142142
static __global__ void mul_mat_vec_q(
143143
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
144-
const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
145-
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
146-
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
144+
const uint32_t ncols_x, const uint32_t nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
145+
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
146+
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
147+
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
147148

148149
constexpr int qk = ggml_cuda_type_traits<type>::qk;
149150
constexpr int qi = ggml_cuda_type_traits<type>::qi;
@@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q(
161162
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
162163

163164
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
164-
const int channel_dst = blockIdx.y;
165-
const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
166-
const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
167-
const int sample_dst = blockIdx.z;
168-
const int sample_x = sample_dst / sample_ratio;
169-
const int sample_y = sample_dst;
165+
const uint32_t channel_dst = blockIdx.y;
166+
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
167+
const uint32_t channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
168+
const uint32_t sample_dst = blockIdx.z;
169+
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
170+
const uint32_t sample_y = sample_dst;
170171

171172
// partial sum for each thread
172173
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
@@ -247,95 +248,79 @@ static void mul_mat_vec_q_switch_ncols_dst(
247248
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
248249
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
249250

250-
const int channel_ratio = nchannels_dst / nchannels_x;
251-
const int sample_ratio = nsamples_dst / nsamples_x;
251+
const uint3 channel_ratio = init_fastdiv_values(nchannels_dst / nchannels_x);
252+
const uint3 sample_ratio = init_fastdiv_values(nsamples_dst / nsamples_x);
252253

253254
const int device = ggml_cuda_get_device();
254255
const int warp_size = ggml_cuda_info().devices[device].warp_size;
255256
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
256257

257258
GGML_ASSERT(!ids || ncols_dst == 1);
258259
switch (ncols_dst) {
259-
case 1:
260-
{
260+
case 1: {
261261
constexpr int c_ncols_dst = 1;
262262
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
263263
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
264264
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
265265
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
266266
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
267-
break;
268-
}
269-
case 2:
270-
{
267+
} break;
268+
case 2: {
271269
constexpr int c_ncols_dst = 2;
272270
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
273271
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
274272
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
275273
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
276274
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
277-
break;
278-
}
279-
case 3:
280-
{
275+
} break;
276+
case 3: {
281277
constexpr int c_ncols_dst = 3;
282278
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
283279
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
284280
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
285281
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
286282
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
287-
break;
288-
}
289-
case 4:
290-
{
283+
} break;
284+
case 4: {
291285
constexpr int c_ncols_dst = 4;
292286
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
293287
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
294288
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
295289
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
296290
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
297-
break;
298-
}
299-
case 5:
300-
{
291+
} break;
292+
case 5: {
301293
constexpr int c_ncols_dst = 5;
302294
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
303295
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
304296
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
305297
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
306298
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
307-
break;
308-
}
309-
case 6:
310-
{
299+
} break;
300+
case 6: {
311301
constexpr int c_ncols_dst = 6;
312302
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
313303
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
314304
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
315305
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
316306
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
317-
break;
318-
}
319-
case 7:
320-
{
307+
} break;
308+
case 7: {
321309
constexpr int c_ncols_dst = 7;
322310
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
323311
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
324312
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
325313
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
326314
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
327-
break;
328-
}
329-
case 8:
330-
{
315+
} break;
316+
case 8: {
331317
constexpr int c_ncols_dst = 8;
332318
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
333319
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
334320
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
335321
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
336322
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
337-
break;
338-
}
323+
} break;
339324
default:
340325
GGML_ABORT("fatal error");
341326
break;

ggml/src/ggml-cuda/quantize.cu

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
#include "quantize.cuh"
22
#include <cstdint>
33

4+
__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1)
45
static __global__ void quantize_q8_1(
56
const float * __restrict__ x, void * __restrict__ vy,
67
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
7-
const int64_t ne0, const int ne1, const int ne2) {
8+
const int64_t ne0, const uint32_t ne1, const uint3 ne2) {
89
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
910

1011
if (i0 >= ne0) {
1112
return;
1213
}
1314

15+
const int64_t i3 = fastdiv(blockIdx.z, ne2);
16+
const int64_t i2 = blockIdx.z - i3*ne2.z;
1417
const int64_t i1 = blockIdx.y;
15-
const int64_t i2 = blockIdx.z % ne2;
16-
const int64_t i3 = blockIdx.z / ne2;
1718

1819
const int64_t & i00 = i0;
1920
const int64_t & i01 = i1;
2021
const int64_t & i02 = i2;
2122
const int64_t & i03 = i3;
2223

23-
const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0;
24+
const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0;
2425

2526
block_q8_1 * y = (block_q8_1 *) vy;
2627

@@ -31,10 +32,10 @@ static __global__ void quantize_q8_1(
3132
float amax = fabsf(xi);
3233
float sum = xi;
3334

34-
amax = warp_reduce_max(amax);
35-
sum = warp_reduce_sum(sum);
35+
amax = warp_reduce_max<QK8_1>(amax);
36+
sum = warp_reduce_sum<QK8_1>(sum);
3637

37-
const float d = amax / 127;
38+
const float d = amax / 127.0f;
3839
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
3940

4041
y[ib].qs[iqs] = q;
@@ -43,8 +44,7 @@ static __global__ void quantize_q8_1(
4344
return;
4445
}
4546

46-
reinterpret_cast<half&>(y[ib].ds.x) = d;
47-
reinterpret_cast<half&>(y[ib].ds.y) = sum;
47+
y[ib].ds = make_half2(d, sum);
4848
}
4949

5050
template <mmq_q8_1_ds_layout ds_layout>
@@ -152,10 +152,12 @@ void quantize_row_q8_1_cuda(
152152
GGML_ASSERT(!ids);
153153
GGML_ASSERT(ne0 % QK8_1 == 0);
154154

155+
const uint3 ne2_fastdiv = init_fastdiv_values(ne2);
156+
155157
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
156158
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
157159
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
158-
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
160+
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
159161
GGML_UNUSED(type_src0);
160162
}
161163

0 commit comments

Comments
 (0)