|
1 | 1 | from abc import ABC, abstractmethod
|
2 |
| -from typing import Final, Generic, Protocol, TypeVar |
| 2 | +from typing import Final, Generic, Optional, Protocol, TypeVar, Union |
3 | 3 |
|
| 4 | +import torch |
4 | 5 | from transformers import PretrainedConfig
|
5 | 6 |
|
| 7 | +import vllm.envs as envs |
| 8 | +from vllm.attention.selector import (backend_name_to_enum, |
| 9 | + get_global_forced_attn_backend) |
| 10 | +from vllm.platforms import _Backend, current_platform |
| 11 | +from vllm.utils import print_warning_once |
| 12 | + |
6 | 13 | _C = TypeVar("_C", bound=PretrainedConfig)
|
7 | 14 |
|
8 | 15 |
|
@@ -60,3 +67,77 @@ def get_vision_encoder_info(
|
60 | 67 |
|
61 | 68 | msg = f"Unsupported vision config: {type(vision_config)}"
|
62 | 69 | raise NotImplementedError(msg)
|
| 70 | + |
| 71 | + |
| 72 | +def get_vit_attn_backend(support_fa: bool = False) -> _Backend: |
| 73 | + """ |
| 74 | + Get the available attention backend for Vision Transformer. |
| 75 | + """ |
| 76 | + # TODO(Isotr0py): Remove `support_fa` after support FA for all ViTs attn. |
| 77 | + selected_backend: Optional[_Backend] = get_global_forced_attn_backend() |
| 78 | + if selected_backend is None: |
| 79 | + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND |
| 80 | + if backend_by_env_var is not None: |
| 81 | + selected_backend = backend_name_to_enum(backend_by_env_var) |
| 82 | + if selected_backend is None: |
| 83 | + # For Volta and Turing GPUs, use xformers instead. |
| 84 | + device_available = current_platform.has_device_capability(80) |
| 85 | + if device_available and support_fa: |
| 86 | + from transformers.utils import is_flash_attn_2_available |
| 87 | + if is_flash_attn_2_available(): |
| 88 | + selected_backend = _Backend.FLASH_ATTN |
| 89 | + else: |
| 90 | + print_warning_once( |
| 91 | + "Current `vllm-flash-attn` has a bug inside vision module, " |
| 92 | + "so we use xformers backend instead. You can run " |
| 93 | + "`pip install flash-attn` to use flash-attention backend.") |
| 94 | + selected_backend = _Backend.XFORMERS |
| 95 | + elif current_platform.is_cpu() or current_platform.is_rocm(): |
| 96 | + # ROCM doesn't support xformers |
| 97 | + selected_backend = _Backend.TORCH_SDPA |
| 98 | + else: |
| 99 | + selected_backend = _Backend.XFORMERS |
| 100 | + return selected_backend |
| 101 | + |
| 102 | + |
| 103 | +def resolve_visual_encoder_outputs( |
| 104 | + encoder_outputs: Union[torch.Tensor, list[torch.Tensor]], |
| 105 | + feature_sample_layers: Optional[list[int]], |
| 106 | + post_layer_norm: Optional[torch.nn.LayerNorm], |
| 107 | + max_possible_layers: int, |
| 108 | +) -> torch.Tensor: |
| 109 | + """Given the outputs a visual encoder module that may correspond to the |
| 110 | + output of the last layer, or a list of hidden states to be stacked, |
| 111 | + handle post normalization and resolve it into a single output tensor. |
| 112 | +
|
| 113 | + Args: |
| 114 | + encoder_outputs: Output of encoder's last layer or all hidden states. |
| 115 | + feature_sample_layers: Optional layer indices to grab from the encoder |
| 116 | + outputs; if provided, encoder outputs must be a list. |
| 117 | + post_layer_norm: Post norm to apply to the output of the encoder. |
| 118 | + max_possible_layers: Total layers in the fully loaded visual encoder. |
| 119 | +
|
| 120 | + """ |
| 121 | + if feature_sample_layers is None: |
| 122 | + if post_layer_norm is not None: |
| 123 | + return post_layer_norm(encoder_outputs) |
| 124 | + return encoder_outputs |
| 125 | + |
| 126 | + # Get the hidden states corresponding to the layer indices. |
| 127 | + # Negative values are relative to the full visual encoder, |
| 128 | + # so offset them depending on how many layers were loaded. |
| 129 | + # NOTE: this assumes that encoder_outputs contains a list |
| 130 | + # of hidden states in the same order as the encoder layers |
| 131 | + # that produced them. |
| 132 | + offset = max_possible_layers - len(encoder_outputs) |
| 133 | + hs_pool = [ |
| 134 | + encoder_outputs[layer_idx] |
| 135 | + if layer_idx >= 0 else encoder_outputs[layer_idx + offset] |
| 136 | + for layer_idx in feature_sample_layers |
| 137 | + ] |
| 138 | + |
| 139 | + # Apply post-norm on the final hidden state if we are using it |
| 140 | + uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1) |
| 141 | + if post_layer_norm is not None and uses_last_layer: |
| 142 | + hs_pool[-1] = post_layer_norm(encoder_outputs) |
| 143 | + return torch.cat(hs_pool, dim=-1) |
0 commit comments