237
237
238
238
try :
239
239
from vllm .vllm_flash_attn import flash_attn_varlen_func
240
+ is_vllm_fa = True
240
241
except ImportError :
241
242
# For rocm use upstream flash attention
242
243
from flash_attn import flash_attn_varlen_func
244
+ is_vllm_fa = False
245
+
246
+ from vllm .attention .ops .triton_flash_attention import triton_attention
243
247
244
248
if TYPE_CHECKING :
245
249
from vllm .worker .model_runner import (ModelInputForGPUBuilder ,
246
250
ModelInputForGPUWithSamplingMetadata )
247
251
252
+ is_hip = current_platform .is_rocm ()
253
+
248
254
249
255
class MLACommonBackend (AttentionBackend ):
250
256
@@ -1046,12 +1052,13 @@ def __init__(
1046
1052
self .q_proj = q_proj
1047
1053
self .kv_b_proj = kv_b_proj
1048
1054
self .o_proj = o_proj
1049
- self .vllm_flash_attn_version = get_flash_attn_version ()
1055
+ self .triton_fa_func = triton_attention
1050
1056
1051
1057
# Handle the differences between the flash_attn_varlen from flash_attn
1052
1058
# and the one from vllm_flash_attn. The former is used on RoCM and the
1053
1059
# latter has an additional parameter to control FA2 vs FA3
1054
1060
self .flash_attn_varlen_func = flash_attn_varlen_func
1061
+ self .vllm_flash_attn_version = get_flash_attn_version ()
1055
1062
if self .vllm_flash_attn_version is not None :
1056
1063
self .flash_attn_varlen_func = \
1057
1064
functools .partial (flash_attn_varlen_func ,
@@ -1315,18 +1322,48 @@ def _compute_prefill_context(
1315
1322
[0 , q .shape [- 1 ] - v .shape [- 1 ]],
1316
1323
value = 0 )
1317
1324
1318
- attn_output , attn_softmax_lse = self .flash_attn_varlen_func (
1319
- q = q ,
1320
- k = k ,
1321
- v = v_padded ,
1322
- cu_seqlens_q = prefill_metadata .query_start_loc ,
1323
- cu_seqlens_k = prefill_metadata .context_chunk_cu_seq_lens [i ],
1324
- max_seqlen_q = prefill_metadata .max_query_len ,
1325
- max_seqlen_k = prefill_metadata .context_chunk_max_seq_lens [i ],
1326
- softmax_scale = self .scale ,
1327
- causal = False , # Context is unmasked
1328
- return_softmax_lse = True ,
1329
- )
1325
+ if is_hip and envs .VLLM_USE_TRITON_FLASH_ATTN :
1326
+ attn_output , attn_softmax_lse = self .triton_fa_func (
1327
+ q ,
1328
+ k ,
1329
+ v_padded ,
1330
+ None ,
1331
+ prefill_metadata .query_start_loc ,
1332
+ prefill_metadata .context_chunk_cu_seq_lens [i ],
1333
+ prefill_metadata .max_query_len ,
1334
+ prefill_metadata .context_chunk_max_seq_lens [i ],
1335
+ False , # causal
1336
+ self .scale ,
1337
+ None , # attn_mask is None unless applying ALiBi mask
1338
+ )
1339
+ elif is_vllm_fa :
1340
+ attn_output , attn_softmax_lse = self .flash_attn_varlen_func (
1341
+ q = q ,
1342
+ k = k ,
1343
+ v = v_padded ,
1344
+ cu_seqlens_q = prefill_metadata .query_start_loc ,
1345
+ cu_seqlens_k = prefill_metadata .context_chunk_cu_seq_lens [i ],
1346
+ max_seqlen_q = prefill_metadata .max_query_len ,
1347
+ max_seqlen_k = prefill_metadata .
1348
+ context_chunk_max_seq_lens [i ],
1349
+ softmax_scale = self .scale ,
1350
+ causal = False , # Context is unmasked
1351
+ return_softmax_lse = True ,
1352
+ )
1353
+ else :
1354
+ attn_output , attn_softmax_lse , _ = self .flash_attn_varlen_func (
1355
+ q = q ,
1356
+ k = k ,
1357
+ v = v_padded ,
1358
+ cu_seqlens_q = prefill_metadata .query_start_loc ,
1359
+ cu_seqlens_k = prefill_metadata .context_chunk_cu_seq_lens [i ],
1360
+ max_seqlen_q = prefill_metadata .max_query_len ,
1361
+ max_seqlen_k = prefill_metadata .
1362
+ context_chunk_max_seq_lens [i ],
1363
+ softmax_scale = self .scale ,
1364
+ causal = False , # Context is unmasked
1365
+ return_attn_probs = True ,
1366
+ )
1330
1367
1331
1368
if output is None :
1332
1369
output = attn_output
@@ -1374,11 +1411,24 @@ def _forward_prefill(
1374
1411
v_padded = torch .nn .functional .pad (v , [0 , q .shape [- 1 ] - v .shape [- 1 ]],
1375
1412
value = 0 )
1376
1413
1377
- if has_context :
1378
- if not current_platform .is_cuda ():
1379
- raise NotImplementedError (
1380
- "Chunked Prefill for MLA is not currently supported on"
1381
- "non-cuda platforms" )
1414
+ if is_hip and envs .VLLM_USE_TRITON_FLASH_ATTN :
1415
+ output = self .triton_fa_func (
1416
+ q ,
1417
+ k ,
1418
+ v_padded ,
1419
+ None ,
1420
+ prefill_metadata .query_start_loc ,
1421
+ prefill_metadata .query_start_loc ,
1422
+ prefill_metadata .max_prefill_seq_len ,
1423
+ prefill_metadata .max_prefill_seq_len ,
1424
+ True , # causal
1425
+ self .scale ,
1426
+ None , # attn_mask is None unless applying ALiBi mask
1427
+ )
1428
+ ## triton flash attention always return 2 objects
1429
+ if not has_context :
1430
+ output = output [0 ]
1431
+ elif is_vllm_fa :
1382
1432
output = self .flash_attn_varlen_func (
1383
1433
q = q ,
1384
1434
k = k ,
@@ -1389,7 +1439,7 @@ def _forward_prefill(
1389
1439
max_seqlen_k = prefill_metadata .max_prefill_seq_len ,
1390
1440
softmax_scale = self .scale ,
1391
1441
causal = True ,
1392
- return_softmax_lse = True ,
1442
+ return_softmax_lse = has_context ,
1393
1443
)
1394
1444
else :
1395
1445
output = self .flash_attn_varlen_func (
@@ -1402,10 +1452,12 @@ def _forward_prefill(
1402
1452
max_seqlen_k = prefill_metadata .max_prefill_seq_len ,
1403
1453
softmax_scale = self .scale ,
1404
1454
causal = True ,
1455
+ return_attn_probs = has_context ,
1405
1456
)
1406
1457
1407
1458
if has_context :
1408
- suffix_output , suffix_lse = output
1459
+ # ROCm flash_attn_varlen_func will return 3 objects instead of 2
1460
+ suffix_output , suffix_lse , * rest = output
1409
1461
context_output , context_lse = self ._compute_prefill_context ( \
1410
1462
q , kv_c_and_k_pe_cache , attn_metadata )
1411
1463
0 commit comments