Skip to content

Commit 3853003

Browse files
committed
Revert "CUDA: faster tile FA (Pascal/AMD), headsize 256 (ggml-org#15769)"
This reverts commit 79bc429.
1 parent fc6f98d commit 3853003

File tree

7 files changed

+769
-604
lines changed

7 files changed

+769
-604
lines changed
Lines changed: 371 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,371 @@
1+
#include "common.cuh"
2+
#include "fattn-common.cuh"
3+
#include "fattn-tile-f16.cuh"
4+
5+
#define FATTN_KQ_STRIDE_TILE_F16 64
6+
7+
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
8+
#if !defined(GGML_USE_HIP)
9+
__launch_bounds__(nwarps*WARP_SIZE, 2)
10+
#endif // !defined(GGML_USE_HIP)
11+
static __global__ void flash_attn_tile_ext_f16(
12+
const char * __restrict__ Q,
13+
const char * __restrict__ K,
14+
const char * __restrict__ V,
15+
const char * __restrict__ mask,
16+
const char * __restrict__ sinks,
17+
const int * __restrict__ KV_max,
18+
float * __restrict__ dst,
19+
float2 * __restrict__ dst_meta,
20+
const float scale,
21+
const float max_bias,
22+
const float m0,
23+
const float m1,
24+
const uint32_t n_head_log2,
25+
const float logit_softcap,
26+
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
27+
const int32_t nb01, const int32_t nb02, const int32_t nb03,
28+
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
29+
const int32_t nb11, const int32_t nb12, const int64_t nb13,
30+
const int32_t nb21, const int32_t nb22, const int64_t nb23,
31+
const int32_t ne31, const int32_t ne32, const int32_t ne33,
32+
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
33+
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
34+
35+
// Skip unused kernel variants for faster compilation:
36+
#ifdef FP16_MMA_AVAILABLE
37+
NO_DEVICE_CODE;
38+
return;
39+
#endif // FP16_MMA_AVAILABLE
40+
if (use_logit_softcap && !(D == 128 || D == 256)) {
41+
NO_DEVICE_CODE;
42+
return;
43+
}
44+
45+
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
46+
47+
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
48+
49+
const int sequence = blockIdx.z / ne02;
50+
const int head = blockIdx.z - sequence*ne02;
51+
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
52+
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
53+
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
54+
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
55+
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
56+
const float * sinksf = (const float *) (sinks);
57+
58+
const int stride_KV2 = nb11 / sizeof(half2);
59+
60+
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
61+
const half slopeh = __float2half(slopef);
62+
63+
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
64+
65+
__shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16];
66+
half2 * KQ2 = (half2 *) KQ;
67+
68+
__shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts.
69+
70+
half kqmax[ncols/nwarps];
71+
#pragma unroll
72+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
73+
kqmax[j0/nwarps] = -HALF_MAX_HALF;
74+
}
75+
half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}};
76+
77+
half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
78+
79+
// Convert Q to half2 and store in registers:
80+
__shared__ half2 Q_h2[ncols][D/2];
81+
#pragma unroll
82+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
83+
const int j = j0 + threadIdx.y;
84+
85+
#pragma unroll
86+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
87+
const int i = i0 + threadIdx.x;
88+
89+
const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
90+
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
91+
}
92+
}
93+
94+
__syncthreads();
95+
96+
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
97+
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
98+
// Calculate KQ tile and keep track of new maximum KQ values:
99+
100+
half kqmax_new[ncols/nwarps];
101+
#pragma unroll
102+
for (int j = 0; j < ncols/nwarps; ++j) {
103+
kqmax_new[j] = kqmax[j];
104+
}
105+
106+
#pragma unroll
107+
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) {
108+
const int i_KQ = i_KQ_0 + threadIdx.y;
109+
110+
#pragma unroll
111+
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
112+
const int k_KQ = k_KQ_0 + threadIdx.x;
113+
114+
KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
115+
}
116+
}
117+
118+
__syncthreads();
119+
120+
half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}};
121+
122+
#pragma unroll
123+
for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) {
124+
half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE];
125+
half2 Q_k[ncols/nwarps];
126+
127+
#pragma unroll
128+
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
129+
const int i_KQ = i_KQ_0 + threadIdx.x;
130+
131+
K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
132+
}
133+
#pragma unroll
134+
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
135+
const int j_KQ = j_KQ_0 + threadIdx.y;
136+
137+
Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ];
138+
}
139+
140+
#pragma unroll
141+
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
142+
#pragma unroll
143+
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
144+
sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps];
145+
}
146+
}
147+
}
148+
149+
#pragma unroll
150+
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
151+
const int i_KQ = i_KQ_0 + threadIdx.x;
152+
153+
#pragma unroll
154+
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
155+
const int j_KQ = j_KQ_0 + threadIdx.y;
156+
157+
half sum;
158+
if (use_logit_softcap) {
159+
const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
160+
sum = logit_softcap * tanhf(tmp.x + tmp.y);
161+
} else {
162+
sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
163+
}
164+
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
165+
166+
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
167+
168+
KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum;
169+
}
170+
}
171+
172+
__syncthreads();
173+
174+
#pragma unroll
175+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
176+
const int j = j0 + threadIdx.y;
177+
178+
kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
179+
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]));
180+
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
181+
182+
#pragma unroll
183+
for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) {
184+
const int i = i0 + threadIdx.x;
185+
186+
const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]);
187+
const half2 val = h2exp(diff);
188+
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val;
189+
KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val;
190+
}
191+
192+
#pragma unroll
193+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
194+
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
195+
}
196+
}
197+
198+
__syncthreads();
199+
200+
#pragma unroll
201+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) {
202+
const int k = k0 + threadIdx.y;
203+
204+
#pragma unroll
205+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
206+
const int i = i0 + threadIdx.x;
207+
208+
KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
209+
}
210+
}
211+
212+
__syncthreads();
213+
214+
#pragma unroll
215+
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) {
216+
half2 V_k[(D/2)/WARP_SIZE][2];
217+
half2 KQ_k[ncols/nwarps];
218+
219+
#pragma unroll
220+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
221+
const int i = i0 + threadIdx.x;
222+
223+
V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i];
224+
V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i];
225+
}
226+
#pragma unroll
227+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
228+
const int j = j0 + threadIdx.y;
229+
230+
KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2];
231+
}
232+
233+
#pragma unroll
234+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
235+
#pragma unroll
236+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
237+
VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]);
238+
VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]);
239+
}
240+
}
241+
}
242+
243+
__syncthreads();
244+
}
245+
246+
//Attention sink: adjust running max and sum once per head
247+
if (sinksf && blockIdx.y == 0) {
248+
const half sink = __float2half(sinksf[head]);
249+
250+
#pragma unroll
251+
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
252+
half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
253+
kqmax_new_j = warp_reduce_max(kqmax_new_j);
254+
255+
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
256+
kqmax[j0/nwarps] = kqmax_new_j;
257+
258+
const half val = hexp(sink - kqmax[j0/nwarps]);
259+
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
260+
if (threadIdx.x == 0) {
261+
kqsum[j0/nwarps].x = __hadd(__low2half(kqsum[j0/nwarps]), val);
262+
}
263+
264+
#pragma unroll
265+
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
266+
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
267+
}
268+
}
269+
}
270+
271+
float2 * dst2 = (float2 *) dst;
272+
273+
#pragma unroll
274+
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
275+
const int j_VKQ = j_VKQ_0 + threadIdx.y;
276+
277+
if (ic0 + j_VKQ >= ne01) {
278+
return;
279+
}
280+
281+
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
282+
kqsum_j = warp_reduce_sum((float)kqsum_j);
283+
284+
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
285+
286+
#pragma unroll
287+
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
288+
const int i0 = i00 + threadIdx.x;
289+
290+
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
291+
if (gridDim.y == 1) {
292+
dst_val /= __half2half2(kqsum_j);
293+
}
294+
dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
295+
}
296+
297+
if (gridDim.y != 1 && threadIdx.x == 0) {
298+
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
299+
}
300+
}
301+
#else
302+
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
303+
max_bias, m0, m1, n_head_log2, logit_softcap,
304+
ne00, ne01, ne02, ne03,
305+
nb01, nb02, nb03,
306+
ne10, ne11, ne12, ne13,
307+
nb11, nb12, nb13,
308+
nb21, nb22, nb23,
309+
ne31, ne32, ne33,
310+
nb31, nb32, nb33);
311+
NO_DEVICE_CODE;
312+
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
313+
}
314+
315+
template <int cols_per_block, bool use_logit_softcap>
316+
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
317+
const ggml_tensor * Q = dst->src[0];
318+
switch (Q->ne[0]) {
319+
case 64: {
320+
constexpr int D = 64;
321+
constexpr int nwarps = 8;
322+
constexpr size_t nbytes_shared = 0;
323+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
324+
launch_fattn<D, cols_per_block, 1>
325+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
326+
} break;
327+
case 128: {
328+
constexpr int D = 128;
329+
constexpr int nwarps = 8;
330+
constexpr size_t nbytes_shared = 0;
331+
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
332+
launch_fattn<D, cols_per_block, 1>
333+
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
334+
} break;
335+
default: {
336+
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
337+
} break;
338+
}
339+
}
340+
341+
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
342+
const ggml_tensor * KQV = dst;
343+
const ggml_tensor * Q = dst->src[0];
344+
345+
const int32_t precision = KQV->op_params[3];
346+
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
347+
348+
float logit_softcap;
349+
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
350+
351+
if (Q->ne[1] <= 16) {
352+
constexpr int cols_per_block = 16;
353+
if (logit_softcap == 0.0f) {
354+
constexpr bool use_logit_softcap = false;
355+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
356+
} else {
357+
constexpr bool use_logit_softcap = true;
358+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
359+
}
360+
return;
361+
}
362+
363+
constexpr int cols_per_block = 32;
364+
if (logit_softcap == 0.0f) {
365+
constexpr bool use_logit_softcap = false;
366+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
367+
} else {
368+
constexpr bool use_logit_softcap = true;
369+
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
370+
}
371+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
#include "common.cuh"
2+
3+
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

0 commit comments

Comments
 (0)