Skip to content

Commit 3e16096

Browse files
committed
GraniteMoeHybrid TensorParallel fix, and enabling in tests
Signed-off-by: Stanislaw Wozniak <[email protected]>
1 parent aa0dc77 commit 3e16096

File tree

3 files changed

+74
-82
lines changed

3 files changed

+74
-82
lines changed

tests/models/language/generation/test_granitemoehybrid.py

Lines changed: 0 additions & 42 deletions
This file was deleted.

tests/models/language/generation/test_hybrid.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@
2525

2626
HYBRID_MODELS = [
2727
"ai21labs/Jamba-tiny-dev",
28-
# NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as
29-
# it is not yet available in huggingface transformers
30-
# "ibm-granite/granite-4.0-tiny-preview",
28+
"ibm-granite/granite-4.0-tiny-preview",
3129
# NOTE: Running Plamo2 in transformers implementation requires to install
3230
# causal-conv1d package, which is not listed as a test dependency as it's
3331
# not compatible with pip-compile.

vllm/model_executor/models/granitemoehybrid.py

Lines changed: 73 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
from vllm.distributed.parallel_state import get_pp_group
1616
from vllm.forward_context import get_forward_context
1717
from vllm.model_executor.layers.layernorm import RMSNorm
18-
from vllm.model_executor.layers.linear import ReplicatedLinear
18+
from vllm.model_executor.layers.linear import (QKVParallelLinear,
19+
RowParallelLinear)
1920
from vllm.model_executor.layers.logits_processor import LogitsProcessor
2021
from vllm.model_executor.layers.mamba.mamba2_metadata import (
2122
Mamba2Metadata, prepare_mamba2_metadata)
@@ -36,8 +37,9 @@
3637
from .granitemoeshared import GraniteMoeSharedMLP
3738
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
3839
SupportsQuant, SupportsV0Only)
39-
from .utils import (AutoWeightsLoader, make_empty_intermediate_tensors_factory,
40-
make_layers, maybe_prefix)
40+
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
41+
make_empty_intermediate_tensors_factory, make_layers,
42+
maybe_prefix)
4143

4244

4345
class GraniteMoeHybridMambaDecoderLayer(nn.Module):
@@ -220,35 +222,37 @@ def __init__(
220222
self.hidden_size = config.hidden_size
221223
self.attention_bias = config.attention_bias
222224
self.attention_multiplier = config.attention_multiplier
223-
self.num_heads = config.num_attention_heads
224-
self.head_dim = self.hidden_size // self.num_heads
225-
self.num_key_value_heads = config.num_key_value_heads
226-
227-
self.q_proj = ReplicatedLinear(self.hidden_size,
228-
self.num_heads * self.head_dim,
229-
bias=self.attention_bias,
230-
quant_config=quant_config,
231-
prefix=f"{prefix}.q_proj")
232-
233-
self.k_proj = ReplicatedLinear(self.hidden_size,
234-
self.num_key_value_heads *
235-
self.head_dim,
236-
bias=self.attention_bias,
237-
quant_config=quant_config,
238-
prefix=f"{prefix}.k_proj")
239-
240-
self.v_proj = ReplicatedLinear(self.hidden_size,
241-
self.num_key_value_heads *
242-
self.head_dim,
243-
bias=self.attention_bias,
244-
quant_config=quant_config,
245-
prefix=f"{prefix}.v_proj")
246-
247-
self.o_proj = ReplicatedLinear(self.hidden_size,
248-
self.hidden_size,
249-
bias=self.attention_bias,
250-
quant_config=quant_config,
251-
prefix=f"{prefix}.o_proj")
225+
self.total_num_heads = config.num_attention_heads
226+
self.head_dim = self.hidden_size // self.total_num_heads
227+
self.total_num_kv_heads = config.num_key_value_heads
228+
229+
# TensorParallel logic
230+
tp_size = get_tensor_model_parallel_world_size()
231+
assert self.total_num_heads % tp_size == 0
232+
self.num_heads = self.total_num_heads // tp_size
233+
if self.total_num_kv_heads >= tp_size:
234+
# Number of KV heads is greater than TP size, so we partition
235+
# the KV heads across multiple tensor parallel GPUs.
236+
assert self.total_num_kv_heads % tp_size == 0
237+
else:
238+
# Number of KV heads is less than TP size, so we replicate
239+
# the KV heads across multiple tensor parallel GPUs.
240+
assert tp_size % self.total_num_kv_heads == 0
241+
self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size)
242+
243+
self.qkv_proj = QKVParallelLinear(self.hidden_size,
244+
self.head_dim,
245+
self.total_num_heads,
246+
self.total_num_kv_heads,
247+
bias=self.attention_bias,
248+
quant_config=quant_config,
249+
prefix=f"{prefix}.qkv_proj")
250+
251+
self.o_proj = RowParallelLinear(self.hidden_size,
252+
self.hidden_size,
253+
bias=self.attention_bias,
254+
quant_config=quant_config,
255+
prefix=f"{prefix}.o_proj")
252256

253257
if config.position_embedding_type == "rope":
254258
self.rotary_emb = get_rope(
@@ -278,9 +282,12 @@ def forward(
278282
hidden_states: torch.Tensor,
279283
) -> torch.Tensor:
280284

281-
query = self.q_proj(hidden_states)[0]
282-
key = self.k_proj(hidden_states)[0]
283-
value = self.v_proj(hidden_states)[0]
285+
qkv, _ = self.qkv_proj(hidden_states)
286+
query, key, value = qkv.split([
287+
self.num_heads * self.head_dim, self.num_key_value_heads *
288+
self.head_dim, self.num_key_value_heads * self.head_dim
289+
],
290+
dim=-1)
284291

285292
if self.rotary_emb is not None:
286293
query, key = self.rotary_emb(positions, query, key)
@@ -401,6 +408,12 @@ def forward(
401408

402409
def load_weights(self, weights: Iterable[tuple[str,
403410
torch.Tensor]]) -> set[str]:
411+
stacked_params_mapping = [
412+
# (param_name, shard_name, shard_id)
413+
(".qkv_proj", ".q_proj", "q"),
414+
(".qkv_proj", ".k_proj", "k"),
415+
(".qkv_proj", ".v_proj", "v"),
416+
]
404417
params_dict = dict(self.named_parameters())
405418
loaded_params: set[str] = set()
406419

@@ -411,6 +424,15 @@ def _load(n, p):
411424
weight_loader(param, p)
412425
loaded_params.add(n)
413426

427+
def _load_shard(n, p, shard_id):
428+
# Skip layers on other devices.
429+
if not is_pp_missing_parameter(n, self):
430+
param = params_dict[n]
431+
weight_loader = getattr(param, "weight_loader",
432+
default_weight_loader)
433+
weight_loader(param, p, shard_id)
434+
loaded_params.add(n)
435+
414436
def _load_expert(n, p, name, shard_id, expert_id):
415437
param = params_dict[n]
416438
weight_loader = getattr(param, "weight_loader",
@@ -465,15 +487,29 @@ def _load_expert(n, p, name, shard_id, expert_id):
465487
".block_sparse_moe.gate.weight")
466488
_load(gate_name, p)
467489
else:
468-
_load(n, p)
490+
loaded = False
491+
for param_name, weight_name, shard_id in stacked_params_mapping:
492+
if weight_name in n:
493+
_load_shard(n.replace(weight_name, param_name),
494+
p,
495+
shard_id=shard_id)
496+
loaded = True
497+
if not loaded:
498+
_load(n, p)
469499

470500
return loaded_params
471501

472502

473503
class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
474504
SupportsPP, IsHybrid, SupportsV0Only,
475505
SupportsQuant):
476-
packed_modules_mapping = {}
506+
packed_modules_mapping = {
507+
"qkv_proj": [
508+
"q_proj",
509+
"k_proj",
510+
"v_proj",
511+
],
512+
}
477513
embedding_modules = {
478514
"embed_tokens": "input_embeddings",
479515
"lm_head": "output_embeddings",

0 commit comments

Comments
 (0)