Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 0 additions & 42 deletions tests/models/language/generation/test_granitemoehybrid.py

This file was deleted.

5 changes: 3 additions & 2 deletions tests/models/language/generation/test_hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@

HYBRID_MODELS = [
"ai21labs/Jamba-tiny-dev",
# NOTE: ibm-granite/granite-4.0-tiny-preview are skipped currently as
# it is not yet available in huggingface transformers
# NOTE: Currently the test failes due to HF transformers issue fixed in:
# https://github.com/huggingface/transformers/pull/39033
# We will enable vLLM test for Granite after next HF transformers release.
# "ibm-granite/granite-4.0-tiny-preview",
# NOTE: Running Plamo2 in transformers implementation requires to install
# causal-conv1d package, which is not listed as a test dependency as it's
Expand Down
110 changes: 73 additions & 37 deletions vllm/model_executor/models/granitemoehybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.linear import (QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.mamba2_metadata import (
Mamba2Metadata, prepare_mamba2_metadata)
Expand All @@ -36,8 +37,9 @@
from .granitemoeshared import GraniteMoeSharedMLP
from .interfaces import (HasInnerState, IsHybrid, SupportsLoRA, SupportsPP,
SupportsQuant, SupportsV0Only)
from .utils import (AutoWeightsLoader, make_empty_intermediate_tensors_factory,
make_layers, maybe_prefix)
from .utils import (AutoWeightsLoader, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)


class GraniteMoeHybridMambaDecoderLayer(nn.Module):
Expand Down Expand Up @@ -220,35 +222,37 @@ def __init__(
self.hidden_size = config.hidden_size
self.attention_bias = config.attention_bias
self.attention_multiplier = config.attention_multiplier
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads

self.q_proj = ReplicatedLinear(self.hidden_size,
self.num_heads * self.head_dim,
bias=self.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")

self.k_proj = ReplicatedLinear(self.hidden_size,
self.num_key_value_heads *
self.head_dim,
bias=self.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.k_proj")

self.v_proj = ReplicatedLinear(self.hidden_size,
self.num_key_value_heads *
self.head_dim,
bias=self.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.v_proj")

self.o_proj = ReplicatedLinear(self.hidden_size,
self.hidden_size,
bias=self.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
self.total_num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.total_num_heads
self.total_num_kv_heads = config.num_key_value_heads

# TensorParallel logic
tp_size = get_tensor_model_parallel_world_size()
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
if self.total_num_kv_heads >= tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert tp_size % self.total_num_kv_heads == 0
self.num_key_value_heads = max(1, self.total_num_kv_heads // tp_size)

self.qkv_proj = QKVParallelLinear(self.hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=self.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj")

self.o_proj = RowParallelLinear(self.hidden_size,
self.hidden_size,
bias=self.attention_bias,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")

if config.position_embedding_type == "rope":
self.rotary_emb = get_rope(
Expand Down Expand Up @@ -278,9 +282,12 @@ def forward(
hidden_states: torch.Tensor,
) -> torch.Tensor:

query = self.q_proj(hidden_states)[0]
key = self.k_proj(hidden_states)[0]
value = self.v_proj(hidden_states)[0]
qkv, _ = self.qkv_proj(hidden_states)
query, key, value = qkv.split([
self.num_heads * self.head_dim, self.num_key_value_heads *
self.head_dim, self.num_key_value_heads * self.head_dim
],
dim=-1)

if self.rotary_emb is not None:
query, key = self.rotary_emb(positions, query, key)
Expand Down Expand Up @@ -401,6 +408,12 @@ def forward(

def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()

Expand All @@ -411,6 +424,15 @@ def _load(n, p):
weight_loader(param, p)
loaded_params.add(n)

def _load_shard(n, p, shard_id):
# Skip layers on other devices.
if not is_pp_missing_parameter(n, self):
param = params_dict[n]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, p, shard_id)
loaded_params.add(n)

def _load_expert(n, p, name, shard_id, expert_id):
param = params_dict[n]
weight_loader = getattr(param, "weight_loader",
Expand Down Expand Up @@ -465,15 +487,29 @@ def _load_expert(n, p, name, shard_id, expert_id):
".block_sparse_moe.gate.weight")
_load(gate_name, p)
else:
_load(n, p)
loaded = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name in n:
_load_shard(n.replace(weight_name, param_name),
p,
shard_id=shard_id)
loaded = True
if not loaded:
_load(n, p)
Comment on lines +490 to +498
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loaded = False variable is initialized but might not be updated in all branches of the else block. If none of the if weight_name in n: conditions are met, loaded will remain False, and _load(n, p) will be called. However, if any of the conditions are met, loaded is set to True, and _load(n, p) is skipped. It's better to ensure that loaded is correctly set in all branches to avoid unexpected behavior. Consider adding a default else clause within the for loop to set loaded = False explicitly.

Suggested change
loaded = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name in n:
_load_shard(n.replace(weight_name, param_name),
p,
shard_id=shard_id)
loaded = True
if not loaded:
_load(n, p)
loaded = False
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name in n:
_load_shard(n.replace(weight_name, param_name),
p,
shard_id=shard_id)
loaded = True
else:
loaded = False # Add this line
if not loaded:


return loaded_params


class GraniteMoeHybridForCausalLM(nn.Module, HasInnerState, SupportsLoRA,
SupportsPP, IsHybrid, SupportsV0Only,
SupportsQuant):
packed_modules_mapping = {}
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
"k_proj",
"v_proj",
],
}
embedding_modules = {
"embed_tokens": "input_embeddings",
"lm_head": "output_embeddings",
Expand Down