Skip to content

Commit 90ffa1a

Browse files
committed
refactor get_vit_attention_backend()
Signed-off-by: wwl2755 <[email protected]>
1 parent cec027c commit 90ffa1a

File tree

11 files changed

+96
-57
lines changed

11 files changed

+96
-57
lines changed

vllm/attention/layer.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.model_executor.layers.quantization.base_config import (
2424
QuantizationConfig)
2525
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
26+
from vllm.model_executor.models.vision import get_vit_attn_backend
2627
from vllm.platforms import _Backend, current_platform
2728
from vllm.utils import direct_register_custom_op
2829

@@ -349,13 +350,11 @@ def __init__(
349350
f"divisible by num_kv_heads ({self.num_kv_heads})"
350351
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
351352

352-
dtype = torch.get_default_dtype()
353-
attn_backend = get_attn_backend(head_size,
354-
dtype,
355-
kv_cache_dtype=None,
356-
block_size=16,
357-
is_attention_free=False)
358-
backend = backend_name_to_enum(attn_backend.get_name())
353+
# dtype = torch.get_default_dtype()
354+
355+
# Determine the attention backend
356+
backend, use_upstream_fa = get_vit_attn_backend(head_size=head_size)
357+
359358
if current_platform.is_rocm():
360359
# currently, only torch_sdpa is supported on rocm
361360
self.attn_backend = _Backend.TORCH_SDPA
@@ -375,6 +374,20 @@ def __init__(
375374
and not check_xformers_availability()):
376375
self.attn_backend = _Backend.TORCH_SDPA
377376

377+
if self.attn_backend in {
378+
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1
379+
}:
380+
if use_upstream_fa:
381+
from flash_attn import flash_attn_varlen_func
382+
self._flash_attn_varlen_func = flash_attn_varlen_func
383+
else:
384+
from vllm.vllm_flash_attn import flash_attn_varlen_func
385+
self._flash_attn_varlen_func = flash_attn_varlen_func
386+
387+
logger.info_once(
388+
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
389+
f"use_upstream_fa: {use_upstream_fa}")
390+
378391
def forward(
379392
self,
380393
query: torch.Tensor,
@@ -399,11 +412,6 @@ def forward(
399412
_Backend.FLASH_ATTN,
400413
_Backend.FLASH_ATTN_VLLM_V1,
401414
}:
402-
if self.head_size % 32 != 0:
403-
# import from upstream flash_attn
404-
from flash_attn import flash_attn_varlen_func
405-
else:
406-
from vllm.vllm_flash_attn import flash_attn_varlen_func
407415

408416
cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
409417
step=q_len,
@@ -414,7 +422,7 @@ def forward(
414422
dtype=torch.int32,
415423
device=key.device)
416424

417-
out = flash_attn_varlen_func(
425+
out = self._flash_attn_varlen_func(
418426
query.flatten(0, 1),
419427
key.flatten(0, 1),
420428
value.flatten(0, 1),

vllm/model_executor/models/ernie45_vl.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,8 @@ def __init__(
172172
prefix=f"{prefix}.proj")
173173

174174
# Detect attention implementation.
175-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
175+
self.attn_backend, self.use_upstream_fa = get_vit_attn_backend(
176+
head_size=self.hidden_size_per_attention_head)
176177
if self.attn_backend not in {
177178
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
178179
_Backend.ROCM_AITER_FA
@@ -235,7 +236,10 @@ def forward(
235236
if self.attn_backend == _Backend.ROCM_AITER_FA:
236237
from aiter import flash_attn_varlen_func
237238
else:
238-
from flash_attn import flash_attn_varlen_func
239+
if self.use_upstream_fa:
240+
from flash_attn import flash_attn_varlen_func
241+
else:
242+
from vllm.vllm_flash_attn import flash_attn_varlen_func
239243

240244
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
241245

@@ -459,7 +463,7 @@ def __init__(
459463
), "vit's config.hidden must be equal to config.embed_dim"
460464
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
461465

462-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
466+
self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim)
463467

464468
@property
465469
def dtype(self) -> torch.dtype:

vllm/model_executor/models/glm4_1v.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,8 @@ def __init__(
260260
)
261261

262262
# Detect attention implementation.
263-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
263+
self.attn_backend, self.use_upstream_fa = get_vit_attn_backend(
264+
head_size=self.hidden_size_per_attention_head)
264265
if self.attn_backend not in {
265266
_Backend.FLASH_ATTN,
266267
_Backend.TORCH_SDPA,
@@ -323,7 +324,10 @@ def forward(
323324
if self.attn_backend == _Backend.FLASH_ATTN:
324325
# from vllm_flash_attn.flash_attn_interface import (
325326
# flash_attn_varlen_func)
326-
from flash_attn import flash_attn_varlen_func
327+
if self.use_upstream_fa:
328+
from flash_attn import flash_attn_varlen_func
329+
else:
330+
from vllm.vllm_flash_attn import flash_attn_varlen_func
327331

328332
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
329333

@@ -728,7 +732,7 @@ def __init__(
728732
self.post_layernorm = RMSNorm(vision_config.hidden_size,
729733
eps=vision_config.rms_norm_eps)
730734

731-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
735+
self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim)
732736

733737
@property
734738
def dtype(self) -> torch.dtype:

vllm/model_executor/models/keye.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,8 @@ def __init__(
374374
)
375375

376376
# Detect attention implementation.
377-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
377+
self.attn_backend, self.use_upstream_fa = get_vit_attn_backend(
378+
head_size=self.head_dim)
378379
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
379380
raise RuntimeError(
380381
f"Keye-VL does not support {self.attn_backend} backend now.")
@@ -428,7 +429,10 @@ def forward(
428429
)
429430

430431
if self.attn_backend == _Backend.FLASH_ATTN:
431-
from flash_attn import flash_attn_varlen_func
432+
if self.use_upstream_fa:
433+
from flash_attn import flash_attn_varlen_func
434+
else:
435+
from vllm.vllm_flash_attn import flash_attn_varlen_func
432436

433437
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
434438

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,8 @@ def __init__(
299299
disable_tp=use_data_parallel)
300300

301301
# Detect attention implementation.
302-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
302+
self.attn_backend, self.use_upstream_fa = get_vit_attn_backend(
303+
head_size=self.hidden_size_per_attention_head)
303304
if self.attn_backend not in {
304305
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
305306
_Backend.ROCM_AITER_FA
@@ -360,7 +361,10 @@ def forward(
360361
if self.attn_backend == _Backend.ROCM_AITER_FA:
361362
from aiter import flash_attn_varlen_func
362363
else:
363-
from flash_attn import flash_attn_varlen_func
364+
if self.use_upstream_fa:
365+
from flash_attn import flash_attn_varlen_func
366+
else:
367+
from vllm.vllm_flash_attn import flash_attn_varlen_func
364368

365369
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
366370

@@ -629,7 +633,7 @@ def __init__(
629633
prefix=f"{prefix}.merger",
630634
use_data_parallel=use_data_parallel,
631635
)
632-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
636+
self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim)
633637

634638
@property
635639
def dtype(self) -> torch.dtype:

vllm/model_executor/models/qwen2_vl.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,7 +314,8 @@ def __init__(
314314
prefix=f"{prefix}.proj")
315315

316316
# Detect attention implementation.
317-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
317+
self.attn_backend, self.use_upstream_fa = get_vit_attn_backend(
318+
head_size=self.hidden_size_per_attention_head)
318319
if self.attn_backend not in {
319320
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
320321
_Backend.ROCM_AITER_FA
@@ -374,7 +375,10 @@ def forward(
374375
if self.attn_backend == _Backend.ROCM_AITER_FA:
375376
from aiter import flash_attn_varlen_func
376377
else:
377-
from flash_attn import flash_attn_varlen_func
378+
if self.use_upstream_fa:
379+
from flash_attn import flash_attn_varlen_func
380+
else:
381+
from vllm.vllm_flash_attn import flash_attn_varlen_func
378382

379383
q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
380384

@@ -628,7 +632,7 @@ def __init__(
628632
quant_config=quant_config,
629633
prefix=f"{prefix}.merger",
630634
)
631-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
635+
self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim)
632636

633637
@property
634638
def dtype(self) -> torch.dtype:

vllm/model_executor/models/siglip2navit.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ def __init__(
236236
self.use_rope = config.use_rope
237237

238238
# Detect attention implementation.
239-
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
239+
self.attn_backend, self.use_upstream_fa = get_vit_attn_backend(
240+
head_size=self.head_dim)
240241
if self.attn_backend not in {
241242
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
242243
_Backend.ROCM_AITER_FA
@@ -280,7 +281,10 @@ def forward(
280281
if self.attn_backend == _Backend.ROCM_AITER_FA:
281282
from aiter import flash_attn_varlen_func
282283
else:
283-
from flash_attn import flash_attn_varlen_func
284+
if self.use_upstream_fa:
285+
from flash_attn import flash_attn_varlen_func
286+
else:
287+
from vllm.vllm_flash_attn import flash_attn_varlen_func
284288
attn_output = flash_attn_varlen_func(
285289
queries, keys, values, cu_seqlens, cu_seqlens, max_seqlen,
286290
max_seqlen).reshape(seq_length, -1)

vllm/model_executor/models/vision.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,18 @@ def get_vision_encoder_info(
6868
raise NotImplementedError(msg)
6969

7070

71-
def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
71+
def get_vit_attn_backend(head_size: int) -> tuple[_Backend, bool]:
7272
"""
7373
Get the available attention backend for Vision Transformer.
74+
75+
Returns:
76+
Tuple of (backend, use_upstream_fa)
7477
"""
75-
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
76-
7778
selected_backend: Optional[_Backend] = get_env_variable_attn_backend()
7879
if selected_backend is not None:
79-
return selected_backend
80+
return selected_backend, False
8081

81-
return current_platform.get_vit_attn_backend(support_fa)
82+
return current_platform.get_vit_attn_backend(head_size)
8283

8384

8485
def resolve_visual_encoder_outputs(

vllm/platforms/cuda.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -203,18 +203,25 @@ def get_current_memory_usage(cls,
203203
return torch.cuda.max_memory_allocated(device)
204204

205205
@classmethod
206-
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
207-
if cls.has_device_capability(80) and support_fa:
208-
from transformers.utils import is_flash_attn_2_available
209-
if is_flash_attn_2_available():
210-
return _Backend.FLASH_ATTN
211-
logger.warning_once(
212-
"Current `vllm-flash-attn` has a bug inside vision "
213-
"module, so we use xformers backend instead. You can "
214-
"run `pip install flash-attn` to use flash-attention "
215-
"backend.")
206+
def get_vit_attn_backend(cls, head_size: int) -> tuple[_Backend, bool]:
207+
if cls.has_device_capability(80):
208+
if head_size % 32 == 0:
209+
# Use vllm-flash-attn
210+
return _Backend.FLASH_ATTN, False
211+
if head_size % 32 != 0:
212+
from transformers.utils import is_flash_attn_2_available
213+
if is_flash_attn_2_available():
214+
# Use upstream FA
215+
return _Backend.FLASH_ATTN, True
216+
else:
217+
# Fallback to XFORMERS
218+
logger.warning_once(
219+
"Using xformers for ViT attention backend. "
220+
"To use flash attention for ViT"
221+
"please install flash_attn")
222+
return _Backend.XFORMERS, False
216223
# Fallback for Volta/Turing GPUs or FA not supported
217-
return _Backend.XFORMERS
224+
return _Backend.XFORMERS, False
218225

219226
@classmethod
220227
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,

vllm/platforms/interface.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,8 @@ def device_id_to_physical_device_id(cls, device_id: int):
190190
return device_id
191191

192192
@classmethod
193-
def get_vit_attn_backend(cls, support_fa: bool = False) -> _Backend:
194-
return _Backend.TORCH_SDPA
193+
def get_vit_attn_backend(cls, head_size: int) -> tuple[_Backend, bool]:
194+
return _Backend.TORCH_SDPA, False
195195

196196
@classmethod
197197
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,

0 commit comments

Comments
 (0)