Skip to content

Conversation

benchislett
Copy link
Collaborator

@benchislett benchislett commented Aug 25, 2025

Purpose

This PR allows the model runner to function asynchronously when using async scheduling. This allows full overlap of the cpu operations (including prepare_inputs) and the model forward pass. This diff is functional and does not support speculative decoding, PP, or guided decoding.

Expected speedup is 5-10% over the current async scheduling.

This PR is pending some light refactoring (see inline comments) and testing but is otherwise ready for review.

Design Analysis

Overview

This PR implements overlapped model execution by allowing the model runner to return a CUDA tensor reference to the sampled token ids, instead of the pythonized token ids. The result is passed via queue to an output worker thread which blocks until the value is ready and then places the pythonized result on the main output queue. This way, the model runner can run ahead to handle new inputs before the GPU has finished processing the previous iteration. This results in an elimination of the CPU overhead of input preparation and sampling.

In order to implement this, the output_token_ids and token_ids_cpu are no longer updated after the sampling step. Instead, a reference to the previous sampled_token_ids is kept and the ids are copied into self.input_ids during the next step's prepare_inputs phase. This means that this approach will not be compatible with ngram speculative decoding (since that requires the output token to be known on the cpu), and will need to be adapted for other speculative decoding (which can [with some modification] receive their inputs from the gpu tensor directly).

Compatibility with Key Features

Currently, this PR is not compatible with:

  • Pipeline Parallelism
  • Speculative Decoding
  • Guided Decoding

I expect that for Speculative Decoding and Guided Decoding, the integration will be straightforward.

For Guided Decoding, an open PR #23224 introduces a refactoring of the structured outputs manager into its own process, allowing the filled bitmask to be received directly by the model runner(s) just before it needs to be applied, enabling overlapped computation of the bitmask.

For Speculative Decoding, work on enabling MLA+MTP #22684 implements a refactor to the speculative decoding runtime to eliminate all gpu->cpu synchronizations. This means that it should nicely integrate into this async execution framework by simply caching the rectangular tensor of sampled token ids which includes the speculated tokens, and copying into input_ids in the same manner as this PR already does.

Drawbacks

This PR's correctness can be enforced by straightforward end-to-end testing with async scheduling enabled. However, it is not so easy to maintain the absence of synchronizations in the execute_model code. Currently, the FlashInfer implementation has two such synchronizations, so Flash Attention is used for benchmarking instead. This is a notable flaw in this design and, if accepted, will require vigilant regression testing for performance degradation due to accidentally introduced synchronization points.

Further, enforcing a fully sync-free execution limits compatibility with features such as n-gram speculative decoding, which inherently require the sampled token id to be serialized to the host. There may be future techniques that limit our ability to effectively maintain zero-synchronizations, and therefore limit the compatibility with this style of async scheduling.

Profile Results

The following is a snapshot of nsys profile of a decode iteration (BS=1) before and after this change. The setup is Llama 3.2 1B-Instruct on 1xB200 with async scheduling and full cudagraph enabled in both cases.

Before:

before-image

After:

after-image

To reproduce, run:

VLLM_ATTENTION_BACKEND=FLASH_ATTN vllm serve meta-llama/Llama-3.2-1B-Instruct --compilation-config '{"full_cuda_graph": true}' --no-enable-prefix-caching --async-scheduling

Test Plan

Tests are coming soon. See discussion in drawbacks above for test considerations.

Test Result

Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
@mergify mergify bot added the v1 label Aug 25, 2025
Copy link

mergify bot commented Aug 25, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @benchislett.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Aug 25, 2025
Signed-off-by: Benjamin Chislett <[email protected]>
@benchislett benchislett added the performance Performance-related issues label Aug 25, 2025
@benchislett benchislett marked this pull request as ready for review August 25, 2025 21:02
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @benchislett, the idea looks good to me!

Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
@mergify mergify bot removed the needs-rebase label Aug 26, 2025
Signed-off-by: Benjamin Chislett <[email protected]>
@benchislett benchislett requested a review from njhill August 26, 2025 18:59
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
Signed-off-by: Benjamin Chislett <[email protected]>
@benchislett benchislett requested a review from njhill August 27, 2025 20:58
@woodlgz
Copy link

woodlgz commented Sep 8, 2025

@benchislett hi, I was reading this code. I wondered if there is any chance current step 's prepare_inputs which conduct inplace update may pollute cpu-gpu copy buffer in those copy_ operations from previous step's prepare_inputs.

