Skip to content

Commit cfb36fa

Browse files
am17anNexesenex
authored andcommitted
CUDA: add softmax broadcast (ggml-org#14475)
* CUDA: add softmax broadcast * Pass by const ref * Review: Use blockDims for indexing, remove designated initializers * Add TODO for noncontigous input/output
1 parent db830e8 commit cfb36fa

File tree

1 file changed

+102
-29
lines changed

1 file changed

+102
-29
lines changed

ggml/src/ggml-cuda/softmax.cu

Lines changed: 102 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,29 @@ __device__ float __forceinline__ t2f32<half>(half val) {
1313
return __half2float(val);
1414
}
1515

16+
struct soft_max_params {
17+
18+
int64_t nheads;
19+
uint32_t n_head_log2;
20+
int64_t ncols;
21+
int64_t nrows_x;
22+
int64_t nrows_y;
23+
int64_t ne00;
24+
int64_t ne01;
25+
int64_t ne02;
26+
int64_t ne03;
27+
int64_t nb11;
28+
int64_t nb12;
29+
int64_t nb13;
30+
31+
int64_t ne12;
32+
int64_t ne13;
33+
float scale;
34+
float max_bias;
35+
float m0;
36+
float m1;
37+
};
38+
1639
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
1740
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
1841
#ifdef __clang__
@@ -21,24 +44,32 @@ __device__ float __forceinline__ t2f32<half>(half val) {
2144
#endif // __clang__
2245
template <bool use_shared, int ncols_template, int block_size_template, typename T>
2346
static __global__ void soft_max_f32(
24-
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
25-
const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
26-
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
47+
const float * x, const T * mask, float * dst, const soft_max_params p) {
48+
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
2749

2850
const int tid = threadIdx.x;
29-
const int rowx = blockIdx.x;
30-
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
51+
52+
const int64_t i03 = blockIdx.z;
53+
const int64_t i02 = blockIdx.y;
54+
const int64_t i01 = blockIdx.x;
55+
56+
//TODO: noncontigous inputs/outputs
57+
const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
58+
59+
const int64_t i11 = i01;
60+
const int64_t i12 = i02 % p.ne12;
61+
const int64_t i13 = i03 % p.ne13;
3162

3263
x += int64_t(rowx)*ncols;
33-
mask += int64_t(rowy)*ncols * (mask != nullptr);
64+
mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
3465
dst += int64_t(rowx)*ncols;
3566

3667
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
3768

3869
const int warp_id = threadIdx.x / WARP_SIZE;
3970
const int lane_id = threadIdx.x % WARP_SIZE;
4071

41-
const float slope = get_alibi_slope(max_bias, rowx/nrows_y, n_head_log2, m0, m1);
72+
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
4273

4374
extern __shared__ float data_soft_max_f32[];
4475
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
@@ -60,7 +91,9 @@ static __global__ void soft_max_f32(
6091

6192
// const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
6293

63-
const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
94+
// const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
95+
96+
const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
6497

6598
vals[col] = val;
6699
max_val = max(max_val, val);
@@ -156,63 +189,60 @@ static __global__ void soft_max_back_f32(
156189
}
157190

158191
template<typename T>
159-
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
192+
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const soft_max_params & params, cudaStream_t stream) {
160193
int nth = WARP_SIZE;
194+
const int64_t ncols_x = params.ncols;
195+
161196
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
162197
const dim3 block_dims(nth, 1, 1);
163-
const dim3 block_nums(nrows_x, 1, 1);
198+
const dim3 block_nums(params.ne01, params.ne02, params.ne03);
164199
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
165200
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
166201

167-
const uint32_t n_head = nrows_x/nrows_y;
168-
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
169-
170-
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
171-
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
172202

173203
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
174204
if (nbytes_shared < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
175205
switch (ncols_x) {
176206
case 32:
177207
soft_max_f32<true, 32, 32><<<block_nums, block_dims, nbytes_shared, stream>>>
178-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
208+
(x, mask, dst, params);
179209
break;
180210
case 64:
181211
soft_max_f32<true, 64, 64><<<block_nums, block_dims, nbytes_shared, stream>>>
182-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
212+
(x, mask, dst, params);
183213
break;
184214
case 128:
185215
soft_max_f32<true, 128, 128><<<block_nums, block_dims, nbytes_shared, stream>>>
186-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
216+
(x, mask, dst, params);
187217
break;
188218
case 256:
189219
soft_max_f32<true, 256, 256><<<block_nums, block_dims, nbytes_shared, stream>>>
190-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
220+
(x, mask, dst, params);
191221
break;
192222
case 512:
193223
soft_max_f32<true, 512, 512><<<block_nums, block_dims, nbytes_shared, stream>>>
194-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
224+
(x, mask, dst, params);
195225
break;
196226
case 1024:
197227
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
198-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
228+
(x, mask, dst, params);
199229
break;
200230
case 2048:
201231
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
202-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
232+
(x, mask, dst, params);
203233
break;
204234
case 4096:
205235
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, nbytes_shared, stream>>>
206-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
236+
(x, mask, dst, params);
207237
break;
208238
default:
209239
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>
210-
(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
240+
(x, mask, dst, params);
211241
break;
212242
}
213243
} else {
214244
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
215-
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
245+
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, dst, params);
216246
}
217247
}
218248

@@ -240,10 +270,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
240270

241271
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
242272

243-
const int64_t ne00 = src0->ne[0];
244273
const int64_t nrows_x = ggml_nrows(src0);
245274
const int64_t nrows_y = src0->ne[1];
246275

276+
const int64_t ne00 = src0->ne[0];
277+
247278
float scale = 1.0f;
248279
float max_bias = 0.0f;
249280

@@ -252,14 +283,56 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
252283

253284
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
254285

286+
const int64_t nb11 = src1 ? src1->nb[1] : 1;
287+
const int64_t nb12 = src1 ? src1->nb[2] : 1;
288+
const int64_t nb13 = src1 ? src1->nb[3] : 1;
289+
290+
const int64_t ne12 = src1 ? src1->ne[2] : 1;
291+
const int64_t ne13 = src1 ? src1->ne[3] : 1;
292+
293+
const uint32_t n_head = src0->ne[2];
294+
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
295+
296+
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
297+
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
298+
299+
300+
soft_max_params params = {};
301+
params.nheads = src0->ne[2];
302+
params.n_head_log2 = n_head_log2;
303+
params.ncols = ne00;
304+
params.nrows_x = nrows_x;
305+
params.nrows_y = nrows_y;
306+
params.ne00 = src0->ne[0];
307+
params.ne01 = src0->ne[1];
308+
params.ne02 = src0->ne[2];
309+
params.ne03 = src0->ne[3];
310+
params.nb11 = nb11;
311+
params.nb12 = nb12;
312+
params.nb13 = nb13;
313+
params.ne12 = ne12;
314+
params.ne13 = ne13;
315+
params.scale = scale;
316+
params.max_bias = max_bias;
317+
params.m0 = m0;
318+
params.m1 = m1;
319+
255320
if (use_f16) {
321+
256322
// const half * src1_dd = (const half *)src1_d;
257323

258-
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
259-
} else {
324+
// soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
325+
326+
// } else {
327+
260328
// const float * src1_dd = (const float *)src1_d;
261329

262-
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
330+
// soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
331+
332+
soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, params, stream);
333+
} else {
334+
soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, params, stream);
335+
263336
}
264337
}
265338

0 commit comments

Comments
 (0)