Skip to content

Commit 08abfa7

Browse files
[Bugfix] fix modelopt exclude_modules name mapping (#24178)
Signed-off-by: Tomer Asida <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 2bef2d1 commit 08abfa7

File tree

3 files changed

+59
-38
lines changed

3 files changed

+59
-38
lines changed

vllm/model_executor/layers/mamba/mamba_mixer2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def __init__(self,
291291
output_size=self.conv_dim,
292292
bias=use_conv_bias,
293293
quant_config=None,
294+
prefix=f"{prefix}.conv1d",
294295
)
295296
# unsqueeze to fit conv1d weights shape into the linear weights shape.
296297
# Can't do this in `weight_loader` since it already exists in
@@ -303,6 +304,7 @@ def __init__(self,
303304
output_size=intermediate_size + self.conv_dim + self.num_heads,
304305
bias=use_bias,
305306
quant_config=quant_config,
307+
prefix=f"{prefix}.in_proj",
306308
)
307309

308310
# - because in_proj is a concatenation of 3 weights, we
@@ -402,6 +404,7 @@ def __init__(self,
402404
bias=use_bias,
403405
input_is_parallel=True,
404406
quant_config=quant_config,
407+
prefix=f"{prefix}.out_proj",
405408
)
406409

407410
self.norm = Mixer2RMSNormGated(intermediate_size,

vllm/model_executor/layers/quantization/modelopt.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33

4-
from typing import Any, Callable, Optional, Union
4+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
55

66
import torch
77
from torch.nn import Module
@@ -45,6 +45,9 @@
4545
from vllm.utils.flashinfer import (flashinfer_scaled_fp4_mm, has_flashinfer,
4646
has_flashinfer_moe)
4747

48+
if TYPE_CHECKING:
49+
from vllm.model_executor.models.utils import WeightsMapper
50+
4851
logger = init_logger(__name__)
4952

5053
QUANT_ALGOS = ["FP8", "NVFP4"]
@@ -63,7 +66,7 @@ def __init__(
6366
super().__init__()
6467
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
6568
self.kv_cache_quant_method = kv_cache_quant_method
66-
self.exclude_modules = exclude_modules
69+
self.exclude_modules = exclude_modules or []
6770
if is_checkpoint_fp8_serialized:
6871
logger.warning("Detected ModelOpt fp8 checkpoint. Please note that"
6972
" the format is experimental and could change.")
@@ -84,6 +87,11 @@ def get_min_capability(cls) -> int:
8487
def get_config_filenames(cls) -> list[str]:
8588
return ["hf_quant_config.json"]
8689

90+
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
91+
if self.exclude_modules is not None:
92+
self.exclude_modules = hf_to_vllm_mapper.apply_list(
93+
self.exclude_modules)
94+
8795
@classmethod
8896
def override_quantization_method(
8997
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
@@ -170,7 +178,9 @@ def get_quant_method(self, layer: torch.nn.Module,
170178
prefix: str) -> Optional["QuantizeMethodBase"]:
171179
from vllm.attention.layer import Attention # Avoid circular import
172180
if isinstance(layer, LinearBase):
173-
if self.is_layer_excluded(prefix):
181+
if (is_layer_skipped(prefix, self.exclude_modules,
182+
self.packed_modules_mapping)
183+
or self.is_layer_excluded(prefix)):
174184
return UnquantizedLinearMethod()
175185
return ModelOptFp8LinearMethod(self)
176186
elif isinstance(layer, Attention):
@@ -615,6 +625,11 @@ def get_min_capability(cls) -> int:
615625
def get_config_filenames(cls) -> list[str]:
616626
return ["hf_quant_config.json"]
617627

628+
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
629+
if self.exclude_modules is not None:
630+
self.exclude_modules = hf_to_vllm_mapper.apply_list(
631+
self.exclude_modules)
632+
618633
@classmethod
619634
def override_quantization_method(
620635
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
@@ -763,7 +778,8 @@ def get_quant_method(self, layer: torch.nn.Module,
763778
prefix: str) -> Optional["QuantizeMethodBase"]:
764779
from vllm.attention.layer import Attention # Avoid circular import
765780
if isinstance(layer, LinearBase):
766-
if (is_layer_skipped(prefix, self.exclude_modules)
781+
if (is_layer_skipped(prefix, self.exclude_modules,
782+
self.packed_modules_mapping)
767783
or self.is_layer_excluded(prefix, self.exclude_modules)):
768784
return UnquantizedLinearMethod()
769785
return ModelOptNvFp4LinearMethod(self)

vllm/model_executor/models/nemotron_h.py

Lines changed: 36 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,16 @@
4444
from vllm.model_executor.layers.quantization import QuantizationConfig
4545
from vllm.model_executor.layers.vocab_parallel_embedding import (
4646
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
47-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
47+
from vllm.model_executor.model_loader.weight_utils import (
48+
default_weight_loader, maybe_remap_kv_scale_name)
4849
from vllm.model_executor.models.interfaces import (HasInnerState, IsHybrid,
4950
SupportsLoRA, SupportsPP,
5051
SupportsQuant)
5152
from vllm.model_executor.models.mamba_cache import (MambaCacheManager,
5253
MambaCacheParams)
5354
from vllm.model_executor.models.utils import (
54-
AutoWeightsLoader, make_empty_intermediate_tensors_factory, make_layers,
55-
maybe_prefix)
55+
AutoWeightsLoader, WeightsMapper, make_empty_intermediate_tensors_factory,
56+
make_layers, maybe_prefix)
5657
from vllm.model_executor.sampling_metadata import SamplingMetadata
5758
from vllm.sequence import IntermediateTensors
5859
from vllm.transformers_utils.configs import NemotronHConfig
@@ -426,38 +427,36 @@ def forward(
426427

427428
def load_weights(self, weights: Iterable[tuple[str,
428429
torch.Tensor]]) -> set[str]:
429-
attb_params_mapping = {
430-
"q_proj": "q",
431-
"k_proj": "k",
432-
"v_proj": "v",
433-
}
430+
stacked_params_mapping = [
431+
# (param_name, shard_name, shard_id)
432+
("qkv_proj", "q_proj", "q"),
433+
("qkv_proj", "k_proj", "k"),
434+
("qkv_proj", "v_proj", "v"),
435+
]
434436

435437
params_dict = dict(self.named_parameters())
436438
loaded_params: set[str] = set()
437439
for name, loaded_weight in weights:
438-
if "embeddings" in name:
439-
name = name.replace("embeddings", "embed_tokens")
440+
if "scale" in name:
441+
# Remapping the name of FP8 kv-scale.
442+
name = maybe_remap_kv_scale_name(name, params_dict)
443+
if name is None:
444+
continue
445+
446+
# load stacked params
447+
for param_name, weight_name, shard_id in stacked_params_mapping:
448+
if weight_name not in name:
449+
continue
450+
name = name.replace(weight_name, param_name)
451+
# Skip loading extra bias for GPTQ models.
452+
if name.endswith(".bias") and name not in params_dict:
453+
continue
440454

441-
if "A_log" in name:
442-
name = name.replace("A_log", "A")
443-
loaded_weight = loaded_weight.to(torch.float32)
444-
445-
if "D" in name:
446-
loaded_weight = loaded_weight.to(torch.float32)
447-
448-
if "dt_bias" in name:
449-
loaded_weight = loaded_weight.to(torch.float32)
450-
451-
# load attn params
452-
if any(proj in name for proj in ["q_proj", "k_proj", "v_proj"]):
453-
weight_name = next(proj
454-
for proj in ["q_proj", "k_proj", "v_proj"]
455-
if proj in name)
456-
name = name.replace(weight_name, "qkv_proj")
457455
param = params_dict[name]
458456
weight_loader = param.weight_loader
459-
weight_loader(param, loaded_weight,
460-
attb_params_mapping[weight_name])
457+
weight_loader(param, loaded_weight, shard_id)
458+
break
459+
461460
# load other params
462461
else:
463462
param = params_dict[name]
@@ -471,6 +470,14 @@ def load_weights(self, weights: Iterable[tuple[str,
471470

472471
class NemotronHForCausalLM(nn.Module, HasInnerState, SupportsLoRA, SupportsPP,
473472
IsHybrid, SupportsQuant):
473+
hf_to_vllm_mapper = WeightsMapper(
474+
orig_to_new_prefix={"backbone": "model"},
475+
orig_to_new_substr={
476+
"A_log": "A",
477+
"embeddings": "embed_tokens"
478+
},
479+
)
480+
474481
packed_modules_mapping = {
475482
"qkv_proj": [
476483
"q_proj",
@@ -622,10 +629,5 @@ def compute_logits(
622629

623630
def load_weights(self, weights: Iterable[tuple[str,
624631
torch.Tensor]]) -> set[str]:
625-
# update name in weights before passing to loader
626-
updated_weights = []
627-
for name, loaded_weight in weights:
628-
name = name.replace("backbone", "model")
629-
updated_weights.append((name, loaded_weight))
630632
loader = AutoWeightsLoader(self)
631-
return loader.load_weights(updated_weights)
633+
return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

0 commit comments

Comments
 (0)