@@ -33,19 +33,38 @@ def test_mha_attn_platform(device: str):
33
33
torch .set_default_dtype (torch .float16 )
34
34
35
35
if device == "cpu" :
36
- with patch ("vllm.attention.selector.current_platform" , CpuPlatform ()):
36
+ with patch ("vllm.model_executor.models.vision.current_platform" ,
37
+ CpuPlatform ()):
37
38
attn = MultiHeadAttention (16 , 64 , scale = 1 )
38
39
assert attn .attn_backend == _Backend .TORCH_SDPA
39
40
elif device == "hip" :
40
- with patch ("vllm.attention.selector.current_platform" , RocmPlatform ()):
41
+ with patch ("vllm.model_executor.models.vision.current_platform" ,
42
+ RocmPlatform ()):
41
43
attn = MultiHeadAttention (16 , 64 , scale = 1 )
42
44
assert attn .attn_backend == _Backend .TORCH_SDPA
43
45
else :
44
- with patch ("vllm.attention.selector.current_platform" , CudaPlatform ()):
46
+ # Test CUDA with head_size=64 (divisible by 32)
47
+ # - should use vLLM FlashAttention
48
+ with patch ("vllm.model_executor.models.vision.current_platform" ,
49
+ CudaPlatform ()):
45
50
attn = MultiHeadAttention (16 , 64 , scale = 1 )
46
- assert attn .attn_backend == _Backend .XFORMERS
47
-
48
- with patch ("vllm.attention.selector.current_platform" , CudaPlatform ()):
51
+ assert attn .attn_backend == _Backend .FLASH_ATTN
52
+
53
+ # Test CUDA with head_size=72 (not divisible by 32)
54
+ # - upstream FA available
55
+ with patch ("vllm.model_executor.models.vision.current_platform" ,
56
+ CudaPlatform ()), \
57
+ patch ("transformers.utils.is_flash_attn_2_available" ,
58
+ return_value = True ):
59
+ attn = MultiHeadAttention (16 , 72 , scale = 1 )
60
+ assert attn .attn_backend == _Backend .FLASH_ATTN
61
+
62
+ # Test CUDA with head_size=72 (not divisible by 32)
63
+ # - upstream FA not available
64
+ with patch ("vllm.model_executor.models.vision.current_platform" ,
65
+ CudaPlatform ()), \
66
+ patch ("transformers.utils.is_flash_attn_2_available" ,
67
+ return_value = False ):
49
68
attn = MultiHeadAttention (16 , 72 , scale = 1 )
50
69
assert attn .attn_backend == _Backend .XFORMERS
51
70
0 commit comments