Skip to content

Commit 2f0a859

Browse files
committed
add dtype checking and fix tests
Signed-off-by: wwl2755 <[email protected]>
1 parent 90ffa1a commit 2f0a859

File tree

12 files changed

+60
-23
lines changed

12 files changed

+60
-23
lines changed

tests/kernels/attention/test_mha_attn.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,19 +33,38 @@ def test_mha_attn_platform(device: str):
3333
torch.set_default_dtype(torch.float16)
3434

3535
if device == "cpu":
36-
with patch("vllm.attention.selector.current_platform", CpuPlatform()):
36+
with patch("vllm.model_executor.models.vision.current_platform",
37+
CpuPlatform()):
3738
attn = MultiHeadAttention(16, 64, scale=1)
3839
assert attn.attn_backend == _Backend.TORCH_SDPA
3940
elif device == "hip":
40-
with patch("vllm.attention.selector.current_platform", RocmPlatform()):
41+
with patch("vllm.model_executor.models.vision.current_platform",
42+
RocmPlatform()):
4143
attn = MultiHeadAttention(16, 64, scale=1)
4244
assert attn.attn_backend == _Backend.TORCH_SDPA
4345
else:
44-
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
46+
# Test CUDA with head_size=64 (divisible by 32)
47+
# - should use vLLM FlashAttention
48+
with patch("vllm.model_executor.models.vision.current_platform",
49+
CudaPlatform()):
4550
attn = MultiHeadAttention(16, 64, scale=1)
46-
assert attn.attn_backend == _Backend.XFORMERS
47-
48-
with patch("vllm.attention.selector.current_platform", CudaPlatform()):
51+
assert attn.attn_backend == _Backend.FLASH_ATTN
52+
53+
# Test CUDA with head_size=72 (not divisible by 32)
54+
# - upstream FA available
55+
with patch("vllm.model_executor.models.vision.current_platform",
56+
CudaPlatform()), \
57+
patch("transformers.utils.is_flash_attn_2_available",
58+
return_value=True):
59+
attn = MultiHeadAttention(16, 72, scale=1)
60+
assert attn.attn_backend == _Backend.FLASH_ATTN
61+
62+
# Test CUDA with head_size=72 (not divisible by 32)
63+
# - upstream FA not available
64+
with patch("vllm.model_executor.models.vision.current_platform",
65+
CudaPlatform()), \
66+
patch("transformers.utils.is_flash_attn_2_available",
67+
return_value=False):
4968
attn = MultiHeadAttention(16, 72, scale=1)
5069
assert attn.attn_backend == _Backend.XFORMERS
5170

vllm/attention/layer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -350,10 +350,13 @@ def __init__(
350350
f"divisible by num_kv_heads ({self.num_kv_heads})"
351351
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
352352

353-
# dtype = torch.get_default_dtype()
353+
# During model initialization, the default dtype is set as the model
354+
# weight and activation dtype.
355+
dtype = torch.get_default_dtype()
354356

355357
# Determine the attention backend
356-
backend, use_upstream_fa = get_vit_attn_backend(head_size=head_size)
358+
backend, use_upstream_fa = get_vit_attn_backend(head_size=head_size,
359+
dtype=dtype)
357360

358361
if current_platform.is_rocm():
359362
# currently, only torch_sdpa is supported on rocm

vllm/model_executor/models/ernie45_vl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ def __init__(
173173

174174
# Detect attention implementation.
175175
self.attn_backend, self.use_upstream_fa = get_vit_attn_backend(
176-
head_size=self.hidden_size_per_attention_head)
176+
head_size=self.hidden_size_per_attention_head,
177+
dtype=torch.get_default_dtype())
177178
if self.attn_backend not in {
178179
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
179180
_Backend.ROCM_AITER_FA
@@ -463,7 +464,8 @@ def __init__(
463464
), "vit's config.hidden must be equal to config.embed_dim"
464465
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)
465466

466-
self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim)
467+
self.attn_backend, _ = get_vit_attn_backend(
468+
head_size=head_dim, dtype=torch.get_default_dtype())
467469

468470
@property
469471
def dtype(self) -> torch.dtype:

vllm/model_executor/models/glm4_1v.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,8 @@ def __init__(
261261

262262
# Detect attention implementation.
263263
self.attn_backend, self.use_upstream_fa = get_vit_attn_backend(
264-
head_size=self.hidden_size_per_attention_head)
264+
head_size=self.hidden_size_per_attention_head,
265+
dtype=torch.get_default_dtype())
265266
if self.attn_backend not in {
266267
_Backend.FLASH_ATTN,
267268
_Backend.TORCH_SDPA,
@@ -732,7 +733,8 @@ def __init__(
732733
self.post_layernorm = RMSNorm(vision_config.hidden_size,
733734
eps=vision_config.rms_norm_eps)
734735

735-
self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim)
736+
self.attn_backend, _ = get_vit_attn_backend(
737+
head_size=head_dim, dtype=torch.get_default_dtype())
736738

737739
@property
738740
def dtype(self) -> torch.dtype:

vllm/model_executor/models/keye.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def __init__(
375375

376376
# Detect attention implementation.
377377
self.attn_backend, self.use_upstream_fa = get_vit_attn_backend(
378-
head_size=self.head_dim)
378+
head_size=self.head_dim, dtype=torch.get_default_dtype())
379379
if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
380380
raise RuntimeError(
381381
f"Keye-VL does not support {self.attn_backend} backend now.")

vllm/model_executor/models/qwen2_5_vl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,8 @@ def __init__(
300300

301301
# Detect attention implementation.
302302
self.attn_backend, self.use_upstream_fa = get_vit_attn_backend(
303-
head_size=self.hidden_size_per_attention_head)
303+
head_size=self.hidden_size_per_attention_head,
304+
dtype=torch.get_default_dtype())
304305
if self.attn_backend not in {
305306
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
306307
_Backend.ROCM_AITER_FA
@@ -633,7 +634,8 @@ def __init__(
633634
prefix=f"{prefix}.merger",
634635
use_data_parallel=use_data_parallel,
635636
)
636-
self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim)
637+
self.attn_backend, _ = get_vit_attn_backend(
638+
head_size=head_dim, dtype=torch.get_default_dtype())
637639

638640
@property
639641
def dtype(self) -> torch.dtype:

vllm/model_executor/models/qwen2_vl.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,8 @@ def __init__(
315315

316316
# Detect attention implementation.
317317
self.attn_backend, self.use_upstream_fa = get_vit_attn_backend(
318-
head_size=self.hidden_size_per_attention_head)
318+
head_size=self.hidden_size_per_attention_head,
319+
dtype=torch.get_default_dtype())
319320
if self.attn_backend not in {
320321
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
321322
_Backend.ROCM_AITER_FA
@@ -632,7 +633,8 @@ def __init__(
632633
quant_config=quant_config,
633634
prefix=f"{prefix}.merger",
634635
)
635-
self.attn_backend, _ = get_vit_attn_backend(head_size=head_dim)
636+
self.attn_backend, _ = get_vit_attn_backend(
637+
head_size=head_dim, dtype=torch.get_default_dtype())
636638

637639
@property
638640
def dtype(self) -> torch.dtype:

vllm/model_executor/models/siglip2navit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def __init__(
237237

238238
# Detect attention implementation.
239239
self.attn_backend, self.use_upstream_fa = get_vit_attn_backend(
240-
head_size=self.head_dim)
240+
head_size=self.head_dim, dtype=torch.get_default_dtype())
241241
if self.attn_backend not in {
242242
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA,
243243
_Backend.ROCM_AITER_FA

vllm/model_executor/models/vision.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def get_vision_encoder_info(
6868
raise NotImplementedError(msg)
6969

7070

71-
def get_vit_attn_backend(head_size: int) -> tuple[_Backend, bool]:
71+
def get_vit_attn_backend(head_size: int,
72+
dtype: torch.dtype) -> tuple[_Backend, bool]:
7273
"""
7374
Get the available attention backend for Vision Transformer.
7475
@@ -79,7 +80,7 @@ def get_vit_attn_backend(head_size: int) -> tuple[_Backend, bool]:
7980
if selected_backend is not None:
8081
return selected_backend, False
8182

82-
return current_platform.get_vit_attn_backend(head_size)
83+
return current_platform.get_vit_attn_backend(head_size, dtype)
8384

8485

8586
def resolve_visual_encoder_outputs(

vllm/platforms/cuda.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,11 @@ def get_current_memory_usage(cls,
203203
return torch.cuda.max_memory_allocated(device)
204204

205205
@classmethod
206-
def get_vit_attn_backend(cls, head_size: int) -> tuple[_Backend, bool]:
206+
def get_vit_attn_backend(cls, head_size: int,
207+
dtype: torch.dtype) -> tuple[_Backend, bool]:
208+
if dtype not in (torch.float16, torch.bfloat16):
209+
return _Backend.XFORMERS, False
210+
207211
if cls.has_device_capability(80):
208212
if head_size % 32 == 0:
209213
# Use vllm-flash-attn

0 commit comments

Comments
 (0)