Skip to content

Commit 8251010

Browse files
DarkLight1337Isotr0py
authored andcommitted
[Misc] Move some model utils into vision file (vllm-project#11848)
Signed-off-by: DarkLight1337 <[email protected]> Signed-off-by: Isotr0py <[email protected]>
1 parent d0e64de commit 8251010

File tree

8 files changed

+94
-92
lines changed

8 files changed

+94
-92
lines changed

vllm/model_executor/models/clip.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,10 @@
2020
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2121
from vllm.multimodal.utils import (cached_get_tokenizer,
2222
consecutive_placeholder_ranges,
23-
repeat_and_pad_placeholder_tokens,
24-
resolve_visual_encoder_outputs)
23+
repeat_and_pad_placeholder_tokens)
2524
from vllm.sequence import SequenceData
2625

27-
from .vision import VisionEncoderInfo
26+
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
2827

2928

3029
def get_clip_patch_grid_length(*, image_size: int, patch_size: int) -> int:

vllm/model_executor/models/pixtral.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,13 @@
3131
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
3232
from vllm.multimodal.inputs import NestedTensors, PlaceholderRange
3333
from vllm.multimodal.utils import (cached_get_tokenizer,
34-
consecutive_placeholder_ranges,
35-
resolve_visual_encoder_outputs)
34+
consecutive_placeholder_ranges)
3635
from vllm.sequence import IntermediateTensors, SequenceData
3736

3837
from .interfaces import SupportsMultiModal, SupportsPP
3938
from .utils import (init_vllm_registered_model, maybe_prefix,
4039
merge_multimodal_embeddings)
41-
from .vision import VisionEncoderInfo
40+
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
4241

4342
try:
4443
from xformers import ops as xops

vllm/model_executor/models/qwen2_vl.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,9 @@
6666
from vllm.transformers_utils.config import uses_mrope
6767

6868
from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP
69-
from .utils import (AutoWeightsLoader, WeightsMapper, get_vit_attn_backend,
69+
from .utils import (AutoWeightsLoader, WeightsMapper,
7070
init_vllm_registered_model, maybe_prefix)
71+
from .vision import get_vit_attn_backend
7172

7273
logger = init_logger(__name__)
7374

vllm/model_executor/models/siglip.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,10 @@
2424
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2525
from vllm.multimodal.utils import (cached_get_tokenizer,
2626
consecutive_placeholder_ranges,
27-
repeat_and_pad_placeholder_tokens,
28-
resolve_visual_encoder_outputs)
27+
repeat_and_pad_placeholder_tokens)
2928
from vllm.sequence import SequenceData
3029

31-
from .vision import VisionEncoderInfo
30+
from .vision import VisionEncoderInfo, resolve_visual_encoder_outputs
3231

3332

3433
def get_siglip_patch_grid_length(*, image_size: int, patch_size: int) -> int:

vllm/model_executor/models/utils.py

Lines changed: 1 addition & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,12 @@
88
from torch.func import functional_call
99
from transformers import PretrainedConfig
1010

11-
import vllm.envs as envs
12-
from vllm.attention.selector import (backend_name_to_enum,
13-
get_global_forced_attn_backend)
1411
from vllm.config import VllmConfig
1512
from vllm.logger import init_logger
1613
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
1714
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
18-
from vllm.platforms import _Backend, current_platform
1915
from vllm.sequence import IntermediateTensors
20-
from vllm.utils import is_pin_memory_available, print_warning_once
16+
from vllm.utils import is_pin_memory_available
2117

2218
logger = init_logger(__name__)
2319

