|
23 | 23 | from vllm.config import ModelConfig, SpeechToTextConfig, VllmConfig
|
24 | 24 | from vllm.inputs.data import PromptType
|
25 | 25 | from vllm.logger import init_logger
|
| 26 | +from vllm.model_executor.layers.quantization import QuantizationConfig |
26 | 27 | from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
27 | 28 | from vllm.model_executor.models import SupportsPP
|
28 | 29 | from vllm.model_executor.models.module_mapping import MultiModelKeys
|
@@ -327,6 +328,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
327 | 328 | super().__init__()
|
328 | 329 | self.tokenizer = cached_tokenizer_from_config(vllm_config.model_config)
|
329 | 330 |
|
| 331 | + # update quant config to so that ignored module and target module names |
| 332 | + # match the vLLM model names |
| 333 | + if hasattr(vllm_config, "quant_config"): |
| 334 | + vllm_config.quant_config = self.maybe_update_quant_config( |
| 335 | + vllm_config.quant_config) |
| 336 | + |
330 | 337 | config = vllm_config.model_config.hf_config
|
331 | 338 | self.config = config
|
332 | 339 | self.downsample_factor = self.config.audio_config.downsample_factor
|
@@ -558,6 +565,72 @@ def llm_weights_generator():
|
558 | 565 |
|
559 | 566 | return loaded_weights
|
560 | 567 |
|
| 568 | + def maybe_update_quant_config( |
| 569 | + self, quant_config: QuantizationConfig) -> QuantizationConfig: |
| 570 | + """ |
| 571 | + Update quant config to so that ignored module and target module names |
| 572 | + match the vLLM model names. |
| 573 | + Right now this is specific for compressed-tensors format and |
| 574 | + load_format mistral. |
| 575 | + """ |
| 576 | + remapping_rules = [ |
| 577 | + (r"output", r"language_model.lm_head"), |
| 578 | + (r"layers\.(\d+)\.attention\.wo", |
| 579 | + r"language_model.model.layers.\1.self_attn.out_proj"), |
| 580 | + (r"layers\.(\d+)\.attention\.w(.*)", |
| 581 | + r"language_model.model.layers.\1.self_attn.\2_proj"), |
| 582 | + (r"layers\.(\d+)\.feed_forward\.w1", |
| 583 | + r"language_model.model.layers.\1.mlp.gate_proj"), |
| 584 | + (r"layers\.(\d+)\.feed_forward\.w2", |
| 585 | + r"language_model.model.layers.\1.mlp.down_proj"), |
| 586 | + (r"layers\.(\d+)\.feed_forward\.w3", |
| 587 | + r"language_model.model.layers.\1.mlp.up_proj"), |
| 588 | + (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.wo", |
| 589 | + r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.out_proj" |
| 590 | + ), |
| 591 | + (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.attention.w(.*)", |
| 592 | + r"whisper_encoder.whisper_encoder.layers.\1.layers.self_attn.\2_proj" |
| 593 | + ), |
| 594 | + (r"mm_whisper_embeddings\.whisper_encoder\.transformer\.layers\.(\d+)\.feed_forward.w(\d+)", |
| 595 | + r"whisper_encoder.whisper_encoder.layers.\1.layers.mlp.fc\2"), |
| 596 | + (r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.0", |
| 597 | + r"whisper_encoder.whisper_encoder.conv1"), |
| 598 | + (r"mm_whisper_embeddings\.whisper_encoder\.conv_layers\.1", |
| 599 | + r"whisper_encoder.whisper_encoder.conv2"), |
| 600 | + (r"mm_whisper_embeddings\.audio_language_projection\.0", |
| 601 | + r"audio_language_adapter.w_in"), |
| 602 | + (r"mm_whisper_embeddings\.audio_language_projection\.2", |
| 603 | + r"audio_language_adapter.w_out"), |
| 604 | + ] |
| 605 | + |
| 606 | + # Update ignore list |
| 607 | + if hasattr(quant_config, "ignore"): |
| 608 | + mistral_ignore = [] |
| 609 | + for name in quant_config.ignore: |
| 610 | + mistral_name = name |
| 611 | + for pattern, repl in remapping_rules: |
| 612 | + if re.fullmatch(pattern, name): |
| 613 | + mistral_name = re.sub(pattern, repl, name) |
| 614 | + mistral_ignore.append(mistral_name) |
| 615 | + quant_config.ignore = mistral_ignore |
| 616 | + |
| 617 | + # Update target list |
| 618 | + if hasattr(quant_config, "config_groups"): |
| 619 | + config_groups = quant_config.config_groups |
| 620 | + for group_name in config_groups: |
| 621 | + if "targets" in config_groups[group_name]: |
| 622 | + targets = [] |
| 623 | + for name in config_groups[group_name]["targets"]: |
| 624 | + mistral_name = name |
| 625 | + for pattern, repl in remapping_rules: |
| 626 | + if re.fullmatch(pattern, name): |
| 627 | + mistral_name = re.sub(pattern, repl, name) |
| 628 | + targets.append(mistral_name) |
| 629 | + config_groups[group_name]["targets"] = targets |
| 630 | + quant_config.config_groups = config_groups |
| 631 | + |
| 632 | + return quant_config |
| 633 | + |
561 | 634 |
|
562 | 635 | class AudioLanguageAdapter(nn.Module):
|
563 | 636 |
|
|
0 commit comments