Skip to content

Commit 5931b7e

Browse files
anmarquesgemini-code-assist[bot]mgoin
authored
[Models][Quantization] Add quantization configuration update in Voxtral model (#24122)
Signed-off-by: Alexandre Marques <[email protected]> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Michael Goin <[email protected]>
1 parent cc99baf commit 5931b7e

File tree

2 files changed

+88
-4
lines changed

2 files changed

+88
-4
lines changed

vllm/model_executor/models/llama.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -626,9 +626,8 @@ def maybe_remap_mistral(
626626
loaded_weight: torch.Tensor,
627627
) -> tuple[str, torch.Tensor]:
628628

629-
def permute(w: torch.Tensor, n_heads: int):
629+
def permute(w: torch.Tensor, n_heads: int, attn_out: int):
630630
attn_in = self.config.head_dim * n_heads
631-
attn_out = self.config.hidden_size
632631

633632
return w.view(n_heads, attn_in // n_heads // 2, 2,
634633
attn_out).transpose(1, 2).reshape(attn_in, attn_out)
@@ -637,12 +636,24 @@ def permute(w: torch.Tensor, n_heads: int):
637636
modules = name.split(".")
638637

639638
# rotary embeds should be sliced
639+
# If using quantized model in mistral format,
640+
# quantization scales (qscale_weight) also need to be sliced
640641
if "wk" in modules and modules[-1] == "weight":
641642
loaded_weight = permute(loaded_weight,
642-
self.config.num_key_value_heads)
643+
self.config.num_key_value_heads,
644+
self.config.hidden_size)
645+
elif "wk" in modules and modules[
646+
-1] == "qscale_weight" and loaded_weight.numel() > 1:
647+
loaded_weight = permute(loaded_weight,
648+
self.config.num_key_value_heads, 1)
643649
elif "wq" in modules and modules[-1] == "weight":
644650
loaded_weight = permute(loaded_weight,
645-
self.config.num_attention_heads)
651+
self.config.num_attention_heads,
652+
self.config.hidden_size)
653+
elif "wq" in modules and modules[
654+
-1] == "qscale_weight" and loaded_weight.numel() > 1:
655+
loaded_weight = permute(loaded_weight,
656+
self.config.num_attention_heads, 1)
646657

647658
num_modules = len(modules)
648659
for i in range(num_modules):

vllm/model_executor/models/voxtral.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
2424
from vllm.inputs.data import PromptType
2525
from vllm.logger import init_logger
26+
from vllm.model_executor.layers.quantization import QuantizationConfig
2627
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2728
from vllm.model_executor.models import SupportsPP
2829
from vllm.model_executor.models.module_mapping import MultiModelKeys
@@ -327,6 +328,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
327328
super().__init__()
328329
self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
329330

331+
# update quant config to so that ignored module and target module names
332+
# match the vLLM model names
333+
if hasattr(vllm_config, "quant_config"):
334+
vllm_config.quant_config = self.maybe_update_quant_config(
335+
vllm_config.quant_config)
336+
330337
config = vllm_config.model_config.hf_config
331338
self.config = config
332339
self.downsample_factor = self.config.audio_config.downsample_factor
@@ -558,6 +565,72 @@ def llm_weights_generator():
558565

559566
return loaded_weights
560567

568+
def maybe_update_quant_config(
569+
self, quant_config: QuantizationConfig) -> QuantizationConfig:
570+
"""
571+
Update quant config to so that ignored module and target module names
572+
match the vLLM model names.
573+
Right now this is specific for compressed-tensors format and
574+
load_format mistral.
575+
"""
576+
remapping_rules = [
577+
(r"output", r"language_model.lm_head"),
578+
(r"layers\.(\d+)\.attention\.wo",
579+
r"language_model.model.layers.\1.self_attn.out_proj"),
580+
(r"layers\.(\d+)\.attention\.w(.*)",
581+
r"language_model.model.layers.\1.self_attn.\2_proj"),
582+
(r"layers\.(\d+)\.feed_forward\.w1",
583+
r"language_model.model.layers.\1.mlp.gate_proj"),
584+
(r"layers\.(\d+)\.feed_forward\.w2",
585+
r"language_model.model.layers.\1.mlp.down_proj"),
586+
(r"layers\.(\d+)\.feed_forward\.w3",
587+
r"language_model.model.layers.\1.mlp.up_proj"),
588+
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo",
589+
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj"
590+
),
591+
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)",
592+
r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj"
593+
),
594+
(r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)",
595+
r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"),
596+
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0",
597+
r"whisper_encoder.whisper_encoder.conv1"),
598+
(r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1",
599+
r"whisper_encoder.whisper_encoder.conv2"),
600+
(r"mm_whisper_embeddings\.audio_language_projection\.0",
601+
r"audio_language_adapter.w_in"),
602+
(r"mm_whisper_embeddings\.audio_language_projection\.2",
603+
r"audio_language_adapter.w_out"),
604+
]
605+
606+
# Update ignore list
607+
if hasattr(quant_config, "ignore"):
608+
mistral_ignore = []
609+
for name in quant_config.ignore:
610+
mistral_name = name
611+
for pattern, repl in remapping_rules:
612+
if re.fullmatch(pattern, name):
613+
mistral_name = re.sub(pattern, repl, name)
614+
mistral_ignore.append(mistral_name)
615+
quant_config.ignore = mistral_ignore
616+
617+
# Update target list
618+
if hasattr(quant_config, "config_groups"):
619+
config_groups = quant_config.config_groups
620+
for group_name in config_groups:
621+
if "targets" in config_groups[group_name]:
622+
targets = []
623+
for name in config_groups[group_name]["targets"]:
624+
mistral_name = name
625+
for pattern, repl in remapping_rules:
626+
if re.fullmatch(pattern, name):
627+
mistral_name = re.sub(pattern, repl, name)
628+
targets.append(mistral_name)
629+
config_groups[group_name]["targets"] = targets
630+
quant_config.config_groups = config_groups
631+
632+
return quant_config
633+
561634

562635
class AudioLanguageAdapter(nn.Module):
563636

0 commit comments

Comments
 (0)