@@ -209,23 +209,26 @@ def get_vit_attn_backend(cls, head_size: int,
209
209
return _Backend .XFORMERS , False
210
210
211
211
if cls .has_device_capability (80 ):
212
- if head_size % 32 == 0 :
213
- # Use vllm-flash-attn
212
+ FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
213
+ from vllm .attention .selector import is_attn_backend_supported
214
+ is_default_fa_supported = is_attn_backend_supported (
215
+ FLASH_ATTN_V1 , head_size , dtype , allow_import_error = False )
216
+ from transformers .utils import is_flash_attn_2_available
217
+ is_upstream_fa_supported = is_flash_attn_2_available ()
218
+ if is_default_fa_supported :
214
219
return _Backend .FLASH_ATTN , False
215
- if head_size % 32 != 0 :
216
- from transformers .utils import is_flash_attn_2_available
217
- if is_flash_attn_2_available ():
218
- # Use upstream FA
219
- return _Backend .FLASH_ATTN , True
220
- else :
221
- # Fallback to XFORMERS
222
- logger .warning_once (
223
- "Using xformers for ViT attention backend. "
224
- "To use flash attention for ViT"
225
- "please install flash_attn" )
226
- return _Backend .XFORMERS , False
227
- # Fallback for Volta/Turing GPUs or FA not supported
228
- return _Backend .XFORMERS , False
220
+ elif is_upstream_fa_supported :
221
+ return _Backend .FLASH_ATTN , True
222
+ else :
223
+ # Fallback to XFORMERS
224
+ logger .warning_once (
225
+ "Using xformers for ViT attention backend. "
226
+ "To use flash attention for ViT"
227
+ "please install flash_attn" )
228
+ return _Backend .XFORMERS , False
229
+ else :
230
+ # Fallback for Volta/Turing GPUs or FA not supported
231
+ return _Backend .XFORMERS , False
229
232
230
233
@classmethod
231
234
def get_attn_backend_cls (cls , selected_backend , head_size , dtype ,
0 commit comments