245
245
from flash_attn import flash_attn_varlen_func
246
246
is_vllm_fa = False
247
247
248
- from vllm .attention .ops .triton_flash_attention import ( triton_attention )
248
+ from vllm .attention .ops .triton_flash_attention import triton_attention
249
249
250
250
if TYPE_CHECKING :
251
251
from vllm .worker .model_runner import (ModelInputForGPUBuilder ,
@@ -1330,9 +1330,9 @@ def _compute_prefill_context(
1330
1330
prefill_metadata .context_chunk_cu_seq_lens [i ],
1331
1331
prefill_metadata .max_query_len ,
1332
1332
prefill_metadata .context_chunk_max_seq_lens [i ],
1333
- False , # causal
1333
+ False , # causal
1334
1334
self .scale ,
1335
- None , # attn_mask is None unless applying ALiBi mask
1335
+ None , # attn_mask is None unless applying ALiBi mask
1336
1336
)
1337
1337
elif is_vllm_fa :
1338
1338
attn_output , attn_softmax_lse = self .flash_attn_varlen_func (
@@ -1342,7 +1342,8 @@ def _compute_prefill_context(
1342
1342
cu_seqlens_q = prefill_metadata .query_start_loc ,
1343
1343
cu_seqlens_k = prefill_metadata .context_chunk_cu_seq_lens [i ],
1344
1344
max_seqlen_q = prefill_metadata .max_query_len ,
1345
- max_seqlen_k = prefill_metadata .context_chunk_max_seq_lens [i ],
1345
+ max_seqlen_k = prefill_metadata .
1346
+ context_chunk_max_seq_lens [i ],
1346
1347
softmax_scale = self .scale ,
1347
1348
causal = False , # Context is unmasked
1348
1349
return_softmax_lse = True ,
@@ -1355,7 +1356,8 @@ def _compute_prefill_context(
1355
1356
cu_seqlens_q = prefill_metadata .query_start_loc ,
1356
1357
cu_seqlens_k = prefill_metadata .context_chunk_cu_seq_lens [i ],
1357
1358
max_seqlen_q = prefill_metadata .max_query_len ,
1358
- max_seqlen_k = prefill_metadata .context_chunk_max_seq_lens [i ],
1359
+ max_seqlen_k = prefill_metadata .
1360
+ context_chunk_max_seq_lens [i ],
1359
1361
softmax_scale = self .scale ,
1360
1362
causal = False , # Context is unmasked
1361
1363
return_attn_probs = True ,
@@ -1417,9 +1419,9 @@ def _forward_prefill(
1417
1419
prefill_metadata .query_start_loc ,
1418
1420
prefill_metadata .max_prefill_seq_len ,
1419
1421
prefill_metadata .max_prefill_seq_len ,
1420
- True , # causal
1422
+ True , # causal
1421
1423
self .scale ,
1422
- None , # attn_mask is None unless applying ALiBi mask
1424
+ None , # attn_mask is None unless applying ALiBi mask
1423
1425
)
1424
1426
## triton flash attention always return 2 objects
1425
1427
if not has_context :
0 commit comments