Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/entrypoints/openai/test_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@
],
[
"The image shows a Venn diagram with three over",
"The image shows a Venn diagram with three intersect",
"This image shows a Venn diagram with three over",
],
[
"This image displays a gradient of colors ranging from",
"The image displays a gradient of colors ranging from",
"This image displays a gradient of colors forming a spectrum",
],
]

Expand Down
49 changes: 35 additions & 14 deletions tests/kernels/attention/test_mha_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,31 +36,52 @@ def test_mha_attn_platform(device: str):
torch.set_default_dtype(torch.float16)

if device == "cpu":
with patch("vllm.attention.selector.current_platform",
CpuPlatform()), \
patch("vllm.platforms.current_platform", CpuPlatform()):
with patch("vllm.attention.layer.current_platform", CpuPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
CpuPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA_VLLM_V1
assert attn.attn_backend == _Backend.TORCH_SDPA
elif device == "hip":
with patch("vllm.attention.selector.current_platform",
RocmPlatform()), \
patch("vllm.platforms.current_platform", RocmPlatform()), \
patch("vllm.attention.layer.current_platform", RocmPlatform()):
with patch("vllm.attention.layer.current_platform", RocmPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
RocmPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.TORCH_SDPA
else:
with patch("vllm.attention.selector.current_platform",
CudaPlatform()), \
patch("vllm.platforms.current_platform", CudaPlatform()):
# Test CUDA with head_size=64 (divisible by 32)
# - should use vLLM's FlashAttention
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
CudaPlatform()):
attn = MultiHeadAttention(16, 64, scale=1)
assert attn.attn_backend == _Backend.XFORMERS
assert attn.attn_backend == _Backend.FLASH_ATTN

with patch("vllm.attention.selector.current_platform",
# Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA not available
# - should use xformers
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
CudaPlatform()), \
patch("vllm.platforms.current_platform", CudaPlatform()):
patch("vllm.attention.layer.check_upstream_fa_availability",
return_value=False):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.XFORMERS

# Test CUDA with head_size=72 (not divisible by 32)
# - with upstream FA available
# - should use upstream FA
with patch("vllm.attention.layer.current_platform", CudaPlatform()), \
patch("vllm.model_executor.models.vision.current_platform",
CudaPlatform()), \
patch("vllm.attention.layer.check_upstream_fa_availability",
return_value=True), \
patch.dict('sys.modules', {'flash_attn': type('MockFlashAttn', (),
{
'flash_attn_varlen_func': lambda *args, **kwargs: None
})()}):
attn = MultiHeadAttention(16, 72, scale=1)
assert attn.attn_backend == _Backend.FLASH_ATTN


def ref_attention(
query: torch.Tensor,
Expand Down
75 changes: 67 additions & 8 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.model_executor.models.vision import get_vit_attn_backend
from vllm.platforms import _Backend, current_platform
from vllm.utils import direct_register_custom_op

Expand Down Expand Up @@ -55,6 +56,14 @@ def check_xformers_availability():
return USE_XFORMERS_OPS


def check_upstream_fa_availability(dtype: torch.dtype):
if dtype in (torch.float16, torch.bfloat16) and current_platform.is_cuda(
) and current_platform.has_device_capability(80):
from transformers.utils import is_flash_attn_2_available
return is_flash_attn_2_available()
return False


class Attention(nn.Module, AttentionLayerBase):
"""Attention layer.

Expand Down Expand Up @@ -349,29 +358,55 @@ def __init__(
f"divisible by num_kv_heads ({self.num_kv_heads})"
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype=None,
block_size=16,
is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name())

# Determine the attention backend
backend = get_vit_attn_backend(head_size=head_size, dtype=dtype)

# Some auto-selected backends can be upgraded
# to upstream flash attention if available.
# If vllm native fa is selected, we use it directly.
use_upstream_fa = False
if backend != _Backend.FLASH_ATTN and check_upstream_fa_availability(
dtype):
backend = _Backend.FLASH_ATTN
use_upstream_fa = True

if current_platform.is_rocm():
# currently, only torch_sdpa is supported on rocm
self.attn_backend = _Backend.TORCH_SDPA
else:

self.attn_backend = backend if backend in {
_Backend.TORCH_SDPA,
_Backend.TORCH_SDPA_VLLM_V1,
_Backend.XFORMERS,
_Backend.PALLAS_VLLM_V1,
_Backend.ROCM_AITER_FA,
} else current_platform.get_vit_attn_backend()
_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1,
} else _Backend.TORCH_SDPA

if (self.attn_backend == _Backend.XFORMERS
and not check_xformers_availability()):
self.attn_backend = _Backend.TORCH_SDPA

if self.attn_backend in {
_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1
}:
if use_upstream_fa:
from flash_attn import flash_attn_varlen_func
self._flash_attn_varlen_func = flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func
self._flash_attn_varlen_func = flash_attn_varlen_func

logger.info_once(
f"MultiHeadAttention attn_backend: {self.attn_backend}, "
f"use_upstream_fa: {use_upstream_fa}")

def forward(
self,
query: torch.Tensor,
Expand All @@ -392,7 +427,31 @@ def forward(
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)

if self.attn_backend == _Backend.XFORMERS:
if self.attn_backend in {
_Backend.FLASH_ATTN,
_Backend.FLASH_ATTN_VLLM_V1,
}:

cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=query.device)
cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len,
step=kv_len,
dtype=torch.int32,
device=key.device)

out = self._flash_attn_varlen_func(
query.flatten(0, 1),
key.flatten(0, 1),
value.flatten(0, 1),
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=q_len,
max_seqlen_k=kv_len,
softmax_scale=self.scale,
)
elif self.attn_backend == _Backend.XFORMERS:
from xformers import ops as xops

out = xops.memory_efficient_attention_forward(query,
Expand Down
23 changes: 20 additions & 3 deletions vllm/model_executor/models/ernie45_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from einops import rearrange, repeat
from transformers import BatchFeature

from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
Expand Down Expand Up @@ -170,7 +171,16 @@ def __init__(
prefix=f"{prefix}.proj")

# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())

self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True

if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
Expand Down Expand Up @@ -233,7 +243,10 @@ def forward(
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
from flash_attn import flash_attn_varlen_func
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

Expand Down Expand Up @@ -457,7 +470,11 @@ def __init__(
), "vit's config.hidden must be equal to config.embed_dim"
self.ln = nn.LayerNorm(hidden_size, eps=1e-6)

self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN

@property
def dtype(self) -> torch.dtype:
Expand Down
22 changes: 19 additions & 3 deletions vllm/model_executor/models/glm4_1v.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
Glm4vVideoProcessor)
from transformers.video_utils import VideoMetadata

from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
parallel_state)
Expand Down Expand Up @@ -260,7 +261,15 @@ def __init__(
)

# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True

if self.attn_backend not in {
_Backend.FLASH_ATTN,
_Backend.TORCH_SDPA,
Expand Down Expand Up @@ -310,7 +319,10 @@ def forward(
if self.attn_backend == _Backend.FLASH_ATTN:
# from vllm_flash_attn.flash_attn_interface import (
# flash_attn_varlen_func)
from flash_attn import flash_attn_varlen_func
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

Expand Down Expand Up @@ -715,7 +727,11 @@ def __init__(
self.post_layernorm = RMSNorm(vision_config.hidden_size,
eps=vision_config.rms_norm_eps)

self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN

@property
def dtype(self) -> torch.dtype:
Expand Down
17 changes: 15 additions & 2 deletions vllm/model_executor/models/keye.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
BaseModelOutputWithPooling)
from transformers.utils import torch_int

from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger
Expand Down Expand Up @@ -374,7 +375,16 @@ def __init__(
)

# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
self.attn_backend = get_vit_attn_backend(
head_size=self.head_dim, dtype=torch.get_default_dtype())

self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True

if self.attn_backend not in {_Backend.FLASH_ATTN, _Backend.XFORMERS}:
raise RuntimeError(
f"Keye-VL does not support {self.attn_backend} backend now.")
Expand Down Expand Up @@ -428,7 +438,10 @@ def forward(
)

if self.attn_backend == _Backend.FLASH_ATTN:
from flash_attn import flash_attn_varlen_func
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

Expand Down
24 changes: 21 additions & 3 deletions vllm/model_executor/models/qwen2_5_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from transformers.models.qwen2_5_vl.configuration_qwen2_5_vl import (
Qwen2_5_VLConfig, Qwen2_5_VLVisionConfig)

from vllm.attention.layer import check_upstream_fa_availability
from vllm.config import VllmConfig
from vllm.distributed import parallel_state
from vllm.distributed import utils as dist_utils
Expand Down Expand Up @@ -298,7 +299,16 @@ def __init__(
disable_tp=use_data_parallel)

# Detect attention implementation.
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
self.attn_backend = get_vit_attn_backend(
head_size=self.hidden_size_per_attention_head,
dtype=torch.get_default_dtype())
self.use_upstream_fa = False
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN
self.use_upstream_fa = True

if self.attn_backend not in {
_Backend.FLASH_ATTN, _Backend.TORCH_SDPA, _Backend.XFORMERS,
_Backend.ROCM_AITER_FA
Expand Down Expand Up @@ -359,7 +369,10 @@ def forward(
if self.attn_backend == _Backend.ROCM_AITER_FA:
from aiter import flash_attn_varlen_func
else:
from flash_attn import flash_attn_varlen_func
if self.use_upstream_fa:
from flash_attn import flash_attn_varlen_func
else:
from vllm.vllm_flash_attn import flash_attn_varlen_func

q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])

Expand Down Expand Up @@ -628,7 +641,12 @@ def __init__(
prefix=f"{prefix}.merger",
use_data_parallel=use_data_parallel,
)
self.attn_backend: _Backend = get_vit_attn_backend(support_fa=True)
self.attn_backend = get_vit_attn_backend(
head_size=head_dim, dtype=torch.get_default_dtype())
if self.attn_backend != _Backend.FLASH_ATTN and \
check_upstream_fa_availability(
torch.get_default_dtype()):
self.attn_backend = _Backend.FLASH_ATTN

@property
def dtype(self) -> torch.dtype:
Expand Down
Loading
Loading