23
23
from vllm .model_executor .layers .quantization .base_config import (
24
24
QuantizationConfig )
25
25
from vllm .model_executor .layers .quantization .kv_cache import BaseKVCacheMethod
26
+ from vllm .model_executor .models .vision import get_vit_attn_backend
26
27
from vllm .platforms import _Backend , current_platform
27
28
from vllm .utils import direct_register_custom_op
28
29
@@ -349,13 +350,11 @@ def __init__(
349
350
f"divisible by num_kv_heads ({ self .num_kv_heads } )"
350
351
self .num_queries_per_kv = self .num_heads // self .num_kv_heads
351
352
352
- dtype = torch .get_default_dtype ()
353
- attn_backend = get_attn_backend (head_size ,
354
- dtype ,
355
- kv_cache_dtype = None ,
356
- block_size = 16 ,
357
- is_attention_free = False )
358
- backend = backend_name_to_enum (attn_backend .get_name ())
353
+ # dtype = torch.get_default_dtype()
354
+
355
+ # Determine the attention backend
356
+ backend , use_upstream_fa = get_vit_attn_backend (head_size = head_size )
357
+
359
358
if current_platform .is_rocm ():
360
359
# currently, only torch_sdpa is supported on rocm
361
360
self .attn_backend = _Backend .TORCH_SDPA
@@ -375,6 +374,20 @@ def __init__(
375
374
and not check_xformers_availability ()):
376
375
self .attn_backend = _Backend .TORCH_SDPA
377
376
377
+ if self .attn_backend in {
378
+ _Backend .FLASH_ATTN , _Backend .FLASH_ATTN_VLLM_V1
379
+ }:
380
+ if use_upstream_fa :
381
+ from flash_attn import flash_attn_varlen_func
382
+ self ._flash_attn_varlen_func = flash_attn_varlen_func
383
+ else :
384
+ from vllm .vllm_flash_attn import flash_attn_varlen_func
385
+ self ._flash_attn_varlen_func = flash_attn_varlen_func
386
+
387
+ logger .info_once (
388
+ f"MultiHeadAttention attn_backend: { self .attn_backend } , "
389
+ f"use_upstream_fa: { use_upstream_fa } " )
390
+
378
391
def forward (
379
392
self ,
380
393
query : torch .Tensor ,
@@ -399,11 +412,6 @@ def forward(
399
412
_Backend .FLASH_ATTN ,
400
413
_Backend .FLASH_ATTN_VLLM_V1 ,
401
414
}:
402
- if self .head_size % 32 != 0 :
403
- # import from upstream flash_attn
404
- from flash_attn import flash_attn_varlen_func
405
- else :
406
- from vllm .vllm_flash_attn import flash_attn_varlen_func
407
415
408
416
cu_seqlens_q = torch .arange (0 , (bsz + 1 ) * q_len ,
409
417
step = q_len ,
@@ -414,7 +422,7 @@ def forward(
414
422
dtype = torch .int32 ,
415
423
device = key .device )
416
424
417
- out = flash_attn_varlen_func (
425
+ out = self . _flash_attn_varlen_func (
418
426
query .flatten (0 , 1 ),
419
427
key .flatten (0 , 1 ),
420
428
value .flatten (0 , 1 ),
0 commit comments