Skip to content

Conversation

tdoublep
Copy link
Member

@tdoublep tdoublep commented Aug 3, 2025

Essential Elements of an Effective PR Description Checklist

  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Purpose

This PR enables Minimax-Text-01 in vLLM V1. Currently only eager mode is supported.

I've tried to keep the changes minimal. I hope this PR could serve as a reference/template to how to enable other non-mamba2 based hybrid models in V1.

Test Plan

Deploy using V0 from main:

vllm serve MiniMaxAI/MiniMax-Text-01 \
	--tensor-parallel-size 8 \
	--trust-remote-code \
	--quantization experts_int8  \
	--max_model_len 4096 \
	--dtype bfloat16 \
	--gpu-memory-utilization 0.95

Deploy using V1 from this PR:

VLLM_USE_V1=1 VLLM_ATTENTION_BACKEND=FLASHINFER vllm serve MiniMaxAI/MiniMax-Text-01 \
	--tensor-parallel-size 8 \
	--trust-remote-code \
	--quantization experts_int8  \
	--max_model_len 4096 \
	--dtype bfloat16 \
	--gpu-memory-utilization 0.95 \
	--no-enable-prefix-caching \
	--enforce-eager

Evaluation:

lm_eval   --model local-completions   \
	--model_args base_url=http://localhost:8000/v1/completions,tokenizer=MiniMaxAI/MiniMax-Text-01 \
	--tasks gsm8k  \
	--batch_size 128 \
	--num_fewshot 5 \
	--limit 500

Test Result

V0:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.898|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.892|±  |0.0139|

V1:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.890|±  |0.0140|
|     |       |strict-match    |     5|exact_match|↑  |0.886|±  |0.0142|

(Optional) Documentation Update

Copy link

github-actions bot commented Aug 3, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Aug 3, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for the Minimax-Text-01 model in V1, which is a significant step for expanding hybrid model compatibility. The changes correctly adapt the model to the V1 architecture, including updates to the attention mechanisms and KV cache handling. The test suite has also been updated to include this new model, which is great to see.

My review has identified a few areas for improvement, primarily concerning leftover debugging statements that should be removed before merging. Additionally, there's a newly added utility method with an incorrect implementation and documentation that needs to be addressed to ensure code quality and maintainability. Addressing these points will help solidify the changes and maintain the high quality of the codebase.

@qscqesze
Copy link
Contributor

qscqesze commented Aug 4, 2025

I'm very interested in the tiny model. Would you be able to share it?

@tdoublep
Copy link
Member Author

tdoublep commented Aug 4, 2025

@NickLucche
Copy link
Collaborator

I think you forgot to paste the gsm8k V1 results in the description, or I am just getting ahead of myself since it's still a draft, sorry in that case :)

@tdoublep
Copy link
Member Author

tdoublep commented Aug 4, 2025

@NickLucche yeah haha, I didn't paste them because they don't look good yet lol. Still debugging it

@tdoublep
Copy link
Member Author

tdoublep commented Aug 4, 2025

It seems to be specifically an issue when the blocks start to get recycled. I see good lm_eval results until blocks start running out and getting re-used. To me this suggests some issue with the way the state is being reset.

Signed-off-by: Thomas Parnell <[email protected]>
@tdoublep tdoublep marked this pull request as ready for review August 4, 2025 21:34
@tdoublep
Copy link
Member Author

tdoublep commented Aug 4, 2025

@NickLucche Fixed the problem and updated above :)

@tdoublep
Copy link
Member Author

tdoublep commented Aug 4, 2025

BTW the problem was related to trying to use fp32 for the linear attention state, but fp16 for the normal attention state. This is what happens on V0 right now, but didn't seem to work when I tried it in V1. Not sure if this is some inherent limitation, will think more on that.

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
@rogeryoungh
Copy link

I have tested your pull request, the accuracy on GSM8k is 0.920, and the average accuracy on MMLU is 0.846.

Deployment command:

