17
17
compute_slot_mapping_start_idx , get_num_prefill_decode_query_kv_tokens ,
18
18
get_seq_len_block_table_args , is_all_cross_attn_metadata_set ,
19
19
is_all_encoder_attn_metadata_set , is_block_tables_empty )
20
+ from vllm .envs import VLLM_FLASH_ATTN_VERSION
20
21
from vllm .multimodal import MultiModalPlaceholderMap
22
+ from vllm .platforms import current_platform
21
23
from vllm .utils import async_tensor_h2d , make_tensor_with_pad
22
24
23
25
if TYPE_CHECKING :
24
26
from vllm .worker .model_runner import (ModelInputForGPUBuilder ,
25
27
ModelInputForGPUWithSamplingMetadata )
26
28
27
29
from vllm .vllm_flash_attn import (flash_attn_varlen_func ,
28
- flash_attn_with_kvcache )
30
+ flash_attn_with_kvcache ,
31
+ is_fa_version_supported )
29
32
30
33
31
34
class FlashAttentionBackend (AttentionBackend ):
@@ -634,6 +637,20 @@ def __init__(
634
637
f"Supported head sizes are: { support_head_sizes } ." )
635
638
self .attn_type = attn_type
636
639
640
+ # if hopper default to FA3, otherwise stick to FA2 for now
641
+ # TODO(lucas): profile FA3 on ampere to see if it makes sense to
642
+ # use FA3 as default for both
643
+ if current_platform .get_device_capability ()[0 ] >= 9 :
644
+ self .fa_version = 3 if is_fa_version_supported (3 ) else 2
645
+ else :
646
+ self .fa_version = 2
647
+
648
+ if VLLM_FLASH_ATTN_VERSION is not None :
649
+ assert VLLM_FLASH_ATTN_VERSION in [2 , 3 ]
650
+ self .fa_version = VLLM_FLASH_ATTN_VERSION
651
+
652
+ assert is_fa_version_supported (self .fa_version )
653
+
637
654
def forward (
638
655
self ,
639
656
layer : AttentionLayer ,
@@ -752,6 +769,7 @@ def forward(
752
769
alibi_slopes = alibi_slopes ,
753
770
softcap = logits_soft_cap ,
754
771
out = prefill_output ,
772
+ fa_version = self .fa_version ,
755
773
)
756
774
else :
757
775
# prefix-enabled attention
@@ -765,7 +783,7 @@ def forward(
765
783
v = value_cache ,
766
784
cu_seqlens_q = prefill_meta .query_start_loc ,
767
785
max_seqlen_q = prefill_meta .max_query_len ,
768
- cu_seqlens_k = prefill_meta .seq_start_loc ,
786
+ seqused_k = prefill_meta .seq_lens_tensor ,
769
787
max_seqlen_k = max_seq_len ,
770
788
softmax_scale = softmax_scale ,
771
789
causal = True ,
@@ -774,6 +792,7 @@ def forward(
774
792
block_table = prefill_meta .block_tables ,
775
793
softcap = logits_soft_cap ,
776
794
out = prefill_output ,
795
+ fa_version = self .fa_version ,
777
796
)
778
797
779
798
if decode_meta := attn_metadata .decode_metadata :
@@ -793,7 +812,7 @@ def forward(
793
812
v = value_cache ,
794
813
cu_seqlens_q = decode_meta .query_start_loc ,
795
814
max_seqlen_q = decode_meta .max_decode_query_len ,
796
- cu_seqlens_k = decode_meta .seq_start_loc ,
815
+ seqused_k = decode_meta .seq_lens_tensor ,
797
816
max_seqlen_k = decode_meta .max_decode_seq_len ,
798
817
softmax_scale = softmax_scale ,
799
818
causal = True ,
@@ -802,6 +821,7 @@ def forward(
802
821
softcap = logits_soft_cap ,
803
822
block_table = decode_meta .block_tables ,
804
823
out = decode_output ,
824
+ fa_version = self .fa_version ,
805
825
)
806
826
else :
807
827
# Use flash_attn_with_kvcache for normal decoding.
@@ -822,6 +842,7 @@ def forward(
822
842
alibi_slopes = alibi_slopes ,
823
843
softcap = logits_soft_cap ,
824
844
out = decode_output .unsqueeze (1 ),
845
+ fa_version = self .fa_version ,
825
846
)
826
847
return output
827
848
0 commit comments