Skip to content

Conversation

varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Feb 11, 2025

Add shrink and expand triton kernels for V1.

Why do we need a new set of kernels:

  • V0 sorts/groups requests based on LoRA ID. The SGMV kernels take advantage of this and groups the compute within thread blocks.
  • V1 doesn't group requests based on LoRA ID. The new set of kernels have information about which input tokens map to which LoRA ID and they use this information to load the appropriate input tokens. The rest of the matmul is very similar to the SGMV kernels.

Kernel Code Change:
The new kernels re-use a lot of the code from the existing SGMV kernels. The main changes are,

  1. Kernel Launch Grid formulation (this was required so the kernels are CUDAGraph compatible. Note that SGMV kernels are not)
  2. Loading of the input tokens (A matrix) for the matmul.
    All other kernel code is the same as the existing SGMV kernels. I refactored the code so it can be reused.

benchmark serving numbers :

server command : VLLM_USE_V1="1" vllm serve meta-llama/Llama-2-7b-hf --max-loras 4 --max-lora-rank 8 --enable-lora --lora-modules lora1=yard1/llama-2-7b-sql-lora-test lora2=yard1/llama-2-7b-sql-lora-test lora3=yard1/llama-2-7b-sql-lora-test lora4=yard1/llama-2-7b-sql-lora-test --no-enable-prefix-caching

benchmark command : python3 benchmarks/benchmark_serving.py --model meta-llama/Llama-2-7b-hf --dataset-name sharegpt --dataset-path ./ShareGPT_V3_unfiltered_cleaned_split.json --num-prompts 500 --request-rate inf --lora-modules lora1 lora2 lora3 lora4

V1 LoRA - This PR:

============ Serving Benchmark Result ============
Successful requests:                     500       
Benchmark duration (s):                  68.81     
Total input tokens:                      117316    
Total generated tokens:                  110942    
Request throughput (req/s):              7.27      
Output token throughput (tok/s):         1612.20   
Total Token throughput (tok/s):          3317.02   
---------------Time to First Token----------------
Mean TTFT (ms):                          7103.84   
Median TTFT (ms):                        7136.82   
P99 TTFT (ms):                           14040.12  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          126.93    
Median TPOT (ms):                        94.96     
P99 TPOT (ms):                           231.31    
---------------Inter-token Latency----------------
Mean ITL (ms):                           76.47     
Median ITL (ms):                         58.02     
P99 ITL (ms):                            236.28    
==================================================

V1 LoRA - Main:

============ Serving Benchmark Result ============
Successful requests:                     500       
Benchmark duration (s):                  117.84    
Total input tokens:                      117316    
Total generated tokens:                  110942    
Request throughput (req/s):              4.24      
Output token throughput (tok/s):         941.44    
Total Token throughput (tok/s):          1936.96   
---------------Time to First Token----------------
Mean TTFT (ms):                          10277.45  
Median TTFT (ms):                        9370.56   
P99 TTFT (ms):                           22882.99  
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          259.87    
Median TPOT (ms):                        236.82    
P99 TPOT (ms):                           445.65    
---------------Inter-token Latency----------------
Mean ITL (ms):                           173.39    
Median ITL (ms):                         164.76    
P99 ITL (ms):                            459.73    
==================================================

Kernel micro benchmark:

Please find kernel microbenchmark here - https://docs.google.com/spreadsheets/d/1b_8KsDGdiSGWlHODMszug_-do7OlSPlSzoV84VkmVPc/edit?usp=sharing (sheet : "V1 : Dont Sort Tokens By LoRA ")

Note : The V0 SGMV and BGMV kernels are not tuned. But the V1 kernels are tuned with triton auto-tuner. Therefore the discrepancy between the V1 and SGMV/BGMV kernels could be partially explained by the tuning.
The SGMV kernel depends heavily on the input being sorted. V1 kernels aren't affected as much.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Feb 11, 2025
@varun-sundar-rabindranath varun-sundar-rabindranath force-pushed the varun/v1-lora-kernels branch 2 times, most recently from 1e6caf2 to d78cd57 Compare February 20, 2025 13:02
@varun-sundar-rabindranath varun-sundar-rabindranath marked this pull request as ready for review February 20, 2025 14:15
@@ -0,0 +1,102 @@
# SPDX-License-Identifier: Apache-2.0

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Requesting reviews on this file. The utilities in this file deal with loading the stored triton configs.
cc @tlrmchlsmth @mgoin

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Is there a significant gain to stored tuned config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not significant - but I do see some gains. On average I see about 200 tokens/s more throughput when using the tuned configs.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TBH, storing these configs is a bit crazy, I'm not sure if this is the right direction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It appears, it is common to have these configs for triton kernels - the fused_moe kernels also have such configs here https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/fused_moe/configs

Arguments for storing configs:

  • The default/static config isn't going to be optimal for all cases.
  • Generating these configs is easy. With new GPUs and Triton version updates we can simply run the triton autotuner and drop in these configs without having to do any code change.

I did a tuned-kernels vs untuned-kernels micro benchmark run -
https://docs.google.com/spreadsheets/d/1f87TKWuwJXVK2-8YHGBoVbo6ueEWGpLPPUIkgYs6dMo/edit?usp=sharing .
TLDR the tuned shrink kernels are better than the untuned versions in most cases. For expand kernels the tuned version is much better at low batch size regimes

E2E performance:

  • The 200 tokens / s number I shared is on the A100 GPU. The E2E performance is limited in this case. But I believe with CUDA Graphs enabled, the tuned kernels will have a bigger impact.
  • Also, I haven't checked the E2E performance on H100 GPU. it might be better.

What do you think ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Spoke with @jeejeelee IRL - The plan is to remove the triton configs from this PR and introduce them in a separate PR so we can reason about them separately.

@jeejeelee I have removed the configs from this PR. can you take another pass at the PR when you find the time please ! Thanks 🙏

self._v1_prepare_metadata_tensors(self.token_lora_indices,
self.sampler_indices)
else:
# Forward to base class update_metadata
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to call the base class method ? (Note that this class inherits from multiple classes. )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe super().update_metadata is better

@jeejeelee
Copy link
Collaborator

jeejeelee commented Feb 28, 2025

All the LoRA tests have failed again

@varun-sundar-rabindranath
Copy link
Contributor Author

All the LoRA tests have failed again

Looking into this now 👍

@varun-sundar-rabindranath
Copy link
Contributor Author

Update : I enabled tests in tests/lora/test_layers.py for V1. The tests work locally but OOM's on the CI - I am tracking this down.

Copy link

mergify bot commented Mar 3, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @varun-sundar-rabindranath.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Mar 3, 2025
@varun-sundar-rabindranath varun-sundar-rabindranath force-pushed the varun/v1-lora-kernels branch 2 times, most recently from 6b9fadf to 2e9eb8b Compare March 3, 2025 21:44
@mergify mergify bot removed the needs-rebase label Mar 3, 2025
@jeejeelee
Copy link
Collaborator

It seems these modifications have significantly increased the time consumption for lora testing
image

@varun-sundar-rabindranath
Copy link
Contributor Author

varun-sundar-rabindranath commented Mar 4, 2025

It seems these modifications have significantly increased the time consumption for lora testing image

Yes. This PR adds the v1_kernel tests in test_punica_ops.py and enables test_layers.py to run for V1 also. I believe most of it is coming from the test_layers.py that now runs for both V1 and V0 (effectively doubling its run time) - Ill see what we can do here.

[Edit]
@jeejeelee

Update : Reduced the tests in commits a18d273 and ba94947
The times are now,
Screenshot 2025-03-04 at 11 38 20 AM

a maximum of 7 minute increase. Do you think we should prune further ?

Copy link
Collaborator

@jeejeelee jeejeelee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much. Let's continue moving forward in the direction we discussed on Slack.

@jeejeelee jeejeelee added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 10, 2025
Varun Sundar Rabindranath added 20 commits March 10, 2025 11:06
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
nit
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
@robertgshaw2-redhat robertgshaw2-redhat merged commit 5ff0d32 into vllm-project:main Mar 10, 2025
36 checks passed
lulmer pushed a commit to lulmer/vllm that referenced this pull request Apr 7, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Louis Ulmer <[email protected]>
shreyankg pushed a commit to shreyankg/vllm that referenced this pull request May 3, 2025
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed v1
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants