Skip to content

Conversation

nopperl
Copy link
Contributor

@nopperl nopperl commented Aug 31, 2025

Purpose

This PR follows the great work by @heheda12345 and @tdoublep to finally enable the v1 engine and full CUDA graphs (decode only) for the Plamo2 model architecture.
It also incorporates other recent improvements to MambaMixer2 (such as #21075 and #18218).

Additionally, this PR fixes #22999 and consequently re-enables Plamo2 in the hybrid model unit tests. I also enabled Plamo2 for the full CUDA graph unit tests.

Note: To support (piecewise) CUDA graphs and torch.compile, the Plamo2MambaMixer is added to CompilationConfig._attention_ops by default. I think it's OK to do since the diff is minimal.

This PR is potentially in conflict with #21467 (cc @cyang49).

Fixes #23956.
Incorporates #23520 to fix #22999.

Benchmarks

Latency

Latency improves quite a bit.

This PR (v1 engine, CUDAGraphMode.FULL_AND_PIECEWISE):

$ VLLM_USE_V1=1 vllm bench latency --model pfnet/plamo-2-8b --trust-remote-code --max-model-len 8192 --no-enable-prefix-caching --input-len 1000 --output-len 1000 --num-iters 3 --num-iters-warmup 3

Avg latency: 26.398034774387877 seconds
10% percentile latency: 26.389116884209216 seconds
25% percentile latency: 26.39132913807407 seconds
50% percentile latency: 26.395016227848828 seconds
75% percentile latency: 26.403231137432158 seconds
90% percentile latency: 26.408160083182157 seconds
99% percentile latency: 26.411117450632155 seconds

main (v0 engine, piecewise):

Avg latency: 30.82136501589169 seconds
10% percentile latency: 30.770767986914144 seconds
25% percentile latency: 30.776328965090215 seconds
50% percentile latency: 30.781959515297785 seconds
75% percentile latency: 30.786822279333137 seconds
90% percentile latency: 30.853961815405636 seconds
99% percentile latency: 31.446262379428372 seconds

Throughput

Throughput does not significantly change.

This PR (v1 engine, CUDAGraphMode.FULL_AND_PIECEWISE):

$ VLLM_USE_V1=1 vllm bench throughput --model pfnet/plamo-2-8b --trust-remote-code --max-model-len 8192 --no-enable-prefix-caching --dataset-name sharegpt --dataset-path /datasets/ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 1000

Throughput: 12.00 requests/s, 4619.67 total tokens/s, 2144.02 output tokens/s

main (v0 engine, piecewise):

Throughput: 11.83 requests/s, 4532.89 total tokens/s, 2104.20 output tokens/s

Output quality

No output degradation.

This PR (v1 engine):

$ VLLM_USE_V1=1 lm_eval --model vllm  --model_args pretrained=pfnet/plamo-2-8b,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95,enable_prefix_caching=False --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5163|±  |0.0138|
|     |       |strict-match    |     5|exact_match|↑  |0.6603|±  |0.0130|

main (v0 engine, piecewise):

$ VLLM_USE_V1=0 lm_eval --model vllm  --model_args pretrained=pfnet/plamo-2-8b,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95,enable_prefix_caching=False --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5125|±  |0.0138|
|     |       |strict-match    |     5|exact_match|↑  |0.6581|±  |0.0131|

Test Result

pytest -s -v tests/models/language/generation/test_hybrid.py passed.


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request is a significant and well-executed effort to enable v1 engine and full CUDA graph support for the Plamo2 model architecture. The changes are comprehensive, including updates to the model implementation, test configurations, and documentation. The refactoring of Plamo2MambaMixer to use CustomOp and the v1-style state management is particularly well done. I've identified one critical bug in the handling of state_indices_tensor that could lead to an out-of-bounds error in mixed-batch scenarios. Addressing this issue should make the implementation robust.

@nopperl nopperl changed the title Plamo2 v1 cudagraph [V1] v1 engine + full CUDA graph support for PLaMo2 Aug 31, 2025
Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work - I have a few tiny comments.

Do you have any lm_eval results comparing V0 to V1? Just so we are confident re: correctness? Nevermind, I see you included it in the PR description.

Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - @nopperl can you please fix the DCO issue?

@mergify mergify bot added the needs-rebase label Sep 3, 2025
@nopperl nopperl force-pushed the plamo2-v1-cudagraph branch from 8060211 to bdfbd8b Compare September 3, 2025 07:22
@mergify mergify bot removed tpu Related to Google TPUs needs-rebase labels Sep 3, 2025
@nopperl nopperl force-pushed the plamo2-v1-cudagraph branch from 07e358d to ad8a800 Compare September 3, 2025 07:30
@tdoublep tdoublep enabled auto-merge (squash) September 3, 2025 09:40
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 3, 2025
@vllm-bot vllm-bot merged commit fa4311d into vllm-project:main Sep 3, 2025
48 of 50 checks passed
@cyang49
Copy link
Contributor

cyang49 commented Sep 3, 2025

@nopperl I tried to run pfnet/plamo-2.1-2b-cpt but it failed with this message. Is this normal?

Command

LLM_USE_V1=1 lm_eval --model vllm  --model_args pretrained=pfnet/plamo-2.1-2b-cpt,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.95,enable_prefix_caching=False --batch_size auto --trust_remote_code  --cache_requests true --tasks gsm8k
(EngineCore_0 pid=2991797) Process EngineCore_0:
(EngineCore_0 pid=2991797) Traceback (most recent call last):
(EngineCore_0 pid=2991797)   File "/net/storage149/mnt/md0/ccyang/miniforge3/envs/vllm_src_torch_nightly/lib/python3.12/multiprocessing/process.py", line 314, in _bootstrap
(EngineCore_0 pid=2991797)     self.run()
(EngineCore_0 pid=2991797)   File "/net/storage149/mnt/md0/ccyang/miniforge3/envs/vllm_src_torch_nightly/lib/python3.12/multiprocessing/process.py", line 108, in run
(EngineCore_0 pid=2991797)     self._target(*self._args, **self._kwargs)
(EngineCore_0 pid=2991797)   File "/net/storage149/mnt/md0/ccyang/github.com/vllm/vllm/v1/engine/core.py", line 716, in run_engine_core
(EngineCore_0 pid=2991797)     raise e
(EngineCore_0 pid=2991797)   File "/net/storage149/mnt/md0/ccyang/github.com/vllm/vllm/v1/engine/core.py", line 703, in run_engine_core
(EngineCore_0 pid=2991797)     engine_core = EngineCoreProc(*args, **kwargs)
(EngineCore_0 pid=2991797)                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(EngineCore_0 pid=2991797)   File "/net/storage149/mnt/md0/ccyang/github.com/vllm/vllm/v1/engine/core.py", line 504, in __init__
(EngineCore_0 pid=2991797)     super().__init__(vllm_config, executor_class, log_stats,
(EngineCore_0 pid=2991797)   File "/net/storage149/mnt/md0/ccyang/github.com/vllm/vllm/v1/engine/core.py", line 90, in __init__
(EngineCore_0 pid=2991797)     self._initialize_kv_caches(vllm_config)
(EngineCore_0 pid=2991797)   File "/net/storage149/mnt/md0/ccyang/github.com/vllm/vllm/v1/engine/core.py", line 192, in _initialize_kv_caches
(EngineCore_0 pid=2991797)     get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
(EngineCore_0 pid=2991797)   File "/net/storage149/mnt/md0/ccyang/github.com/vllm/vllm/v1/core/kv_cache_utils.py", line 1119, in get_kv_cache_config
(EngineCore_0 pid=2991797)     raise NotImplementedError
(EngineCore_0 pid=2991797) NotImplementedError

I also printed the kv_cache_spec before the error

(EngineCore_0 pid=2991797) kv_cache_spec={'model.layers.layers.1.mixer.attn': SlidingWindowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use
_mla=False, sliding_window=32768), 'model.layers.layers.3.mixer.attn': SlidingWindowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=Fa
lse, sliding_window=32768), 'model.layers.layers.5.mixer.attn': SlidingWindowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=False, sl
iding_window=32768), 'model.layers.layers.7.mixer.attn': SlidingWindowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=False, sliding_w
indow=32768), 'model.layers.layers.9.mixer.attn': SlidingWindowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=False, sliding_window=3
2768), 'model.layers.layers.11.mixer.attn': SlidingWindowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=False, sliding_window=32768),
 'model.layers.layers.13.mixer.attn': SlidingWindowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=False, sliding_window=32768), 'mode
