Skip to content
19 changes: 15 additions & 4 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -626,9 +626,8 @@ def maybe_remap_mistral(
loaded_weight: torch.Tensor,
) -> tuple[str, torch.Tensor]:

def permute(w: torch.Tensor, n_heads: int):
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
attn_in = self.config.head_dim * n_heads
attn_out = self.config.hidden_size

return w.view(n_heads, attn_in // n_heads // 2, 2,
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
Expand All @@ -637,12 +636,24 @@ def permute(w: torch.Tensor, n_heads: int):
modules = name.split(".")

# rotary embeds should be sliced
# If using quantized model in mistral format,
# quantization scales (qscale_weight) also need to be sliced
if "wk" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads)
self.config.num_key_value_heads,
self.config.hidden_size)
elif "wk" in modules and modules[
-1] == "qscale_weight" and loaded_weight.numel() > 1:
loaded_weight = permute(loaded_weight,
self.config.num_key_value_heads, 1)
elif "wq" in modules and modules[-1] == "weight":
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads)
self.config.num_attention_heads,
self.config.hidden_size)
elif "wq" in modules and modules[
-1] == "qscale_weight" and loaded_weight.numel() > 1:
loaded_weight = permute(loaded_weight,
self.config.num_attention_heads, 1)

num_modules = len(modules)
for i in range(num_modules):
Expand Down
73 changes: 73 additions & 0 deletions vllm/model_executor/models/voxtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
from vllm.inputs.data import PromptType
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models import SupportsPP
from vllm.model_executor.models.module_mapping import MultiModelKeys
Expand Down Expand Up @@ -327,6 +328,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)

# update quant config to so that ignored module and target module names
# match the vLLM model names
if hasattr(vllm_config, "quant_config"):
vllm_config.quant_config = self.maybe_update_quant_config(
vllm_config.quant_config)

config = vllm_config.model_config.hf_config
self.config = config
self.downsample_factor = self.config.audio_config.downsample_factor
Expand Down Expand Up @@ -558,6 +565,72 @@ def llm_weights_generator():

return loaded_weights

def maybe_update_quant_config(
self, quant_config: QuantizationConfig) -> QuantizationConfig:
"""
Update quant config to so that ignored module and target module names
match the vLLM model names.
Right now this is specific for compressed-tensors format and
load_format mistral.
"""
remapping_rules = [
(r"output", r"language_model.lm_head"),
(r"layers\.(\d+)\.attention\.wo",
r"language_model.model.layers.\1.self_attn.out_proj"),
(r"layers\.(\d+)\.attention\.w(.*)",
r"language_model.model.layers.\1.self_attn.\2_proj"),
(r"layers\.(\d+)\.feed_forward\.w1",
r"language_model.model.layers.\1.mlp.gate_proj"),
(r"layers\.(\d+)\.feed_forward\.w2",
r"language_model.model.layers.\1.mlp.down_proj"),
(r"layers\.(\d+)\.feed_forward\.w3",
r"language_model.model.layers.\1.mlp.up_proj"),
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj"
),
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj"
),
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"),
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
r"whisper_encoder.whisper_encoder.conv1"),
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
r"whisper_encoder.whisper_encoder.conv2"),
(r"mm_whisper_embeddings\.audio_language_projection\.0",
r"audio_language_adapter.w_in"),
(r"mm_whisper_embeddings\.audio_language_projection\.2",
r"audio_language_adapter.w_out"),
]

# Update ignore list
if hasattr(quant_config, "ignore"):
mistral_ignore = []
for name in quant_config.ignore:
mistral_name = name
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
mistral_name = re.sub(pattern, repl, name)
mistral_ignore.append(mistral_name)
quant_config.ignore = mistral_ignore

# Update target list
if hasattr(quant_config, "config_groups"):
config_groups = quant_config.config_groups
for group_name in config_groups:
if "targets" in config_groups[group_name]:
targets = []
for name in config_groups[group_name]["targets"]:
mistral_name = name
for pattern, repl in remapping_rules:
if re.fullmatch(pattern, name):
mistral_name = re.sub(pattern, repl, name)
targets.append(mistral_name)
config_groups[group_name]["targets"] = targets
quant_config.config_groups = config_groups

return quant_config


class AudioLanguageAdapter(nn.Module):

Expand Down