@@ -612,37 +608,6 @@ def make_empty_intermediate_tensors(
612608
return make_empty_intermediate_tensors
613609

614610

615-
def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
616-
"""
617-
Get the available attention backend for Vision Transformer.
618-
"""
619-
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
620-
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
621-
if selected_backend is None:
622-
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
623-
if backend_by_env_var is not None:
624-
selected_backend = backend_name_to_enum(backend_by_env_var)
625-
if selected_backend is None:
626-
# For Volta and Turing GPUs, use xformers instead.
627-
device_available = current_platform.has_device_capability(80)
628-
if device_available and support_fa:
629-
from transformers.utils import is_flash_attn_2_available
630-
if is_flash_attn_2_available():
631-
selected_backend = _Backend.FLASH_ATTN
632-
else:
633-
print_warning_once(
634-
"Current `vllm-flash-attn` has a bug inside vision module, "
635-
"so we use xformers backend instead. You can run "
636-
"`pip install flash-attn` to use flash-attention backend.")
637-
selected_backend = _Backend.XFORMERS
638-
elif current_platform.is_cpu() or current_platform.is_rocm():
639-
# ROCM doesn't support xformers
640-
selected_backend = _Backend.TORCH_SDPA
641-
else:
642-
selected_backend = _Backend.XFORMERS
643-
return selected_backend
644-
645-
646611
def maybe_prefix(prefix: str, name: str) -> str:
647612
"""Add a prefix to a name if the prefix is non-empty.
648613

vllm/model_executor/models/vision.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,15 @@
11
from abc import ABC, abstractmethod
2-
from typing import Final, Generic, Protocol, TypeVar
2+
from typing import Final, Generic, Optional, Protocol, TypeVar, Union
33

4+
import torch
45
from transformers import PretrainedConfig
56

7+
import vllm.envs as envs
8+
from vllm.attention.selector import (backend_name_to_enum,
9+
get_global_forced_attn_backend)
10+
from vllm.platforms import _Backend, current_platform
11+
from vllm.utils import print_warning_once
12+
613
_C = TypeVar("_C", bound=PretrainedConfig)
714

815

@@ -60,3 +67,77 @@ def get_vision_encoder_info(
6067

6168
msg = f"Unsupported vision config: {type(vision_config)}"
6269
raise NotImplementedError(msg)
70+
71+
72+
def get_vit_attn_backend(support_fa: bool = False) -> _Backend:
73+
"""
74+
Get the available attention backend for Vision Transformer.
75+
"""
76+
# TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn.
77+
selected_backend: Optional[_Backend] = get_global_forced_attn_backend()
78+
if selected_backend is None:
79+
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
80+
if backend_by_env_var is not None:
81+
selected_backend = backend_name_to_enum(backend_by_env_var)
82+
if selected_backend is None:
83+
# For Volta and Turing GPUs, use xformers instead.
84+
device_available = current_platform.has_device_capability(80)
85+
if device_available and support_fa:
86+
from transformers.utils import is_flash_attn_2_available
87+
if is_flash_attn_2_available():
88+
selected_backend = _Backend.FLASH_ATTN
89+
else:
90+
print_warning_once(
91+
"Current `vllm-flash-attn` has a bug inside vision module, "
92+
"so we use xformers backend instead. You can run "
93+
"`pip install flash-attn` to use flash-attention backend.")
94+
selected_backend = _Backend.XFORMERS
95+
elif current_platform.is_cpu() or current_platform.is_rocm():
96+
# ROCM doesn't support xformers
97+
selected_backend = _Backend.TORCH_SDPA
98+
else:
99+
selected_backend = _Backend.XFORMERS
100+
return selected_backend
101+
102+
103+
def resolve_visual_encoder_outputs(
104+
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
105+
feature_sample_layers: Optional[list[int]],
106+
post_layer_norm: Optional[torch.nn.LayerNorm],
107+
max_possible_layers: int,
108+
) -> torch.Tensor:
109+
"""Given the outputs a visual encoder module that may correspond to the
110+
output of the last layer, or a list of hidden states to be stacked,
111+
handle post normalization and resolve it into a single output tensor.
112+
113+
Args:
114+
encoder_outputs: Output of encoder's last layer or all hidden states.
115+
feature_sample_layers: Optional layer indices to grab from the encoder
116+
outputs; if provided, encoder outputs must be a list.
117+
post_layer_norm: Post norm to apply to the output of the encoder.
118+
max_possible_layers: Total layers in the fully loaded visual encoder.
119+
120+
"""
121+
if feature_sample_layers is None:
122+
if post_layer_norm is not None:
123+
return post_layer_norm(encoder_outputs)
124+
return encoder_outputs
125+
126+
# Get the hidden states corresponding to the layer indices.
127+
# Negative values are relative to the full visual encoder,
128+
# so offset them depending on how many layers were loaded.
129+
# NOTE: this assumes that encoder_outputs contains a list
130+
# of hidden states in the same order as the encoder layers
131+
# that produced them.
132+
offset = max_possible_layers - len(encoder_outputs)
133+
hs_pool = [
134+
encoder_outputs[layer_idx]
135+
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
136+
for layer_idx in feature_sample_layers
137+
]
138+
139+
# Apply post-norm on the final hidden state if we are using it
140+
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
141+
if post_layer_norm is not None and uses_last_layer:
142+
hs_pool[-1] = post_layer_norm(encoder_outputs)
143+
return torch.cat(hs_pool, dim=-1)

vllm/multimodal/inputs.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ class MultiModalDataBuiltins(TypedDict, total=False):
9999
MultiModalDataDict: TypeAlias = Mapping[str, ModalityData[Any]]
100100
"""
101101
A dictionary containing an entry for each modality type to input.
102+
103+
The built-in modalities are defined by :class:`MultiModalDataBuiltins`.
102104
"""
103105

104106

@@ -485,7 +487,7 @@ def get_items(self, modality: str) -> Sequence[MultiModalKwargsItem]:
485487

486488
MultiModalPlaceholderDict = Mapping[str, Sequence[PlaceholderRange]]
487489
"""
488-
A dictionary containing placeholder ranges.
490+
A dictionary containing placeholder ranges for each modality.
489491
"""
490492

491493

vllm/multimodal/utils.py

Lines changed: 0 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
import numpy as np
77
import numpy.typing as npt
8-
import torch
98
from PIL import Image
109

1110
import vllm.envs as envs
@@ -285,49 +284,6 @@ def encode_video_base64(frames: npt.NDArray) -> str:
285284
return video_io.encode_base64(frames)
286285

287286

288-
def resolve_visual_encoder_outputs(
289-
encoder_outputs: Union[torch.Tensor, list[torch.Tensor]],
290-
feature_sample_layers: Optional[list[int]],
291-
post_layer_norm: Optional[torch.nn.LayerNorm],
292-
max_possible_layers: int,
293-
) -> torch.Tensor:
294-
"""Given the outputs a visual encoder module that may correspond to the
295-
output of the last layer, or a list of hidden states to be stacked,
296-
handle post normalization and resolve it into a single output tensor.
297-
298-
Args:
299-
encoder_outputs: Output of encoder's last layer or all hidden states.
300-
feature_sample_layers: Optional layer indices to grab from the encoder
301-
outputs; if provided, encoder outputs must be a list.
302-
post_layer_norm: Post norm to apply to the output of the encoder.
303-
max_possible_layers: Total layers in the fully loaded visual encoder.
304-
305-
"""
306-
if feature_sample_layers is None:
307-
if post_layer_norm is not None:
308-
return post_layer_norm(encoder_outputs)
309-
return encoder_outputs
310-
311-
# Get the hidden states corresponding to the layer indices.
312-
# Negative values are relative to the full visual encoder,
313-
# so offset them depending on how many layers were loaded.
314-
# NOTE: this assumes that encoder_outputs contains a list
315-
# of hidden states in the same order as the encoder layers
316-
# that produced them.
317-
offset = max_possible_layers - len(encoder_outputs)
318-
hs_pool = [
319-
encoder_outputs[layer_idx]
320-
if layer_idx >= 0 else encoder_outputs[layer_idx + offset]
321-
for layer_idx in feature_sample_layers
322-
]
323-
324-
# Apply post-norm on the final hidden state if we are using it
325-
uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1)
326-
if post_layer_norm is not None and uses_last_layer:
327-
hs_pool[-1] = post_layer_norm(encoder_outputs)
328-
return torch.cat(hs_pool, dim=-1)
329-
330-
331287
# Utilities for input processors
332288
_T = TypeVar("_T", str, int)
333289

0 commit comments

Comments
 (0)