python3 -m vllm.entrypoints.api_server --model /data/xxx/model/MiniMax-Text-01/ --tensor-parallel-size 8 --trust-remote-code --quantization experts_int8 --max_model_len 4096 --dtype bfloat16 --no-enable-prefix-caching --enforce-eager

I used the sglang benchmark scripts. The detailed results are as follows:

GSM8k test:

python3 bench_other.py --num-questions 500 --num-shots 5 --backend vllm --port 8000 --host http://127.0.0.1
# ...
Accuracy: 0.920
Invalid: 0.000
Latency: 155.861 s

MMLU test:

python3 bench_other.py --nsub 200 --backend vllm --port 8000 --host http://127.0.0.1
# ...
Total latency: 1511.974
Average accuracy: 0.846

Copy link
Collaborator

@NickLucche NickLucche left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, but let's wait for someone more familiar with this the model to chime-in.
It's a bit of a pity we can pair this with a v0 deprecation task and just get rid of the v0 branching.

@tdoublep
Copy link
Member Author

tdoublep commented Aug 6, 2025

@NickLucche Re: V0 code, I agree in general. There will still be a performance gap to V0 until we merge #21401, so I'm not sure we should remove V0 code until performance is at least matching. I think we may also want to wait until hybrid models are supported using FlashAttention (e.g., after #21549 is merged).

Copy link

mergify bot commented Aug 7, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @tdoublep.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot removed the needs-rebase label Aug 7, 2025
@tdoublep tdoublep mentioned this pull request Aug 7, 2025
4 tasks
@youkaichao
Copy link
Member

cc @heheda12345

@qscqesze
Copy link
Contributor

qscqesze commented Aug 8, 2025

I think this code is fine. Great job! Looking forward to its early merge.

Copy link
Collaborator

@heheda12345 heheda12345 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! But can you move all ssm-related attention backends to v1/attention/backends/mamba? Either in this PR or a follow-up PR is OK for me.

And let's do more refactor to make more integration more plugable in the future.

@heheda12345 heheda12345 enabled auto-merge (squash) August 8, 2025 23:12
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 8, 2025
@heheda12345
Copy link
Collaborator

And thanks minimax team for verifying the correctness of this PR.

@vllm-bot vllm-bot merged commit 6ade99e into vllm-project:main Aug 9, 2025
47 of 56 checks passed
Copy link
Member

@DarkLight1337 DarkLight1337 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, can you open a follow-up PR to update the docs accordingly?

@zhiyuan1i
Copy link

BTW the problem was related to trying to use fp32 for the linear attention state, but fp16 for the normal attention state. This is what happens on V0 right now, but didn't seem to work when I tried it in V1. Not sure if this is some inherent limitation, will think more on that.

Thank you for the elegant design and meticulous engineering in this work—your code has already become the go-to reference for many of us.

I’m currently aligning our in-house, multi-branch hybrid model for an upcoming sync back to main. We also maintain a dual-precision linear cache (some tensors stay in bf16 while others stay in fp32), so your implementation will be an invaluable blueprint for how to handle the mixed-precision paths cleanly.

I believe the limitation exits because:
vllm/v1/kv_cache_interface.py

@dataclass(frozen=True)
class MambaSpec(KVCacheSpec):
    shapes: tuple[tuple[int, ...], ...]
    dtype: torch.dtype
    page_size_padded: Optional[int] = None
    mamba_type: str = "mamba2"

    @property
    def page_size_bytes(self) -> int:
        num_elements = sum(prod(shape) for shape in self.shapes)
        page_size = num_elements * get_dtype_size(self.dtype)
        if self.page_size_padded is not None:
            assert self.page_size_padded >= page_size
            return self.page_size_padded
        return page_size

    def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
        # We allocate 1 block for each request now, so max_memory_usage_bytes is
        # the same as page_size_bytes.
        # Need to update this when supporting prefix caching.
        return self.page_size_bytes

This codes only consider one type?

paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
yiliu30 pushed a commit to yiliu30/vllm-fork that referenced this pull request Aug 19, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants