44
44
from vllm .model_executor .layers .quantization import QuantizationConfig
45
45
from vllm .model_executor .layers .vocab_parallel_embedding import (
46
46
DEFAULT_VOCAB_PADDING_SIZE , ParallelLMHead , VocabParallelEmbedding )
47
- from vllm .model_executor .model_loader .weight_utils import default_weight_loader
47
+ from vllm .model_executor .model_loader .weight_utils import (
48
+ default_weight_loader , maybe_remap_kv_scale_name )
48
49
from vllm .model_executor .models .interfaces import (HasInnerState , IsHybrid ,
49
50
SupportsLoRA , SupportsPP ,
50
51
SupportsQuant )
51
52
from vllm .model_executor .models .mamba_cache import (MambaCacheManager ,
52
53
MambaCacheParams )
53
54
from vllm .model_executor .models .utils import (
54
- AutoWeightsLoader , make_empty_intermediate_tensors_factory , make_layers ,
55
- maybe_prefix )
55
+ AutoWeightsLoader , WeightsMapper , make_empty_intermediate_tensors_factory ,
56
+ make_layers , maybe_prefix )
56
57
from vllm .model_executor .sampling_metadata import SamplingMetadata
57
58
from vllm .sequence import IntermediateTensors
58
59
from vllm .transformers_utils .configs import NemotronHConfig
@@ -426,38 +427,36 @@ def forward(
426
427
427
428
def load_weights (self , weights : Iterable [tuple [str ,
428
429
torch .Tensor ]]) -> set [str ]:
429
- attb_params_mapping = {
430
- "q_proj" : "q" ,
431
- "k_proj" : "k" ,
432
- "v_proj" : "v" ,
433
- }
430
+ stacked_params_mapping = [
431
+ # (param_name, shard_name, shard_id)
432
+ ("qkv_proj" , "q_proj" , "q" ),
433
+ ("qkv_proj" , "k_proj" , "k" ),
434
+ ("qkv_proj" , "v_proj" , "v" ),
435
+ ]
434
436
435
437
params_dict = dict (self .named_parameters ())
436
438
loaded_params : set [str ] = set ()
437
439
for name , loaded_weight in weights :
438
- if "embeddings" in name :
439
- name = name .replace ("embeddings" , "embed_tokens" )
440
+ if "scale" in name :
441
+ # Remapping the name of FP8 kv-scale.
442
+ name = maybe_remap_kv_scale_name (name , params_dict )
443
+ if name is None :
444
+ continue
445
+
446
+ # load stacked params
447
+ for param_name , weight_name , shard_id in stacked_params_mapping :
448
+ if weight_name not in name :
449
+ continue
450
+ name = name .replace (weight_name , param_name )
451
+ # Skip loading extra bias for GPTQ models.
452
+ if name .endswith (".bias" ) and name not in params_dict :
453
+ continue
440
454
441
- if "A_log" in name :
442
- name = name .replace ("A_log" , "A" )
443
- loaded_weight = loaded_weight .to (torch .float32 )
444
-
445
- if "D" in name :
446
- loaded_weight = loaded_weight .to (torch .float32 )
447
-
448
- if "dt_bias" in name :
449
- loaded_weight = loaded_weight .to (torch .float32 )
450
-
451
- # load attn params
452
- if any (proj in name for proj in ["q_proj" , "k_proj" , "v_proj" ]):
453
- weight_name = next (proj
454
- for proj in ["q_proj" , "k_proj" , "v_proj" ]
455
- if proj in name )
456
- name = name .replace (weight_name , "qkv_proj" )
457
455
param = params_dict [name ]
458
456
weight_loader = param .weight_loader
459
- weight_loader (param , loaded_weight ,
460
- attb_params_mapping [weight_name ])
457
+ weight_loader (param , loaded_weight , shard_id )
458
+ break
459
+
461
460
# load other params
462
461
else :
463
462
param = params_dict [name ]
@@ -471,6 +470,14 @@ def load_weights(self, weights: Iterable[tuple[str,
471
470
472
471
class NemotronHForCausalLM (nn .Module , HasInnerState , SupportsLoRA , SupportsPP ,
473
472
IsHybrid , SupportsQuant ):
473
+ hf_to_vllm_mapper = WeightsMapper (
474
+ orig_to_new_prefix = {"backbone" : "model" },
475
+ orig_to_new_substr = {
476
+ "A_log" : "A" ,
477
+ "embeddings" : "embed_tokens"
478
+ },
479
+ )
480
+
474
481
packed_modules_mapping = {
475
482
"qkv_proj" : [
476
483
"q_proj" ,
@@ -622,10 +629,5 @@ def compute_logits(
622
629
623
630
def load_weights (self , weights : Iterable [tuple [str ,
624
631
torch .Tensor ]]) -> set [str ]:
625
- # update name in weights before passing to loader
626
- updated_weights = []
627
- for name , loaded_weight in weights :
628
- name = name .replace ("backbone" , "model" )
629
- updated_weights .append ((name , loaded_weight ))
630
632
loader = AutoWeightsLoader (self )
631
- return loader .load_weights (updated_weights )
633
+ return loader .load_weights (weights , mapper = self . hf_to_vllm_mapper )
0 commit comments