Skip to content

Commit 5143fa8

Browse files
CUDA: fastdiv, launch bounds for mmvq + q8_1 quant (#15802)
* CUDA: fastdiv, launch bounds for mmvq + q8_1 quant
1 parent 3a550b5 commit 5143fa8

File tree

3 files changed

+67
-77
lines changed

3 files changed

+67
-77
lines changed

ggml/src/ggml-cuda/common.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -570,6 +570,8 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
570570
//
571571
// n/d = (mulhi(n, mp) + n) >> L;
572572
static const uint3 init_fastdiv_values(uint32_t d) {
573+
GGML_ASSERT(d != 0);
574+
573575
// compute L = ceil(log2(d));
574576
uint32_t L = 0;
575577
while (L < 32 && (uint32_t{ 1 } << L) < d) {

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 53 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,10 @@ template <ggml_type type, int ncols_dst>
141141
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
142142
static __global__ void mul_mat_vec_q(
143143
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
144-
const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
145-
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
146-
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
144+
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
145+
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
146+
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
147+
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
147148

148149
constexpr int qk = ggml_cuda_type_traits<type>::qk;
149150
constexpr int qi = ggml_cuda_type_traits<type>::qi;
@@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q(
161162
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
162163

163164
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
164-
const int channel_dst = blockIdx.y;
165-
const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
166-
const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
167-
const int sample_dst = blockIdx.z;
168-
const int sample_x = sample_dst / sample_ratio;
169-
const int sample_y = sample_dst;
165+
const uint32_t channel_dst = blockIdx.y;
166+
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
167+
const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
168+
const uint32_t sample_dst = blockIdx.z;
169+
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
170+
const uint32_t sample_y = sample_dst;
170171

171172
// partial sum for each thread
172173
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
@@ -247,95 +248,80 @@ static void mul_mat_vec_q_switch_ncols_dst(
247248
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
248249
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
249250

250-
const int channel_ratio = nchannels_dst / nchannels_x;
251-
const int sample_ratio = nsamples_dst / nsamples_x;
251+
const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
252+
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
253+
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
252254

253255
const int device = ggml_cuda_get_device();
254256
const int warp_size = ggml_cuda_info().devices[device].warp_size;
255257
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
256258

257259
GGML_ASSERT(!ids || ncols_dst == 1);
258260
switch (ncols_dst) {
259-
case 1:
260-
{
261+
case 1: {
261262
constexpr int c_ncols_dst = 1;
262263
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
263264
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
264-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
265-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
266-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
267-
break;
268-
}
269-
case 2:
270-
{
265+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
266+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
267+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
268+
} break;
269+
case 2: {
271270
constexpr int c_ncols_dst = 2;
272271
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
273272
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
274-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
275-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
276-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
277-
break;
278-
}
279-
case 3:
280-
{
273+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
274+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
275+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
276+
} break;
277+
case 3: {
281278
constexpr int c_ncols_dst = 3;
282279
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
283280
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
284-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
285-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
286-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
287-
break;
288-
}
289-
case 4:
290-
{
281+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
282+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
283+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
284+
} break;
285+
case 4: {
291286
constexpr int c_ncols_dst = 4;
292287
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
293288
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
294-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
295-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
296-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
297-
break;
298-
}
299-
case 5:
300-
{
289+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
290+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
291+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
292+
} break;
293+
case 5: {
301294
constexpr int c_ncols_dst = 5;
302295
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
303296
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
304-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
305-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
306-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
307-
break;
308-
}
309-
case 6:
310-
{
297+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
298+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
299+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
300+
} break;
301+
case 6: {
311302
constexpr int c_ncols_dst = 6;
312303
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
313304
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
314-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
315-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
316-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
317-
break;
318-
}
319-
case 7:
320-
{
305+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
306+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
307+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
308+
} break;
309+
case 7: {
321310
constexpr int c_ncols_dst = 7;
322311
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
323312
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
324-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
325-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
326-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
327-
break;
328-
}
329-
case 8:
330-
{
313+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
314+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
315+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
316+
} break;
317+
case 8: {
331318
constexpr int c_ncols_dst = 8;
332319
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
333320
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
334-
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
335-
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
336-
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
337-
break;
338-
}
321+
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
322+
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
323+
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
324+
} break;
339325
default:
340326
GGML_ABORT("fatal error");
341327
break;

ggml/src/ggml-cuda/quantize.cu

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,27 @@
11
#include "quantize.cuh"
22
#include <cstdint>
33

4+
__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1)
45
static __global__ void quantize_q8_1(
56
const float * __restrict__ x, void * __restrict__ vy,
67
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
7-
const int64_t ne0, const int ne1, const int ne2) {
8+
const int64_t ne0, const uint32_t ne1, const uint3 ne2) {
89
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
910

1011
if (i0 >= ne0) {
1112
return;
1213
}
1314

15+
const int64_t i3 = fastdiv(blockIdx.z, ne2);
16+
const int64_t i2 = blockIdx.z - i3*ne2.z;
1417
const int64_t i1 = blockIdx.y;
15-
const int64_t i2 = blockIdx.z % ne2;
16-
const int64_t i3 = blockIdx.z / ne2;
1718

1819
const int64_t & i00 = i0;
1920
const int64_t & i01 = i1;
2021
const int64_t & i02 = i2;
2122
const int64_t & i03 = i3;
2223

23-
const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0;
24+
const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0;
2425

2526
block_q8_1 * y = (block_q8_1 *) vy;
2627

@@ -31,10 +32,10 @@ static __global__ void quantize_q8_1(
3132
float amax = fabsf(xi);
3233
float sum = xi;
3334

34-
amax = warp_reduce_max(amax);
35-
sum = warp_reduce_sum(sum);
35+
amax = warp_reduce_max<QK8_1>(amax);
36+
sum = warp_reduce_sum<QK8_1>(sum);
3637

37-
const float d = amax / 127;
38+
const float d = amax / 127.0f;
3839
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
3940

4041
y[ib].qs[iqs] = q;
@@ -43,8 +44,7 @@ static __global__ void quantize_q8_1(
4344
return;
4445
}
4546

46-
reinterpret_cast<half&>(y[ib].ds.x) = d;
47-
reinterpret_cast<half&>(y[ib].ds.y) = sum;
47+
y[ib].ds = make_half2(d, sum);
4848
}
4949

5050
template <mmq_q8_1_ds_layout ds_layout>
@@ -152,10 +152,12 @@ void quantize_row_q8_1_cuda(
152152
GGML_ASSERT(!ids);
153153
GGML_ASSERT(ne0 % QK8_1 == 0);
154154

155+
const uint3 ne2_fastdiv = init_fastdiv_values(ne2);
156+
155157
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
156158
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
157159
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
158-
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
160+
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
159161
GGML_UNUSED(type_src0);
160162
}
161163

0 commit comments

Comments
 (0)