Maybe event recorded on previous step's _prepare_inputs operations should also be synchronized in AsyncGPUModelRunnerOutput.get_output method if there's any such event.

@njhill
Copy link
Member

njhill commented Sep 8, 2025

@woodlgz these operations are all happening in one cuda stream though so should be serialized.

AsyncGPUModelRunnerOutput and the gpu->cpu copy cuda stream deal only with the sampled_token_ids tensor and a new one of these is allocated by the sampler every step.

@woodlgz
Copy link

woodlgz commented Sep 9, 2025

@woodlgz these operations are all happening in one cuda stream though so should be serialized.

AsyncGPUModelRunnerOutput and the gpu->cpu copy cuda stream deal only with the sampled_token_ids tensor and a new one of these is allocated by the sampler every step.

hi, @njhill , thanks for reply. I got your idea, but my confusion is mainly about prepare_inputs which involves cpu -> gpu copy and reuses cpu buffers (input_ids for example) every time. GPU operations of course will be serialized but what if cpu buffer is rewritten before its related cudaMemcpyAsync (H2D) finished? this can happen when current step's prepare_inputs host work (involves populating needed cpu buffer) is overlapping with last step's device work (involves cpu->gpu copy and model gpu execution). usually cudaMemcpyAsync (H2D) takes very short time, so it may not be a problem, but there's still risks here, right?

@wangqia0309
Copy link

wangqia0309 commented Sep 9, 2025

I found a bug: when asynchronous scheduling is enabled, vLLM crashes if the request includes a frequency_penalty (greater than 0).

/pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:163: operator(): block: [0,0,0], thread: [0,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "scatter gather kernel index out of bounds"` failed.
torch.AcceleratorError: CUDA error: device-side assert triggered

Stack trace:
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/sample/ops/penalties.py", line 24, in apply_all_penalties
    return apply_penalties(logits, prompt_token_ids, output_tokens_t,
  File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/utils.py", line 73, in apply_penalties
    output_bin_counts, output_mask = get_token_bin_counts_and_mask(
  File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/utils.py", line 45, in get_token_bin_counts_and_mask
    bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))

@njhill
Copy link
Member

njhill commented Sep 9, 2025

@woodlgz ah I see what you mean. I think you're right, I've opened a PR to address it, please take a look: #24527

@wangqia0309 async scheduling support is still in progress - it's not yet compatible with various other features but we'll be addressing this soon.

eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
@JaheimLee
Copy link

I found a bug: when asynchronous scheduling is enabled, vLLM crashes if the request includes a frequency_penalty (greater than 0).

/pytorch/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:163: operator(): block: [0,0,0], thread: [0,0,0] Assertion `idx_dim >= 0 && idx_dim < index_size && "scatter gather kernel index out of bounds"` failed.
torch.AcceleratorError: CUDA error: device-side assert triggered

Stack trace:
  File "/usr/local/lib/python3.12/dist-packages/vllm/v1/sample/ops/penalties.py", line 24, in apply_all_penalties
    return apply_penalties(logits, prompt_token_ids, output_tokens_t,
  File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/utils.py", line 73, in apply_penalties
    output_bin_counts, output_mask = get_token_bin_counts_and_mask(
  File "/usr/local/lib/python3.12/dist-packages/vllm/model_executor/layers/utils.py", line 45, in get_token_bin_counts_and_mask
    bin_counts.scatter_add_(1, tokens, torch.ones_like(tokens))

Same problem. I think it's pretty important because many quantitative models cannot be used without it. @njhill

MengqingCao pushed a commit to vllm-project/vllm-ascend that referenced this pull request Sep 11, 2025
This PR is based on top of
[#23569](vllm-project/vllm#23569) and
[#24219](vllm-project/vllm#24219).

### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.

Expected speedup is 5-10% over the current async scheduling.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR
```

benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|


- vLLM version: main
- vLLM main:
vllm-project/vllm@e93f4cc

Signed-off-by: jiangpeng36 <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: jiangpeng36 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
yiz-liu pushed a commit to linfeng-yuan/vllm-ascend that referenced this pull request Sep 12, 2025
This PR is based on top of
[#23569](vllm-project/vllm#23569) and
[#24219](vllm-project/vllm#24219).

### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.

Expected speedup is 5-10% over the current async scheduling.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR
```

benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|

- vLLM version: main
- vLLM main:
vllm-project/vllm@e93f4cc

Signed-off-by: jiangpeng36 <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: jiangpeng36 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
Signed-off-by: Yizhou Liu <[email protected]>
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
xuechendi added a commit to vllm-project/vllm-gaudi that referenced this pull request Sep 15, 2025
Dependent on vllm-project/vllm#23569

---------

Signed-off-by: Tianmu Li <[email protected]>
Co-authored-by: Chendi.Xue <[email protected]>
offline893 pushed a commit to offline893/vllm-ascend that referenced this pull request Sep 16, 2025
This PR is based on top of
[#23569](vllm-project/vllm#23569) and
[#24219](vllm-project/vllm#24219).

### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.

Expected speedup is 5-10% over the current async scheduling.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR
```

benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|

- vLLM version: main
- vLLM main:
vllm-project/vllm@e93f4cc

Signed-off-by: jiangpeng36 <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: jiangpeng36 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
kdamaszk pushed a commit to kdamaszk/vllm-gaudi that referenced this pull request Sep 18, 2025
Dependent on vllm-project/vllm#23569

---------

Signed-off-by: Tianmu Li <[email protected]>
Co-authored-by: Chendi.Xue <[email protected]>
@njhill
Copy link
Member

njhill commented Sep 19, 2025

@wangqia0309 @JaheimLee could you try again with #25279? It's possible that was the reason for the penalty crash.

@JaheimLee
Copy link

@wangqia0309 @JaheimLee could you try again with #25279? It's possible that was the reason for the penalty crash.

Still has the problem.

@wangqia0309
Copy link

@wangqia0309 @JaheimLee could you try again with #25279? It's possible that was the reason for the penalty crash.

thanks, I have fixed this issue in house, It resulted from asynchronous scheduling that no longer fills the sample output token and uses -1 instead, so it needs to be asynchronously filled with the correct value in another thread

slokesha pushed a commit to slokesha/vllm-gaudi that referenced this pull request Sep 24, 2025
Dependent on vllm-project/vllm#23569

---------

Signed-off-by: Tianmu Li <[email protected]>
Co-authored-by: Chendi.Xue <[email protected]>
wangxiaoteng888 pushed a commit to LCAIZJ/vllm-ascend that referenced this pull request Sep 25, 2025
This PR is based on top of
[#23569](vllm-project/vllm#23569) and
[#24219](vllm-project/vllm#24219).

### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.

Expected speedup is 5-10% over the current async scheduling.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR
```

benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|


- vLLM version: main
- vLLM main:
vllm-project/vllm@e93f4cc

Signed-off-by: jiangpeng36 <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: jiangpeng36 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
chopper0126 pushed a commit to chopper0126/vllm-ascend that referenced this pull request Sep 26, 2025
This PR is based on top of
[#23569](vllm-project/vllm#23569) and
[#24219](vllm-project/vllm#24219).

### What this PR does / why we need it?
This PR allows the model runner to function asynchronously when using
async scheduling. This allows full overlap of the cpu operations
(including prepare_inputs) and the model forward pass. This diff is
functional and does not support speculative decoding, PP, or guided
decoding.

Expected speedup is 5-10% over the current async scheduling.

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
server
```
python -m vllm.entrypoints.openai.api_server --model=Qwen3-32B\
	--trust-remote-code --enforce-eager \
	--distributed-executor-backend=mp \
	-tp=4 \
	--port 8006 \
	--max-model-len 32000 \
	--block-size 128 \
	--gpu-memory-utilization 0.99
```
client
```
python $TEST_PY --backend vllm --trust-remote-code --model Qwen3-32B \
  --dataset-name random --random-input-len 2048 --random-output-len 2048 \
  --ignore-eos\
  --num-prompts 48 --max-concurrency 48  --request-rate inf --temperature 0 \
  --metric-percentiles 90  --base-url http://localhost:8006 --save-result \
  --result-dir $PROFILER_DIR
```

benchmark test based on Qwen3-32B TPOT result:
||forward async| scheduler async |sync|
|-|-|-|-|
|avg|41.73|41.86|44.20|
|improve0|0.3%|0|0|
|improve1|5.58%|0|0|

benchmark test based on Qwen2___5-VL-7B-Instruct TPOT result:
||forward async|sync|
|-|-|-|
|avg|23.22|29.16|
|improve|20.3%|0|


- vLLM version: main
- vLLM main:
vllm-project/vllm@e93f4cc

Signed-off-by: jiangpeng36 <[email protected]>
Signed-off-by: Ronald1995 <[email protected]>
Co-authored-by: jiangpeng36 <[email protected]>
Co-authored-by: Ronald1995 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.