15
15
from vllm .distributed .parallel_state import get_pp_group
16
16
from vllm .forward_context import get_forward_context
17
17
from vllm .model_executor .layers .layernorm import RMSNorm
18
- from vllm .model_executor .layers .linear import ReplicatedLinear
18
+ from vllm .model_executor .layers .linear import (QKVParallelLinear ,
19
+ RowParallelLinear )
19
20
from vllm .model_executor .layers .logits_processor import LogitsProcessor
20
21
from vllm .model_executor .layers .mamba .mamba2_metadata import (
21
22
Mamba2Metadata , prepare_mamba2_metadata )
36
37
from .granitemoeshared import GraniteMoeSharedMLP
37
38
from .interfaces import (HasInnerState , IsHybrid , SupportsLoRA , SupportsPP ,
38
39
SupportsQuant , SupportsV0Only )
39
- from .utils import (AutoWeightsLoader , make_empty_intermediate_tensors_factory ,
40
- make_layers , maybe_prefix )
40
+ from .utils import (AutoWeightsLoader , is_pp_missing_parameter ,
41
+ make_empty_intermediate_tensors_factory , make_layers ,
42
+ maybe_prefix )
41
43
42
44
43
45
class GraniteMoeHybridMambaDecoderLayer (nn .Module ):
@@ -220,35 +222,37 @@ def __init__(
220
222
self .hidden_size = config .hidden_size
221
223
self .attention_bias = config .attention_bias
222
224
self .attention_multiplier = config .attention_multiplier
223
- self .num_heads = config .num_attention_heads
224
- self .head_dim = self .hidden_size // self .num_heads
225
- self .num_key_value_heads = config .num_key_value_heads
226
-
227
- self .q_proj = ReplicatedLinear (self .hidden_size ,
228
- self .num_heads * self .head_dim ,
229
- bias = self .attention_bias ,
230
- quant_config = quant_config ,
231
- prefix = f"{ prefix } .q_proj" )
232
-
233
- self .k_proj = ReplicatedLinear (self .hidden_size ,
234
- self .num_key_value_heads *
235
- self .head_dim ,
236
- bias = self .attention_bias ,
237
- quant_config = quant_config ,
238
- prefix = f"{ prefix } .k_proj" )
239
-
240
- self .v_proj = ReplicatedLinear (self .hidden_size ,
241
- self .num_key_value_heads *
242
- self .head_dim ,
243
- bias = self .attention_bias ,
244
- quant_config = quant_config ,
245
- prefix = f"{ prefix } .v_proj" )
246
-
247
- self .o_proj = ReplicatedLinear (self .hidden_size ,
248
- self .hidden_size ,
249
- bias = self .attention_bias ,
250
- quant_config = quant_config ,
251
- prefix = f"{ prefix } .o_proj" )
225
+ self .total_num_heads = config .num_attention_heads
226
+ self .head_dim = self .hidden_size // self .total_num_heads
227
+ self .total_num_kv_heads = config .num_key_value_heads
228
+
229
+ # TensorParallel logic
230
+ tp_size = get_tensor_model_parallel_world_size ()
231
+ assert self .total_num_heads % tp_size == 0
232
+ self .num_heads = self .total_num_heads // tp_size
233
+ if self .total_num_kv_heads >= tp_size :
234
+ # Number of KV heads is greater than TP size, so we partition
235
+ # the KV heads across multiple tensor parallel GPUs.
236
+ assert self .total_num_kv_heads % tp_size == 0
237
+ else :
238
+ # Number of KV heads is less than TP size, so we replicate
239
+ # the KV heads across multiple tensor parallel GPUs.
240
+ assert tp_size % self .total_num_kv_heads == 0
241
+ self .num_key_value_heads = max (1 , self .total_num_kv_heads // tp_size )
242
+
243
+ self .qkv_proj = QKVParallelLinear (self .hidden_size ,
244
+ self .head_dim ,
245
+ self .total_num_heads ,
246
+ self .total_num_kv_heads ,
247
+ bias = self .attention_bias ,
248
+ quant_config = quant_config ,
249
+ prefix = f"{ prefix } .qkv_proj" )
250
+
251
+ self .o_proj = RowParallelLinear (self .hidden_size ,
252
+ self .hidden_size ,
253
+ bias = self .attention_bias ,
254
+ quant_config = quant_config ,
255
+ prefix = f"{ prefix } .o_proj" )
252
256
253
257
if config .position_embedding_type == "rope" :
254
258
self .rotary_emb = get_rope (
@@ -278,9 +282,12 @@ def forward(
278
282
hidden_states : torch .Tensor ,
279
283
) -> torch .Tensor :
280
284
281
- query = self .q_proj (hidden_states )[0 ]
282
- key = self .k_proj (hidden_states )[0 ]
283
- value = self .v_proj (hidden_states )[0 ]
285
+ qkv , _ = self .qkv_proj (hidden_states )
286
+ query , key , value = qkv .split ([
287
+ self .num_heads * self .head_dim , self .num_key_value_heads *
288
+ self .head_dim , self .num_key_value_heads * self .head_dim
289
+ ],
290
+ dim = - 1 )
284
291
285
292
if self .rotary_emb is not None :
286
293
query , key = self .rotary_emb (positions , query , key )
@@ -401,6 +408,12 @@ def forward(
401
408
402
409
def load_weights (self , weights : Iterable [tuple [str ,
403
410
torch .Tensor ]]) -> set [str ]:
411
+ stacked_params_mapping = [
412
+ # (param_name, shard_name, shard_id)
413
+ (".qkv_proj" , ".q_proj" , "q" ),
414
+ (".qkv_proj" , ".k_proj" , "k" ),
415
+ (".qkv_proj" , ".v_proj" , "v" ),
416
+ ]
404
417
params_dict = dict (self .named_parameters ())
405
418
loaded_params : set [str ] = set ()
406
419
@@ -411,6 +424,15 @@ def _load(n, p):
411
424
weight_loader (param , p )
412
425
loaded_params .add (n )
413
426
427
+ def _load_shard (n , p , shard_id ):
428
+ # Skip layers on other devices.
429
+ if not is_pp_missing_parameter (n , self ):
430
+ param = params_dict [n ]
431
+ weight_loader = getattr (param , "weight_loader" ,
432
+ default_weight_loader )
433
+ weight_loader (param , p , shard_id )
434
+ loaded_params .add (n )
435
+
414
436
def _load_expert (n , p , name , shard_id , expert_id ):
415
437
param = params_dict [n ]
416
438
weight_loader = getattr (param , "weight_loader" ,
@@ -465,15 +487,29 @@ def _load_expert(n, p, name, shard_id, expert_id):
465
487
".block_sparse_moe.gate.weight" )
466
488
_load (gate_name , p )
467
489
else :
468
- _load (n , p )
490
+ loaded = False
491
+ for param_name , weight_name , shard_id in stacked_params_mapping :
492
+ if weight_name in n :
493
+ _load_shard (n .replace (weight_name , param_name ),
494
+ p ,
495
+ shard_id = shard_id )
496
+ loaded = True
497
+ if not loaded :
498
+ _load (n , p )
469
499
470
500
return loaded_params
471
501
472
502
473
503
class GraniteMoeHybridForCausalLM (nn .Module , HasInnerState , SupportsLoRA ,
474
504
SupportsPP , IsHybrid , SupportsV0Only ,
475
505
SupportsQuant ):
476
- packed_modules_mapping = {}
506
+ packed_modules_mapping = {
507
+ "qkv_proj" : [
508
+ "q_proj" ,
509
+ "k_proj" ,
510
+ "v_proj" ,
511
+ ],
512
+ }
477
513
embedding_modules = {
478
514
"embed_tokens" : "input_embeddings" ,
479
515
"lm_head" : "output_embeddings" ,
0 commit comments