@@ -13,6 +13,29 @@ __device__ float __forceinline__ t2f32<half>(half val) {
13
13
return __half2float (val);
14
14
}
15
15
16
+ struct soft_max_params {
17
+
18
+ int64_t nheads;
19
+ uint32_t n_head_log2;
20
+ int64_t ncols;
21
+ int64_t nrows_x;
22
+ int64_t nrows_y;
23
+ int64_t ne00;
24
+ int64_t ne01;
25
+ int64_t ne02;
26
+ int64_t ne03;
27
+ int64_t nb11;
28
+ int64_t nb12;
29
+ int64_t nb13;
30
+
31
+ int64_t ne12;
32
+ int64_t ne13;
33
+ float scale;
34
+ float max_bias;
35
+ float m0;
36
+ float m1;
37
+ };
38
+
16
39
// When ncols_template == 0 the bounds for the loops in this function are not known and can't be unrolled.
17
40
// As we want to keep pragma unroll for all other cases we supress the clang transformation warning here.
18
41
#ifdef __clang__
@@ -21,24 +44,32 @@ __device__ float __forceinline__ t2f32<half>(half val) {
21
44
#endif // __clang__
22
45
template <bool use_shared, int ncols_template, int block_size_template, typename T>
23
46
static __global__ void soft_max_f32 (
24
- const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
25
- const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
26
- const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
47
+ const float * x, const T * mask, float * dst, const soft_max_params p) {
48
+ const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
27
49
28
50
const int tid = threadIdx .x ;
29
- const int rowx = blockIdx .x ;
30
- const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
51
+
52
+ const int64_t i03 = blockIdx .z ;
53
+ const int64_t i02 = blockIdx .y ;
54
+ const int64_t i01 = blockIdx .x ;
55
+
56
+ // TODO: noncontigous inputs/outputs
57
+ const int rowx = blockIdx .x + blockIdx .y * gridDim .x + blockIdx .z * gridDim .x * gridDim .y ;
58
+
59
+ const int64_t i11 = i01;
60
+ const int64_t i12 = i02 % p.ne12 ;
61
+ const int64_t i13 = i03 % p.ne13 ;
31
62
32
63
x += int64_t (rowx)*ncols;
33
- mask += int64_t (rowy)*ncols * (mask != nullptr );
64
+ mask += (i11*p. nb11 + i12*p. nb12 + i13*p. nb13 ) / sizeof (T) * (mask != nullptr );
34
65
dst += int64_t (rowx)*ncols;
35
66
36
67
const int block_size = block_size_template == 0 ? blockDim .x : block_size_template;
37
68
38
69
const int warp_id = threadIdx .x / WARP_SIZE;
39
70
const int lane_id = threadIdx .x % WARP_SIZE;
40
71
41
- const float slope = get_alibi_slope (max_bias, rowx/nrows_y, n_head_log2, m0, m1);
72
+ const float slope = get_alibi_slope (p. max_bias , i02, p. n_head_log2 , p. m0 , p. m1 );
42
73
43
74
extern __shared__ float data_soft_max_f32[];
44
75
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
@@ -60,7 +91,9 @@ static __global__ void soft_max_f32(
60
91
61
92
// const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
62
93
63
- const float val = x[col]*scale + (mask ? slope*t2f32 (mask[col]) : 0 .0f );
94
+ // const float val = x[col]*scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
95
+
96
+ const float val = x[col]*p.scale + (mask ? slope*t2f32 (mask[col]) : 0 .0f );
64
97
65
98
vals[col] = val;
66
99
max_val = max (max_val, val);
@@ -156,63 +189,60 @@ static __global__ void soft_max_back_f32(
156
189
}
157
190
158
191
template <typename T>
159
- static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias , cudaStream_t stream) {
192
+ static void soft_max_f32_cuda (const float * x, const T * mask, float * dst, const soft_max_params & params , cudaStream_t stream) {
160
193
int nth = WARP_SIZE;
194
+ const int64_t ncols_x = params.ncols ;
195
+
161
196
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2 ;
162
197
const dim3 block_dims (nth, 1 , 1 );
163
- const dim3 block_nums (nrows_x, 1 , 1 );
198
+ const dim3 block_nums (params. ne01 , params. ne02 , params. ne03 );
164
199
const size_t nbytes_shared = (GGML_PAD (ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof (float );
165
200
static_assert (CUDA_SOFT_MAX_BLOCK_SIZE == 1024 , " These values need to be adjusted." );
166
201
167
- const uint32_t n_head = nrows_x/nrows_y;
168
- const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
169
-
170
- const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
171
- const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
172
202
173
203
// FIXME: this limit could be raised by ~2-4x on Ampere or newer
174
204
if (nbytes_shared < ggml_cuda_info ().devices [ggml_cuda_get_device ()].smpb ) {
175
205
switch (ncols_x) {
176
206
case 32 :
177
207
soft_max_f32<true , 32 , 32 ><<<block_nums, block_dims, nbytes_shared, stream>>>
178
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
208
+ (x, mask, dst, params );
179
209
break ;
180
210
case 64 :
181
211
soft_max_f32<true , 64 , 64 ><<<block_nums, block_dims, nbytes_shared, stream>>>
182
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
212
+ (x, mask, dst, params );
183
213
break ;
184
214
case 128 :
185
215
soft_max_f32<true , 128 , 128 ><<<block_nums, block_dims, nbytes_shared, stream>>>
186
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
216
+ (x, mask, dst, params );
187
217
break ;
188
218
case 256 :
189
219
soft_max_f32<true , 256 , 256 ><<<block_nums, block_dims, nbytes_shared, stream>>>
190
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
220
+ (x, mask, dst, params );
191
221
break ;
192
222
case 512 :
193
223
soft_max_f32<true , 512 , 512 ><<<block_nums, block_dims, nbytes_shared, stream>>>
194
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
224
+ (x, mask, dst, params );
195
225
break ;
196
226
case 1024 :
197
227
soft_max_f32<true , 1024 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
198
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
228
+ (x, mask, dst, params );
199
229
break ;
200
230
case 2048 :
201
231
soft_max_f32<true , 2048 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
202
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
232
+ (x, mask, dst, params );
203
233
break ;
204
234
case 4096 :
205
235
soft_max_f32<true , 4096 , 1024 ><<<block_nums, block_dims, nbytes_shared, stream>>>
206
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
236
+ (x, mask, dst, params );
207
237
break ;
208
238
default :
209
239
soft_max_f32<true , 0 , 0 ><<<block_nums, block_dims, nbytes_shared, stream>>>
210
- (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
240
+ (x, mask, dst, params );
211
241
break ;
212
242
}
213
243
} else {
214
244
const size_t nbytes_shared_low = WARP_SIZE*sizeof (float );
215
- soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2 );
245
+ soft_max_f32<false , 0 , 0 ><<<block_nums, block_dims, nbytes_shared_low, stream>>> (x, mask, dst, params );
216
246
}
217
247
}
218
248
@@ -240,10 +270,11 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
240
270
241
271
GGML_ASSERT (!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
242
272
243
- const int64_t ne00 = src0->ne [0 ];
244
273
const int64_t nrows_x = ggml_nrows (src0);
245
274
const int64_t nrows_y = src0->ne [1 ];
246
275
276
+ const int64_t ne00 = src0->ne [0 ];
277
+
247
278
float scale = 1 .0f ;
248
279
float max_bias = 0 .0f ;
249
280
@@ -252,14 +283,56 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
252
283
253
284
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
254
285
286
+ const int64_t nb11 = src1 ? src1->nb [1 ] : 1 ;
287
+ const int64_t nb12 = src1 ? src1->nb [2 ] : 1 ;
288
+ const int64_t nb13 = src1 ? src1->nb [3 ] : 1 ;
289
+
290
+ const int64_t ne12 = src1 ? src1->ne [2 ] : 1 ;
291
+ const int64_t ne13 = src1 ? src1->ne [3 ] : 1 ;
292
+
293
+ const uint32_t n_head = src0->ne [2 ];
294
+ const uint32_t n_head_log2 = 1u << (uint32_t ) floorf (log2f ((float ) n_head));
295
+
296
+ const float m0 = powf (2 .0f , -(max_bias ) / n_head_log2);
297
+ const float m1 = powf (2 .0f , -(max_bias / 2 .0f ) / n_head_log2);
298
+
299
+
300
+ soft_max_params params = {};
301
+ params.nheads = src0->ne [2 ];
302
+ params.n_head_log2 = n_head_log2;
303
+ params.ncols = ne00;
304
+ params.nrows_x = nrows_x;
305
+ params.nrows_y = nrows_y;
306
+ params.ne00 = src0->ne [0 ];
307
+ params.ne01 = src0->ne [1 ];
308
+ params.ne02 = src0->ne [2 ];
309
+ params.ne03 = src0->ne [3 ];
310
+ params.nb11 = nb11;
311
+ params.nb12 = nb12;
312
+ params.nb13 = nb13;
313
+ params.ne12 = ne12;
314
+ params.ne13 = ne13;
315
+ params.scale = scale;
316
+ params.max_bias = max_bias;
317
+ params.m0 = m0;
318
+ params.m1 = m1;
319
+
255
320
if (use_f16) {
321
+
256
322
// const half * src1_dd = (const half *)src1_d;
257
323
258
- soft_max_f32_cuda (src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
259
- } else {
324
+ // soft_max_f32_cuda(src0_d, (const half *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
325
+
326
+ // } else {
327
+
260
328
// const float * src1_dd = (const float *)src1_d;
261
329
262
- soft_max_f32_cuda (src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
330
+ // soft_max_f32_cuda(src0_d, (const float *) src1_d, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
331
+
332
+ soft_max_f32_cuda (src0_d, (const half *) src1_d, dst_d, params, stream);
333
+ } else {
334
+ soft_max_f32_cuda (src0_d, (const float *) src1_d, dst_d, params, stream);
335
+
263
336
}
264
337
}
265
338
0 commit comments