l.layers.layers.15.mixer.attn': SlidingWindowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=False, sliding_window=32768), 'model.laye
rs.layers.17.mixer.attn': SlidingWindowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=False, sliding_window=32768), 'model.layers.lay
ers.19.mixer.attn': SlidingWindowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=False, sliding_window=32768), 'model.layers.layers.21
.mixer.attn': SlidingWindowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=False, sliding_window=32768), 'model.layers.layers.23.mixer
.attn': SlidingWindowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=False, sliding_window=32768), 'model.layers.layers.25.mixer.attn'
: SlidingWindowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=False, sliding_window=32768), 'model.layers.layers.27.mixer.attn': Slid
ingWindowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=False, sliding_window=32768), 'model.layers.layers.29.mixer.attn': SlidingWin
dowSpec(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=False, sliding_window=32768), 'model.layers.layers.31.mixer.attn': SlidingWindowSpe
c(block_size=1072, num_kv_heads=4, head_size=128, dtype=torch.bfloat16, use_mla=False, sliding_window=32768), 'model.layers.layers.0.mixer': MambaSpec(block_size=1048576
0, shapes=((3, 8192), (64, 128, 64)), dtypes=(torch.bfloat16, torch.bfloat16), page_size_padded=None, mamba_type='mamba2'), 'model.layers.layers.2.mixer': MambaSpec(bloc
k_size=10485760, shapes=((3, 8192), (64, 128, 64)), dtypes=(torch.bfloat16, torch.bfloat16), page_size_padded=None, mamba_type='mamba2'), 'model.layers.layers.4.mixer': 
MambaSpec(block_size=10485760, shapes=((3, 8192), (64, 128, 64)), dtypes=(torch.bfloat16, torch.bfloat16), page_size_padded=None, mamba_type='mamba2'), 'model.layers.lay
ers.6.mixer': MambaSpec(block_size=10485760, shapes=((3, 8192), (64, 128, 64)), dtypes=(torch.bfloat16, torch.bfloat16), page_size_padded=None, mamba_type='mamba2'), 'mo
del.layers.layers.8.mixer': MambaSpec(block_size=10485760, shapes=((3, 8192), (64, 128, 64)), dtypes=(torch.bfloat16, torch.bfloat16), page_size_padded=None, mamba_type=
'mamba2'), 'model.layers.layers.10.mixer': MambaSpec(block_size=10485760, shapes=((3, 8192), (64, 128, 64)), dtypes=(torch.bfloat16, torch.bfloat16), page_size_padded=No
ne, mamba_type='mamba2'), 'model.layers.layers.12.mixer': MambaSpec(block_size=10485760, shapes=((3, 8192), (64, 128, 64)), dtypes=(torch.bfloat16, torch.bfloat16), page
_size_padded=None, mamba_type='mamba2'), 'model.layers.layers.14.mixer': MambaSpec(block_size=10485760, shapes=((3, 8192), (64, 128, 64)), dtypes=(torch.bfloat16, torch.
bfloat16), page_size_padded=None, mamba_type='mamba2'), 'model.layers.layers.16.mixer': MambaSpec(block_size=10485760, shapes=((3, 8192), (64, 128, 64)), dtypes=(torch.b
float16, torch.bfloat16), page_size_padded=None, mamba_type='mamba2'), 'model.layers.layers.18.mixer': MambaSpec(block_size=10485760, shapes=((3, 8192), (64, 128, 64)), 
dtypes=(torch.bfloat16, torch.bfloat16), page_size_padded=None, mamba_type='mamba2'), 'model.layers.layers.20.mixer': MambaSpec(block_size=10485760, shapes=((3, 8192), (
64, 128, 64)), dtypes=(torch.bfloat16, torch.bfloat16), page_size_padded=None, mamba_type='mamba2'), 'model.layers.layers.22.mixer': MambaSpec(block_size=10485760, shape
s=((3, 8192), (64, 128, 64)), dtypes=(torch.bfloat16, torch.bfloat16), page_size_padded=None, mamba_type='mamba2'), 'model.layers.layers.24.mixer': MambaSpec(block_size=
10485760, shapes=((3, 8192), (64, 128, 64)), dtypes=(torch.bfloat16, torch.bfloat16), page_size_padded=None, mamba_type='mamba2'), 'model.layers.layers.26.mixer': MambaS
pec(block_size=10485760, shapes=((3, 8192), (64, 128, 64)), dtypes=(torch.bfloat16, torch.bfloat16), page_size_padded=None, mamba_type='mamba2'), 'model.layers.layers.28
.mixer': MambaSpec(block_size=10485760, shapes=((3, 8192), (64, 128, 64)), dtypes=(torch.bfloat16, torch.bfloat16), page_size_padded=None, mamba_type='mamba2'), 'model.l
ayers.layers.30.mixer': MambaSpec(block_size=10485760, shapes=((3, 8192), (64, 128, 64)), dtypes=(torch.bfloat16, torch.bfloat16), page_size_padded=None, mamba_type='mamba2')}

