Skip to content

Commit c344e24

Browse files
wwl2755ywang96DarkLight1337
authored andcommitted
[Bug] [Spec Decode] Fix model_initialization test and mismatch in aux_hidden_layers (vllm-project#24613)
Signed-off-by: wwl2755 <[email protected]> Signed-off-by: Roger Wang <[email protected]> Signed-off-by: Cyrus Leung <[email protected]> Co-authored-by: Roger Wang <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent 51e58a6 commit c344e24

File tree

3 files changed

+32
-11
lines changed

3 files changed

+32
-11
lines changed

tests/models/registry.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ class _HfExamplesInfo:
9797
max_num_seqs: Optional[int] = None
9898
"""Maximum number of sequences to be processed in a single iteration."""
9999

100+
use_original_num_layers: bool = False
101+
"""
102+
If True, use the original number of layers from the model config
103+
instead of minimal layers for testing.
104+
"""
105+
100106
def check_transformers_version(
101107
self,
102108
*,
@@ -597,18 +603,21 @@ def check_available_online(
597603
"EagleDeepSeekMTPModel": _HfExamplesInfo("eagle618/deepseek-v3-random",
598604
speculative_model="eagle618/eagle-deepseek-v3-random", # noqa: E501
599605
trust_remote_code=True),
600-
"EagleLlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE-LLaMA3-Instruct-8B",
606+
"EagleLlamaForCausalLM": _HfExamplesInfo("meta-llama/Meta-Llama-3-8B-Instruct", # noqa: E501
601607
trust_remote_code=True,
602608
speculative_model="yuhuili/EAGLE-LLaMA3-Instruct-8B",
603-
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
604-
"Eagle3LlamaForCausalLM": _HfExamplesInfo("yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501
609+
tokenizer="meta-llama/Meta-Llama-3-8B-Instruct"), # noqa: E501
610+
"Eagle3LlamaForCausalLM": _HfExamplesInfo("meta-llama/Llama-3.1-8B-Instruct", # noqa: E501
605611
trust_remote_code=True,
606-
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B",
607-
tokenizer="meta-llama/Llama-3.1-8B-Instruct"),
608-
"LlamaForCausalLMEagle3": _HfExamplesInfo("AngelSlim/Qwen3-8B_eagle3", # noqa: E501
612+
speculative_model="yuhuili/EAGLE3-LLaMA3.1-Instruct-8B", # noqa: E501
613+
tokenizer="meta-llama/Llama-3.1-8B-Instruct",
614+
use_original_num_layers=True,
615+
max_model_len=10240),
616+
"LlamaForCausalLMEagle3": _HfExamplesInfo("Qwen/Qwen3-8B", # noqa: E501
609617
trust_remote_code=True,
610618
speculative_model="AngelSlim/Qwen3-8B_eagle3", # noqa: E501
611-
tokenizer="Qwen/Qwen3-8B"),
619+
tokenizer="Qwen/Qwen3-8B",
620+
use_original_num_layers=True),
612621
"EagleLlama4ForCausalLM": _HfExamplesInfo(
613622
"morgendave/EAGLE-Llama-4-Scout-17B-16E-Instruct",
614623
trust_remote_code=True,

tests/models/test_initialization.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
3636

3737
hf_overrides_fn = partial(dummy_hf_overrides,
3838
model_arch=model_arch,
39-
exist_overrides=model_info.hf_overrides)
39+
exist_overrides=model_info.hf_overrides,
40+
use_original_num_layers=getattr(
41+
model_info, 'use_original_num_layers',
42+
False))
4043

4144
# Avoid calling model.forward()
4245
def _initialize_kv_caches_v0(self) -> None:

tests/models/utils.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ def dummy_hf_overrides(
396396
*,
397397
model_arch: str = "",
398398
exist_overrides: Optional[dict[str, Any]] = None,
399+
use_original_num_layers: bool = False,
399400
) -> PretrainedConfig:
400401
"""
401402
Dummy HF overrides function used to create dummy model
@@ -412,10 +413,18 @@ def dummy_hf_overrides(
412413

413414
# we use three layers for Gemma-3n to check
414415
# both normal layer and kv_shared_layer
415-
num_hidden_layers = (3 if model_arch == "Gemma3nForConditionalGeneration"
416-
else 1)
416+
if use_original_num_layers:
417+
# Use the original number of layers from the config
418+
num_layers = getattr(text_config, 'num_layers', 1)
419+
num_hidden_layers = getattr(text_config, 'num_hidden_layers', 1)
420+
else:
421+
# Use minimal layers for testing
422+
num_layers = 1
423+
num_hidden_layers = (3 if model_arch
424+
== "Gemma3nForConditionalGeneration" else 1)
425+
417426
text_config.update({
418-
"num_layers": 1,
427+
"num_layers": num_layers,
419428
"num_hidden_layers": num_hidden_layers,
420429
"num_experts": num_experts,
421430
"num_experts_per_tok": 2,

0 commit comments

Comments
 (0)