Skip to content

Commit 9d019ef

Browse files
qli88Akshat-Tripathi
authored andcommitted
[core] Perf improvement for DSv3 on AMD GPUs (vllm-project#13718)
Signed-off-by: qli88 <[email protected]>
1 parent 810e7c5 commit 9d019ef

File tree

3 files changed

+210
-25
lines changed

3 files changed

+210
-25
lines changed

vllm/attention/backends/mla/common.py

Lines changed: 72 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -237,14 +237,20 @@
237237

238238
try:
239239
from vllm.vllm_flash_attn import flash_attn_varlen_func
240+
is_vllm_fa = True
240241
except ImportError:
241242
# For rocm use upstream flash attention
242243
from flash_attn import flash_attn_varlen_func
244+
is_vllm_fa = False
245+
246+
from vllm.attention.ops.triton_flash_attention import triton_attention
243247

244248
if TYPE_CHECKING:
245249
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
246250
ModelInputForGPUWithSamplingMetadata)
247251

252+
is_hip = current_platform.is_rocm()
253+
248254

249255
class MLACommonBackend(AttentionBackend):
250256

@@ -1046,12 +1052,13 @@ def __init__(
10461052
self.q_proj = q_proj
10471053
self.kv_b_proj = kv_b_proj
10481054
self.o_proj = o_proj
1049-
self.vllm_flash_attn_version = get_flash_attn_version()
1055+
self.triton_fa_func = triton_attention
10501056

10511057
# Handle the differences between the flash_attn_varlen from flash_attn
10521058
# and the one from vllm_flash_attn. The former is used on RoCM and the
10531059
# latter has an additional parameter to control FA2 vs FA3
10541060
self.flash_attn_varlen_func = flash_attn_varlen_func
1061+
self.vllm_flash_attn_version = get_flash_attn_version()
10551062
if self.vllm_flash_attn_version is not None:
10561063
self.flash_attn_varlen_func = \
10571064
functools.partial(flash_attn_varlen_func,
@@ -1315,18 +1322,48 @@ def _compute_prefill_context(
13151322
[0, q.shape[-1] - v.shape[-1]],
13161323
value=0)
13171324

1318-
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
1319-
q=q,
1320-
k=k,
1321-
v=v_padded,
1322-
cu_seqlens_q=prefill_metadata.query_start_loc,
1323-
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
1324-
max_seqlen_q=prefill_metadata.max_query_len,
1325-
max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i],
1326-
softmax_scale=self.scale,
1327-
causal=False, # Context is unmasked
1328-
return_softmax_lse=True,
1329-
)
1325+
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
1326+
attn_output, attn_softmax_lse = self.triton_fa_func(
1327+
q,
1328+
k,
1329+
v_padded,
1330+
None,
1331+
prefill_metadata.query_start_loc,
1332+
prefill_metadata.context_chunk_cu_seq_lens[i],
1333+
prefill_metadata.max_query_len,
1334+
prefill_metadata.context_chunk_max_seq_lens[i],
1335+
False, # causal
1336+
self.scale,
1337+
None, # attn_mask is None unless applying ALiBi mask
1338+
)
1339+
elif is_vllm_fa:
1340+
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
1341+
q=q,
1342+
k=k,
1343+
v=v_padded,
1344+
cu_seqlens_q=prefill_metadata.query_start_loc,
1345+
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
1346+
max_seqlen_q=prefill_metadata.max_query_len,
1347+
max_seqlen_k=prefill_metadata.
1348+
context_chunk_max_seq_lens[i],
1349+
softmax_scale=self.scale,
1350+
causal=False, # Context is unmasked
1351+
return_softmax_lse=True,
1352+
)
1353+
else:
1354+
attn_output, attn_softmax_lse, _ = self.flash_attn_varlen_func(
1355+
q=q,
1356+
k=k,
1357+
v=v_padded,
1358+
cu_seqlens_q=prefill_metadata.query_start_loc,
1359+
cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
1360+
max_seqlen_q=prefill_metadata.max_query_len,
1361+
max_seqlen_k=prefill_metadata.
1362+
context_chunk_max_seq_lens[i],
1363+
softmax_scale=self.scale,
1364+
causal=False, # Context is unmasked
1365+
return_attn_probs=True,
1366+
)
13301367

13311368
if output is None:
13321369
output = attn_output
@@ -1374,11 +1411,24 @@ def _forward_prefill(
13741411
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
13751412
value=0)
13761413

