Skip to content

Commit 8f7bced

Browse files
committed
more robust check
Signed-off-by: wwl2755 <[email protected]>
1 parent 2f0a859 commit 8f7bced

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

vllm/platforms/cuda.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -209,23 +209,26 @@ def get_vit_attn_backend(cls, head_size: int,
209209
return _Backend.XFORMERS, False
210210

211211
if cls.has_device_capability(80):
212-
if head_size % 32 == 0:
213-
# Use vllm-flash-attn
212+
FLASH_ATTN_V1 = "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" # noqa: E501
213+
from vllm.attention.selector import is_attn_backend_supported
214+
is_default_fa_supported = is_attn_backend_supported(
215+
FLASH_ATTN_V1, head_size, dtype, allow_import_error=False)
216+
from transformers.utils import is_flash_attn_2_available
217+
is_upstream_fa_supported = is_flash_attn_2_available()
218+
if is_default_fa_supported:
214219
return _Backend.FLASH_ATTN, False
215-
if head_size % 32 != 0:
216-
from transformers.utils import is_flash_attn_2_available
217-
if is_flash_attn_2_available():
218-
# Use upstream FA
219-
return _Backend.FLASH_ATTN, True
220-
else:
221-
# Fallback to XFORMERS
222-
logger.warning_once(
223-
"Using xformers for ViT attention backend. "
224-
"To use flash attention for ViT"
225-
"please install flash_attn")
226-
return _Backend.XFORMERS, False
227-
# Fallback for Volta/Turing GPUs or FA not supported
228-
return _Backend.XFORMERS, False
220+
elif is_upstream_fa_supported:
221+
return _Backend.FLASH_ATTN, True
222+
else:
223+
# Fallback to XFORMERS
224+
logger.warning_once(
225+
"Using xformers for ViT attention backend. "
226+
"To use flash attention for ViT"
227+
"please install flash_attn")
228+
return _Backend.XFORMERS, False
229+
else:
230+
# Fallback for Volta/Turing GPUs or FA not supported
231+
return _Backend.XFORMERS, False
229232

230233
@classmethod
231234
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,

0 commit comments

Comments
 (0)