-
-
Notifications
You must be signed in to change notification settings - Fork 10.1k
Feature/spec decode draft model #24322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feature/spec decode draft model #24322
Conversation
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for speculative decoding using a draft model. The changes are comprehensive, touching configuration, model loading, scheduling, and the core speculative decoding logic. New tests and benchmark modifications are also included to validate and measure the new feature. The overall implementation appears solid. However, I've identified a critical issue in a refactoring of the bind_kv_cache
utility function, which removes an important safety check and could lead to incorrect behavior for certain model architectures.
@tomasruizt - Thank you for the PR!
|
What is the TP you are using for Qwen3-32B? By default, draft model TP is equal to target model TP. Since Qwen3-1.7B is a small model, running it on high TP might be incurring nccl communication cost. Try setting draft TP to 1. |
I ran the benchmarks with TP=1 and num_draft_tokens=3. So we can rule out TP communication issues. |
This pull request has merge conflicts that must be resolved before it can be |
7de2ae1
to
2e0fb65
Compare
Signed-off-by: Tomas Ruiz <[email protected]>
2e0fb65
to
f27261a
Compare
Signed-off-by: Tomas Ruiz <[email protected]>
Signed-off-by: Tomas Ruiz <[email protected]>
@ekagra-ranjan I updated the metrics with your suggestions. Acceptance lengths look good IMO. After applying CUDA graph on the draft model its speed improved dramatically. Now the throughput of using SD is higher than not using it (in particular TPOT). |
Signed-off-by: Tomas Ruiz <[email protected]>
Signed-off-by: Tomas Ruiz <[email protected]>
@@ -44,18 +45,20 @@ class EagleAttentionMetadata(Protocol): | |||
slot_mapping: torch.Tensor | |||
|
|||
|
|||
class EagleProposer: | |||
class SpecDecodeProposer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So maybe this file name should be changed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe. I feel like a refactoring is needed to cleanly separate EAGLE from "draft_model", to simplify the if-else statements and reduce the mental overhead when reading the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But perhaps that should be a follow-up PR to avoid changing too much code in a single PR.
Signed-off-by: Tomas Ruiz <[email protected]>
Signed-off-by: Tomas Ruiz <[email protected]>
Signed-off-by: Tomas Ruiz <[email protected]>
Signed-off-by: Tomas Ruiz <[email protected]>
@tomasruizt - The online inference cmd shared in PR has this flag Also, please provide the cmd you used to profile the model for reproducibility. |
|
This pull request has merge conflicts that must be resolved before it can be |
Purpose
Enabling draft models for speculative decoding (SD).
E.g.
Qwen3-1.7B
as draft model andQwen3-32B
as target model.This type of SD requires no special trained heads (like EAGLE, or Medusa).
Example usage:
vllm serve \ --model=Qwen/Qwen3-4B \ --speculative-config '{"model": "Qwen/Qwen3-0.6B", "method": "draft_model", "num_speculative_tokens": 3, "max-model-len": 2000}' \ --max-model-len 2000
Get a generation:
Status
Acceptance Length
As suggested by @ekagra-ranjan, I benchmarked acceptance length (AL) with the command below:
The AL values within the Qwen3 family seem good, both with temperatures of 0.0 (greedy) and 1.0.
As a sanity check, I benchmarked LLama-3.2-1B as both target and draft, which had almost perfect AL (3.97/4), suggesting its working as intended.
I have not run the default model
meta-llama/Llama-3.1-8B-Instruct
, because I didn't find a good draft model for it, but feel free to suggest one and I can run the benchmarks.Temperature t=0:
Temperature t=1.0:
Using t=1.0, the AL metric degrades. However, spec-decode with probabilities is not yet implemented, needed for lossless rejection sampling. This is being addressed atm: #20459. After that PR is merged, the AL for non-greedy spec-decode should improve.
All scripts and logs used for the benchmarks can be found in this Google Drive.
Online Throughput Metrics
I measured online throughput metrics using the commands below. After making sure the draft model also uses CUDA graph, SD has higher throughput than not using SD. See tables below.
The metrics (lower is better) are:
Batch Size = 1
For Temperature = 0.0:Using SD runtimes and TPOT are shorter by ~50%.
Batch Size = 100
For Temperature = 0.0:For Temperature = 1.0:
This scenario with batch size 100 is a more realistic inference case.
Using SD runtimes and TPOT are shorter.
Profiling
This section was removed, since using CUDA graphs on the draft model significantly improved its speed.
Profiling script
I used the command below to profile the generation process and identify that the draft model was running too slow before.Note: The command uses the
--profile
flag, which I introduce in this PR: #24575Test Plan
The added unit test check the correctness metrics. To run it:
cd tests/v1/e2e/ pytest test_spec_decode.py -k test_draft_model_correctness
EAGLE testing
I tested that the EAGLE implementation stays unaffected the command below
The results are in line with previous measurements like #17504 (comment)
Follow-up Optimizations
next_token_ids
together withtarget_token_ids
in the first forward pass of the draft model. This reduces the number of forward passes needed in each drafting phase by one, speeding up drafting.(Optional) Documentation Update
Essential Elements of an Effective PR Description Checklist
supported_models.md
andexamples
for a new model.