Skip to content

Conversation

chanh
Copy link
Contributor

@chanh chanh commented Apr 4, 2025

Summary

Support capturing a single CUDA graph for the entire model's forward pass, instead of piecewise graphs. This requires creating persistent buffers to make attention graphable. Credit to @tlrmchlsmth for the original implementation.

Limitations:

  1. This only works with V1 + FA3, since FA2 currently is not graphable due to an optimization for GQA.
  2. This doesn't work with Cascade Attention.

Work in progress:

  1. Investigating changes needed to make this work with Llama4 / local attention

This reduces median TPOT by 7% for small models like Qwen 2.5 1.5B.

Before

With piecewise, there are multiple kernel launches per layer, with more gaps between the kernel execution (13ms time to decide one token in profiling mode):
Screenshot 2025-04-04 at 12 04 24 PM

============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  103.15    
Total input tokens:                      100000    
Total generated tokens:                  10000     
Request throughput (req/s):              0.97      
Output token throughput (tok/s):         96.95     
Total Token throughput (tok/s):          1066.46   
---------------Time to First Token----------------
Mean TTFT (ms):                          29.08     
Median TTFT (ms):                        28.89     
P99 TTFT (ms):                           36.17     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          5.75      
Median TPOT (ms):                        5.75      
P99 TPOT (ms):                           6.00      
---------------Inter-token Latency----------------
Mean ITL (ms):                           5.75      
Median ITL (ms):                         5.70      
P99 ITL (ms):                            6.58      
==================================================

After

There is now a single kernel launch, with almost no gaps between kernel execution (6ms time to decode one token in profiling mode):
Screenshot 2025-04-04 at 12 05 54 PM

============ Serving Benchmark Result ============
Successful requests:                     100       
Benchmark duration (s):                  103.10    
Total input tokens:                      100000    
Total generated tokens:                  10000     
Request throughput (req/s):              0.97      
Output token throughput (tok/s):         96.99     
Total Token throughput (tok/s):          1066.92   
---------------Time to First Token----------------
Mean TTFT (ms):                          29.52     
Median TTFT (ms):                        30.47     
P99 TTFT (ms):                           39.97     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          5.31      
Median TPOT (ms):                        5.33      
P99 TPOT (ms):                           5.56      
---------------Inter-token Latency----------------
Mean ITL (ms):                           5.31      
Median ITL (ms):                         5.27      
P99 ITL (ms):                            6.18      
==================================================

** Above benchmarks performed with:

VLLM_FLASH_ATTN_VERSION=3 VLLM_USE_V1=1 vllm serve Qwen/Qwen2.5-1.5B-Instruct  --enable-prefix-caching --dtype float16 --disable-log-requests -O3 (or -O4)

vllm bench serve \
        --model Qwen/Qwen2.5-1.5B-Instruct \
        --request-rate 1 \
        --num-prompts 100 \
        --random-input-len 1000 \
        --random-output-len 100 \
        --tokenizer Qwen/Qwen2.5-1.5B-Instruct \
        --ignore-eos

Copy link

github-actions bot commented Apr 4, 2025

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Apr 4, 2025
@mgoin mgoin requested a review from tlrmchlsmth April 4, 2025 21:34
@WoosukKwon WoosukKwon self-assigned this Apr 4, 2025
Chanh Nguyen added 2 commits April 7, 2025 20:57
Signed-off-by: Chanh Nguyen <[email protected]>
Signed-off-by: Chanh Nguyen <[email protected]>
Copy link

mergify bot commented Apr 8, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @chanh.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Apr 8, 2025
Chanh Nguyen added 2 commits April 8, 2025 07:38
Signed-off-by: Chanh Nguyen <[email protected]>
@mergify mergify bot added ci/build and removed needs-rebase labels Apr 8, 2025
@chanh chanh marked this pull request as ready for review April 8, 2025 08:44
@alexm-redhat
Copy link
Collaborator

@chanh thanks for the PR, I have tested llama 8b on my side with your PR and I see ~7% improvement for TPOT. Great work!

Before PR:

============ Serving Benchmark Result ============
Successful requests:                     50        
Benchmark duration (s):                  45.05     
Total input tokens:                      25600     
Total generated tokens:                  12800     
Request throughput (req/s):              1.11      
Output token throughput (tok/s):         284.11    
Total Token throughput (tok/s):          852.34    
---------------Time to First Token----------------
Mean TTFT (ms):                          22.43     
Median TTFT (ms):                        22.10     
P99 TTFT (ms):                           27.48     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.63      
Median TPOT (ms):                        7.63      
P99 TPOT (ms):                           7.77      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.63      
Median ITL (ms):                         7.61      
P99 ITL (ms):                            8.45      
==================================================

After PR:

============ Serving Benchmark Result ============
Successful requests:                     50        
Benchmark duration (s):                  44.93     
Total input tokens:                      25600     
Total generated tokens:                  12800     
Request throughput (req/s):              1.11      
Output token throughput (tok/s):         284.88    
Total Token throughput (tok/s):          854.64    
---------------Time to First Token----------------
Mean TTFT (ms):                          22.72     
Median TTFT (ms):                        22.93     
P99 TTFT (ms):                           27.49     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.14      
Median TPOT (ms):                        7.14      
P99 TPOT (ms):                           7.28      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.14      
Median ITL (ms):                         7.12      
P99 ITL (ms):                            8.05      
==================================================

@sarckk
Copy link
Collaborator

sarckk commented Apr 9, 2025

Work in progress:

  1. Investigating changes needed to make this work with Llama4 / local attention

just a heads up @zou3519

@chanh
Copy link
Contributor Author

chanh commented Apr 9, 2025

@chanh thanks for the PR, I have tested llama 8b on my side with your PR and I see ~7% improvement for TPOT. Great work!

Before PR:

============ Serving Benchmark Result ============
Successful requests:                     50        
Benchmark duration (s):                  45.05     
Total input tokens:                      25600     
Total generated tokens:                  12800     
Request throughput (req/s):              1.11      
Output token throughput (tok/s):         284.11    
Total Token throughput (tok/s):          852.34    
---------------Time to First Token----------------
Mean TTFT (ms):                          22.43     
Median TTFT (ms):                        22.10     
P99 TTFT (ms):                           27.48     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.63      
Median TPOT (ms):                        7.63      
P99 TPOT (ms):                           7.77      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.63      
Median ITL (ms):                         7.61      
P99 ITL (ms):                            8.45      
==================================================

After PR:

============ Serving Benchmark Result ============
Successful requests:                     50        
Benchmark duration (s):                  44.93     
Total input tokens:                      25600     
Total generated tokens:                  12800     
Request throughput (req/s):              1.11      
Output token throughput (tok/s):         284.88    
Total Token throughput (tok/s):          854.64    
---------------Time to First Token----------------
Mean TTFT (ms):                          22.72     
Median TTFT (ms):                        22.93     
P99 TTFT (ms):                           27.49     
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          7.14      
Median TPOT (ms):                        7.14      
P99 TPOT (ms):                           7.28      
---------------Inter-token Latency----------------
Mean ITL (ms):                           7.14      
Median ITL (ms):                         7.12      
P99 ITL (ms):                            8.05      
==================================================

Thanks for @alexm-redhat for verifying!

Copy link
Collaborator

@alexm-redhat alexm-redhat left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chanh went over the PR in detail, looks really good. Left some comments. Thanks for adding the test, I think it can be expanded a bit to cover CUDA graph's edge cases a bit better.

@mgoin mgoin self-requested a review April 11, 2025 14:54
@alexm-redhat
Copy link
Collaborator

@chanh tell me if you need help with extending the tests, I can do it on my side.

@WoosukKwon
Copy link
Collaborator

Thanks for the PR! I will review it this weekend (maybe Tyler and Rob, too).

@dblincoe
Copy link

I ran some latency-focused testing on this PR using LLaMA 3.2 1B Instruct with a small batch size (~1-2) in a highly latency-constrained setting where minimizing CUDA graph launches can significantly improve GPU utilization. Here are the results:

Before PR:

Average latency: 56.82 ms
p50 latency: 53.00 ms
p90 latency: 64.00 ms
p95 latency: 68.00 ms
p99 latency: 82.23 ms

After PR:

Average latency: 50.30 ms
p50 latency: 48.00 ms
p90 latency: 58.00 ms
p95 latency: 61.00 ms
p99 latency: 67.00 ms

This shows a notable improvement across the board, particularly in tail latencies. Great work!

@mergify mergify bot removed the needs-rebase label May 7, 2025
@chanh chanh requested a review from WoosukKwon May 7, 2025 11:26
@tlrmchlsmth
Copy link
Member

@chanh Thanks for pushing this through!

@LucasWilkinson
Copy link
Collaborator

I think we may need to disable ahead-of-time scheduling for FA3 when using full cuda-graph:

if self.aot_schedule:
return get_scheduler_metadata(
batch_size=batch_size,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
cache_seqlens=seqlens,
num_heads_q=self.num_heads_q,
num_heads_kv=self.num_heads_kv,
headdim=self.headdim,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
window_size=self.aot_sliding_window,
)

since this scheduler may choose a different number of splits than what the graph was captured with

do we have lm-eval accuracy results with full cuda-graphs on?

@chanh
Copy link
Contributor Author

chanh commented May 7, 2025

I think we may need to disable ahead-of-time scheduling for FA3 when using full cuda-graph:

if self.aot_schedule:
return get_scheduler_metadata(
batch_size=batch_size,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
cache_seqlens=seqlens,
num_heads_q=self.num_heads_q,
num_heads_kv=self.num_heads_kv,
headdim=self.headdim,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
window_size=self.aot_sliding_window,
)

since this scheduler may choose a different number of splits than what the graph was captured with

do we have lm-eval accuracy results with full cuda-graphs on?

Will discuss with you over Slack

Signed-off-by: Chanh Nguyen <[email protected]>
@chanh
Copy link
Contributor Author

chanh commented May 7, 2025

I think we may need to disable ahead-of-time scheduling for FA3 when using full cuda-graph:

if self.aot_schedule:
return get_scheduler_metadata(
batch_size=batch_size,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
cache_seqlens=seqlens,
num_heads_q=self.num_heads_q,
num_heads_kv=self.num_heads_kv,
headdim=self.headdim,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
window_size=self.aot_sliding_window,
)

since this scheduler may choose a different number of splits than what the graph was captured with
do we have lm-eval accuracy results with full cuda-graphs on?

Will discuss with you over Slack

Okay disabled it for now.

fix
Signed-off-by: Chanh Nguyen <[email protected]>
@chanh
Copy link
Contributor Author

chanh commented May 7, 2025

I think we may need to disable ahead-of-time scheduling for FA3 when using full cuda-graph:

if self.aot_schedule:
return get_scheduler_metadata(
batch_size=batch_size,
max_seqlen_q=max_query_len,
max_seqlen_k=max_seq_len,
cache_seqlens=seqlens,
num_heads_q=self.num_heads_q,
num_heads_kv=self.num_heads_kv,
headdim=self.headdim,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
window_size=self.aot_sliding_window,
)

since this scheduler may choose a different number of splits than what the graph was captured with
do we have lm-eval accuracy results with full cuda-graphs on?

Will discuss with you over Slack

Okay disabled it for now.

lm-eval results

[Current branch, Full CUDA Graph flag enabled, modified lm-eval to pass the compilation_config JSON properly to vLLM]
VLLM_FLASH_ATTN_VERSION=3 VLLM_USE_V1=1 \
lm_eval --model vllm \
  --model_args "pretrained=Qwen/Qwen2-1.5B-Instruct,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1,compilation_config={\"full_cuda_graph\": true}" \
  --tasks gsm8k

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5982|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.5898|±  |0.0135|


[Main branch]
VLLM_FLASH_ATTN_VERSION=3 VLLM_USE_V1=1 \
lm_eval --model vllm \
  --model_args "pretrained=Qwen/Qwen2-1.5B-Instruct,tensor_parallel_size=1,dtype=auto,gpu_memory_utilization=0.8,data_parallel_size=1" \

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.5951|±  |0.0135|
|     |       |strict-match    |     5|exact_match|↑  |0.5891|±  |0.0136|

Copy link
Collaborator

@LucasWilkinson LucasWilkinson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks!

@simon-mo simon-mo merged commit 7ea2adb into vllm-project:main May 8, 2025
51 checks passed
princepride pushed a commit to princepride/vllm that referenced this pull request May 10, 2025
Signed-off-by: Chanh Nguyen <[email protected]>
Co-authored-by: Chanh Nguyen <[email protected]>
Signed-off-by: 汪志鹏 <[email protected]>
RichardoMrMu pushed a commit to RichardoMrMu/vllm that referenced this pull request May 12, 2025
Signed-off-by: Chanh Nguyen <[email protected]>
Co-authored-by: Chanh Nguyen <[email protected]>
Signed-off-by: Mu Huai <[email protected]>
@renjie0
Copy link

renjie0 commented May 13, 2025

Work in progress:

  1. Investigating changes needed to make this work with Llama4 / local attention

just a heads up @zou3519

What is special about local attention?

mawong-amd pushed a commit to ROCm/vllm that referenced this pull request May 14, 2025
Signed-off-by: Chanh Nguyen <[email protected]>
Co-authored-by: Chanh Nguyen <[email protected]>
zzzyq pushed a commit to zzzyq/vllm that referenced this pull request May 24, 2025
Signed-off-by: Chanh Nguyen <[email protected]>
Co-authored-by: Chanh Nguyen <[email protected]>
Signed-off-by: Yuqi Zhang <[email protected]>
@Juelianqvq
Copy link
Contributor

@chanh It seems that full cuda graph support outputs garbage on latest main. Do you have any idea?

@ProExpertProg
Copy link
Collaborator

@chanh +1 - it seems like the test was never added to CI (needs to be added manually to .buildkite/test-pipeline.yml). When I run the test locally, the first shape works and all the other shapes output garbage.

})

with set_forward_context(None,
with set_forward_context(attn_metadata,
Copy link
Contributor

@hidva hidva Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Considering that self.maybe_setup_kv_connector(scheduler_output) is not executed here, in the Full Cuda Graph scenario, the sequence unified_attention_with_output -> maybe_save_kv_layer_to_connector -> connector.save_kv_layer() will cause the connector to read uninitialized metadata.

https://github.com/LMCache/LMCache/blob/680fbdf84e2ee1040bf4e084d43c9155a91b8d5c/lmcache/integration/vllm/vllm_v1_adapter.py#L609-L610

Therefore, Full Cuda Graph should be incompatible with kvconnector?

@simon-mo

@Lmywl
Copy link

Lmywl commented Jun 19, 2025

  1. This only works with V1 + FA3, since FA2 currently is not graphable due to an optimization for GQA.

Hello, I have changed the code to make full Cudagraph capture with FA2. The result shows that FA2 can also work correctly.
So, I'm curious that what does it specifically refer to that " FA2 currently is not graphable due to an optimization for GQA"

@happierpig
Copy link

@WoosukKwon This maybe helpful. Regarding to FA2, FlashInfer (flashinfer-ai/flashinfer#1137) recently merges a PR that implements persistent-style FA2 template. This PR unifies prefill and decode, which supports a single cuda-graph for all batcheds and sequence lengths.

@xsank
Copy link
Contributor

xsank commented Aug 28, 2025

@chanh #23739, do you have any idea of this problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ci/build documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.