@nopperl
Copy link
Contributor Author

nopperl commented Sep 4, 2025

@cyang49 thanks for reporting! There seems to be an issue with the HF version of PLaMo2. I have created an issue to track this: #24204

@nopperl
Copy link
Contributor Author

nopperl commented Sep 4, 2025

@cyang49 the issue with PLaMo2.1 is fixed now!

eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
Signed-off-by: Hemmi Shinichi <[email protected]>
Signed-off-by: nopperl <[email protected]>
Co-authored-by: Hemmi Shinichi <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: Hemmi Shinichi <[email protected]>
Signed-off-by: nopperl <[email protected]>
Co-authored-by: Hemmi Shinichi <[email protected]>
Co-authored-by: Thomas Parnell <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build deepseek Related to DeepSeek models documentation Improvements or additions to documentation frontend gpt-oss Related to GPT-OSS models llama Related to Llama models multi-modality Related to multi-modality (#4194) new-model Requests to new models performance Performance-related issues qwen Related to Qwen models ready ONLY add when PR is ready to merge/full CI is needed rocm Related to AMD ROCm speculative-decoding structured-output tool-calling v1
Projects
Status: Done
Status: Done
Development

Successfully merging this pull request may close these issues.

[Feature]: Support Plamo2 Model in V1 [Bug]: plamo2 broken on main using transformers==4.55.0
5 participants