1377-
if has_context:
1378-
if not current_platform.is_cuda():
1379-
raise NotImplementedError(
1380-
"Chunked Prefill for MLA is not currently supported on"
1381-
"non-cuda platforms")
1414+
if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN:
1415+
output = self.triton_fa_func(
1416+
q,
1417+
k,
1418+
v_padded,
1419+
None,
1420+
prefill_metadata.query_start_loc,
1421+
prefill_metadata.query_start_loc,
1422+
prefill_metadata.max_prefill_seq_len,
1423+
prefill_metadata.max_prefill_seq_len,
1424+
True, # causal
1425+
self.scale,
1426+
None, # attn_mask is None unless applying ALiBi mask
1427+
)
1428+
## triton flash attention always return 2 objects
1429+
if not has_context:
1430+
output = output[0]
1431+
elif is_vllm_fa:
13821432
output = self.flash_attn_varlen_func(
13831433
q=q,
13841434
k=k,
@@ -1389,7 +1439,7 @@ def _forward_prefill(
13891439
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
13901440
softmax_scale=self.scale,
13911441
causal=True,
1392-
return_softmax_lse=True,
1442+
return_softmax_lse=has_context,
13931443
)
13941444
else:
13951445
output = self.flash_attn_varlen_func(
@@ -1402,10 +1452,12 @@ def _forward_prefill(
14021452
max_seqlen_k=prefill_metadata.max_prefill_seq_len,
14031453
softmax_scale=self.scale,
14041454
causal=True,
1455+
return_attn_probs=has_context,
14051456
)
14061457

14071458
if has_context:
1408-
suffix_output, suffix_lse = output
1459+
# ROCm flash_attn_varlen_func will return 3 objects instead of 2
1460+
suffix_output, suffix_lse, *rest = output
14091461
context_output, context_lse = self._compute_prefill_context( \
14101462
q, kv_c_and_k_pe_cache, attn_metadata)
14111463

vllm/attention/ops/triton_decode_attention.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ def _decode_att_m_fwd(
178178
page_size,
179179
logit_cap,
180180
):
181-
BLOCK = 64
181+
BLOCK = 64 if not is_hip_ else 8
182+
182183
NUM_KV_SPLITS = num_kv_splits
183184
Lk = k_buffer.shape[-1]
184185
Lv = v_buffer.shape[-1]
@@ -188,7 +189,9 @@ def _decode_att_m_fwd(
188189
grid = (batch, head_num, NUM_KV_SPLITS)
189190
kv_group_num = q.shape[1] // k_buffer.shape[-2]
190191

191-
num_warps = 4 if kv_group_num == 1 else 2
192+
num_warps = 4
193+
if kv_group_num != 1:
194+
num_warps = 1 if is_hip_ else 2
192195

193196
BLOCK_DMODEL = triton.next_power_of_2(Lk)
194197
BLOCK_DV = triton.next_power_of_2(Lv)
@@ -418,14 +421,16 @@ def _decode_grouped_att_m_fwd(
418421
)
419422

420423
extra_kargs = {}
424+
num_stages = 2
421425
if is_hip_:
422-
# https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html
426+
# https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#mi300x-triton-kernel-performance-optimization
423427
# https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py
424428
extra_kargs = {
425-
"waves_per_eu": 4,
429+
"waves_per_eu": 1,
426430
"matrix_instr_nonkdim": 16,
427431
"kpack": 2
428432
}
433+
num_stages = 1
429434

430435
_fwd_grouped_kernel_stage1[grid](
431436
q,
@@ -456,7 +461,7 @@ def _decode_grouped_att_m_fwd(
456461
PAGE_SIZE=page_size,
457462
logit_cap=logit_cap,
458463
num_warps=4,
459-
num_stages=2,
464+
num_stages=num_stages,
460465
Lk=Lk,
461466
Lv=Lv,
462467
**extra_kargs,
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
{
2+
"1": {
3+
"BLOCK_SIZE_M": 16,
4+
"BLOCK_SIZE_N": 32,
5+
"BLOCK_SIZE_K": 256,
6+
"GROUP_SIZE_M": 1,
7+
"num_warps": 4,
8+
"num_stages": 2,
9+
"waves_per_eu": 0
10+
},
11+
"2": {
12+
"BLOCK_SIZE_M": 32,
13+
"BLOCK_SIZE_N": 16,
14+
"BLOCK_SIZE_K": 256,
15+
"GROUP_SIZE_M": 1,
16+
"num_warps": 2,
17+
"num_stages": 2,
18+
"waves_per_eu": 0
19+
},
20+
"4": {
21+
"BLOCK_SIZE_M": 16,
22+
"BLOCK_SIZE_N": 64,
23+
"BLOCK_SIZE_K": 256,
24+
"GROUP_SIZE_M": 1,
25+
"num_warps": 4,
26+
"num_stages": 2,
27+
"waves_per_eu": 0
28+
},
29+
"8": {
30+
"BLOCK_SIZE_M": 16,
31+
"BLOCK_SIZE_N": 128,
32+
"BLOCK_SIZE_K": 128,
33+
"GROUP_SIZE_M": 1,
34+
"num_warps": 4,
35+
"num_stages": 2,
36+
"waves_per_eu": 0
37+
},
38+
"16": {
39+
"BLOCK_SIZE_M": 16,
40+
"BLOCK_SIZE_N": 64,
41+
"BLOCK_SIZE_K": 128,
42+
"GROUP_SIZE_M": 4,
43+
"num_warps": 2,
44+
"num_stages": 2,
45+
"waves_per_eu": 0
46+
},
47+
"24": {
48+
"BLOCK_SIZE_M": 16,
49+
"BLOCK_SIZE_N": 128,
50+
"BLOCK_SIZE_K": 128,
51+
"GROUP_SIZE_M": 1,
52+
"num_warps": 4,
53+
"num_stages": 2,
54+
"waves_per_eu": 0
55+
},
56+
"32": {
57+
"BLOCK_SIZE_M": 16,
58+
"BLOCK_SIZE_N": 64,
59+
"BLOCK_SIZE_K": 128,
60+
"GROUP_SIZE_M": 4,
61+
"num_warps": 2,
62+
"num_stages": 2,
63+
"waves_per_eu": 0
64+
},
65+
"48": {
66+
"BLOCK_SIZE_M": 16,
67+
"BLOCK_SIZE_N": 64,
68+
"BLOCK_SIZE_K": 128,
69+
"GROUP_SIZE_M": 4,
70+
"num_warps": 2,
71+
"num_stages": 2,
72+
"waves_per_eu": 0
73+
},
74+
"64": {
75+
"BLOCK_SIZE_M": 16,
76+
"BLOCK_SIZE_N": 64,
77+
"BLOCK_SIZE_K": 128,
78+
"GROUP_SIZE_M": 1,
79+
"num_warps": 2,
80+
"num_stages": 2,
81+
"waves_per_eu": 0
82+
},
83+
"96": {
84+
"BLOCK_SIZE_M": 16,
85+
"BLOCK_SIZE_N": 64,
86+
"BLOCK_SIZE_K": 128,
87+
"GROUP_SIZE_M": 4,
88+
"num_warps": 4,
89+
"num_stages": 2,
90+
"waves_per_eu": 0
91+
},
92+
"128": {
93+
"BLOCK_SIZE_M": 16,
94+
"BLOCK_SIZE_N": 64,
95+
"BLOCK_SIZE_K": 256,
96+
"GROUP_SIZE_M": 1,
97+
"num_warps": 2,
98+
"num_stages": 2,
99+
"waves_per_eu": 0
100+
},
101+
"256": {
102+
"BLOCK_SIZE_M": 16,
103+
"BLOCK_SIZE_N": 64,
104+
"BLOCK_SIZE_K": 128,
105+
"GROUP_SIZE_M": 4,
106+
"num_warps": 4,
107+
"num_stages": 2,
108+
"waves_per_eu": 0
109+
},
110+
"512": {
111+
"BLOCK_SIZE_M": 32,
112+
"BLOCK_SIZE_N": 256,
113+
"BLOCK_SIZE_K": 128,
114+
"GROUP_SIZE_M": 8,
115+
"num_warps": 8,
116+
"num_stages": 2,
117+
"waves_per_eu": 0
118+
},
119+
"1024": {
120+
"BLOCK_SIZE_M": 64,
121+
"BLOCK_SIZE_N": 256,
122+
"BLOCK_SIZE_K": 128,
123+
"GROUP_SIZE_M": 8,
124+
"num_warps": 8,
125+
"num_stages": 2,
126+
"waves_per_eu": 0
127+
}
128+
}

0 commit comments

Comments
 (0)