Skip to content

Conversation

shixianc
Copy link
Contributor

@shixianc shixianc commented Aug 17, 2025

Purpose

Integrate permute/unpermute cuda kernels from #17934 into run_cutlass_moe_fp8 to speed up the ops before and after cutlass matmuls:

  1. Preprocessing: call moe_permute to compute expert_offsets and permuted_hidden_states. And expose compute_problem_sizes kernel to python level to compute problem_sizes alone (based on NCU profiling this step takes only 1-2% of layer latency so did not do further fusion).
  2. Postprocessing: call moe_unpermute which fuses weight_mul and unpermute together. Therefore TopKWeightAndReduceNoOP is used to override finalize_weight_and_reduce_impl().

Additionally:

  1. reuse workspace13 and workspace2 alternatively to avoid intermediate tensor creation (e.g., outputs for permuted hidden and fp8_quant, avoid final output._copy)
  2. I took the idea from [Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE #20762 on pre-calculation of ab_strides and c_strides, given it's reverted so I reapplied partial of original commit. Also note that the current kernel is faster than the optimized moe_shuffle proposed in this previous PR (benchmark comparison attached below).

Note:
Profiling shows that Triton FusedMoE kernel still beats cutlass kernel on lower batch sizes because heavier cutlass preprocessing steps (e.g. permutation, more input tensors to prepare) I plan to have a follow up PR to route lower BS to triton.
NCU Profiling comparison between triton vs. cutlass on M=16. C3X Grouped GEMM matmuls only takes 28% moe latency vs. Triton 63%. This overhead becomes smaller on larger M where CUTLASS will beat Triton.
CUTLASS Grouped GEMM (top2 kernels are 2 matmuls, followed by long list of kernel calls):
Screenshot 2025-08-16 at 10 08 19 PM
Triton FusedMoE (much less kernel launched):
Screenshot 2025-08-16 at 10 11 48 PM

Test Plan

  1. pytest tests/kernels/moe/*
  2. Kernel Comparison (unit: ms) vs. PR [Performance] Performance improvements in non-blockwise fp8 CUTLASS MoE #20762 shuffle_rows
  3. Performance Test: layer benchmark against baseline cutlass moe
  4. Quality Test:
lm_eval --model vllm --model_args pretrained=RedHatAI/Llama-4-Maverick-17B-128E-Instruct-FP8,tensor_parallel_size=8,max_model_len=2048,gpu_memory_utilization=0.9,max_num_seqs=32 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

lm_eval --model vllm --model_args pretrained=RedHatAI/Mixtral-8x7B-Instruct-v0.1-FP8,tensor_parallel_size=8,max_model_len=2048,gpu_memory_utilization=0.9 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto

Test Result

1. pytest tests/kernels/moe/*

33 failed, 6218 passed, 1852 skipped, 7 warnings in 3771.32s (1:02:51)
All failed in:
FAILED tests/kernels/moe/test_block_fp8.py::test_w8a8_block_fp8_fused_moe

        tol = 0.035 if M < 40000 else 0.039
>       torch.testing.assert_close(out, ref_out, atol=tol, rtol=tol)
E       AssertionError: Tensor-likes are not close!
E
E       Mismatched elements: 1 / 14680064 (0.0%)
E       Greatest absolute difference: 0.039794921875 at index (1316, 3337) (up to 0.035 allowed)
E       Greatest relative difference: 0.90234375 at index (1316, 3337) (up to 0.035 allowed)

The error is related to TritonExperts and I saw same error on main w/o my changes. I'll see if I can track this down to particular commit ...

2. Kernel Comparison (unit: ms) vs. PR #20762 shuffle_rows

Model / Config (M, K, N) shuffle_rows cuda (un)permute % improvement
DeepSeekV2-Lite (1, 2048, 1408) 2.0 1.9 +5%
DeepSeekV2-Lite (16, 2048, 1408) 4.8 4.9 -2%
DeepSeekV2-Lite (512, 2048, 1408) 8.0 6.9 +14%
DeepSeekV2-Lite (4096, 2048, 1408) 35.6 26.8 +25%
DeepSeekV2-Lite (10240, 2048, 1408) 83.7 60.8 +27%
Granite-1B (1, 1024, 1024) 1.6 1.6 0%
Granite-1B (16, 1024, 1024) 2.5 2.4 +4%
Granite-1B (512, 1024, 1024) 4.8 4.1 +15%
Granite-1B (4096, 1024, 1024) 23.7 17.7 +25%
Granite-1B (10240, 1024, 1024) 54.4 39.4 +28%
Granite-3B (1, 1024, 1536) 1.8 1.7 +6%
Granite-3B (16, 1024, 1536) 3.1 3.1 0%
Granite-3B (512, 1024, 1536) 5.9 5.2 +12%
Granite-3B (4096, 1024, 1536) 29.6 23.6 +20%
Granite-3B (10240, 1024, 1536) 68.3 53.4 +22%
Mixtral-8x7B (1, 4096, 28672) 6.8 7.6 -12%
Mixtral-8x7B (16, 4096, 28672) 18.7 18.7 0%
Mixtral-8x7B (512, 4096, 28672) 34.5 31.8 +8%
Mixtral-8x7B (4096, 4096, 28672) 201.4 195.8 +3%
Mixtral-8x7B (10240, 4096, 28672) 486.9 474.2 +3%
Mixtral-8x7B (1, 14336, 4096) 4.6 4.6 0%
Mixtral-8x7B (16, 14336, 4096) 10.7 10.6 +1%
Mixtral-8x7B (512, 14336, 4096) 21.2 19.1 +10%
Mixtral-8x7B (4096, 14336, 4096) 122.1 103.4 +15%
Mixtral-8x7B (10240, 14336, 4096) 302.6 248.7 +18%

3. Performance Test: layer benchmark against baseline cutlass moe

Mixtral: RedHatAI/Mixtral-8x7B-Instruct-v0.1-FP8
LLaMA4: RedHatAI/Llama-4-Maverick-17B-128E-Instruct-FP8
QWEN3: Qwen/Qwen3-235B-A22B

M llama4 baseline llama4 cuda (un)permute mixtral-tp1 baseline mixtral-tp1 cuda (un)permute mixtral-tp8 baseline mixtral-tp8 cuda (un)permute qwen3-235B baseline qwen3-235B cuda (un)permute
2 53.16 52.25 196.27 193.03 56.28 58.21 116.69 118.80
4 63.20 58.28 270.79 272.83 67.95 68.56 177.25 180.44
8 74.47 72.37 339.88 338.88 77.96 77.40 275.92 281.86
16 102.80 101.55 370.55 368.43 81.45 83.15 416.73 422.01
64 235.17 232.31 416.76 413.67 86.31 86.63 623.87 627.96
128 356.50 352.24 434.55 425.84 89.78 89.58 651.17 652.36
256 487.17 480.44 555.46 529.08 121.31 121.01 708.23 700.80
512 540.10 530.19 704.87 661.02 136.81 132.21 741.94 720.87
1024 576.02 557.74 1200.26 1108.55 198.02 188.55 938.46 888.24
2048 624.00 588.86 2174.25 1983.22 356.12 319.08 1365.11 1250.25
4096 717.53 640.28 4170.79 3717.44 614.92 563.19 2352.36 2194.96
8192 961.51 813.88 8052.83 7343.35 1128.27 1002.53 4479.57 4144.66
10240 1159.86 982.33 10007.20 9020.98 1415.05 1262.71 5447.69 5062.40
20480 1802.10 1455.03 19984.04 17919.30 2750.78 2417.12 10309.76 9457.82
76800 7232.99 5474.50 69537.11 68032.06 10499.43 9109.40 39176.71 35258.72

4. Quality Test:

lm_eval --model vllm --model_args pretrained=RedHatAI/Llama-4-Maverick-17B-128E-Instruct-FP8,tensor_parallel_size=8,max_model_len=2048,gpu_memory_utilization=0.9,max_num_seqs=32 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9143 ± 0.0077
strict-match 5 exact_match 0.9166 ± 0.0076
lm_eval --model vllm --model_args pretrained=RedHatAI/Mixtral-8x7B-Instruct-v0.1-FP8,tensor_parallel_size=8,max_model_len=2048,gpu_memory_utilization=0.9 --trust_remote_code --tasks gsm8k --num_fewshot 5 --batch_size auto
Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.6080 ± 0.0134
strict-match 5 exact_match 0.6065 ± 0.0135

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the performance Performance-related issues label Aug 17, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request integrates CUDA permute/unpermute kernels for MoE FP8 operations, aiming to improve performance. The changes include refactoring the MoE data preparation and finalization steps, introducing new CUDA kernels, and updating the Python bindings and benchmarks accordingly. My review found a few areas for improvement. There is a recurring typo in a variable name in one of the CUDA files, which should be corrected for consistency and readability. Additionally, there are redundant attribute assignments in two classes which can be removed to make the code cleaner and more maintainable.

Copy link
Collaborator

@yewentao256 yewentao256 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! Please also fix as Gemini suggests, DCO and pre-commit issue.

@mgoin
Copy link
Member

mgoin commented Aug 17, 2025

Excellent analysis and work! We were just talking about unreverting Eliza's work this week, so this is timely.
I didn't see the accuracy evals reported yet, so will wait on those and Wentao's comments. Looking forward to getting this in.

@shixianc shixianc force-pushed the cutlass-permute-integ branch from de5e9d0 to ca9e4f6 Compare August 17, 2025 16:51
@shixianc
Copy link
Contributor Author

@mgoin @yewentao256 Thanks for the quick review!

Addressed all comments and attached quality test in the description.
The only issue in unittest is 33 errors in test_block_fp8.py::test_w8a8_block_fp8_fused_moe, however I found same errors from previous commit, so not related this PR, however I'll see if I can track it down to a particular commit.

@shixianc
Copy link
Contributor Author

@mgoin @yewentao256 Thanks for the quick review!

Addressed all comments and attached quality test in the description. The only issue in unittest is 33 errors in test_block_fp8.py::test_w8a8_block_fp8_fused_moe, however I found same errors from previous commit, so not related this PR, however I'll see if I can track it down to a particular commit.

@mgoin @yewentao256 regarding the unrelated pytest error I'm able to track it down to #21083 (cuda fp8 block quant kernel) which uses atol=0.15 in its unittest, however the failed fused_moe test uses atol=0.035, if I increase to 0.075 we can pass all the tests. Perhaps we should consider increasing the tolerance.

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Aug 19, 2025
@shixianc shixianc force-pushed the cutlass-permute-integ branch from 4fb9926 to 1b8b4b1 Compare August 19, 2025 04:19
@ElizaWszola
Copy link
Contributor

Nice work, I really like the speedups!

Regarding the failed fused_moe tests, did you manually inspect the ground truth vs. CUTLASS MoE outputs to confirm that they look similar and max diff is triggered for only a few elements?

Signed-off-by: Shixian Cui <[email protected]>
@shixianc
Copy link
Contributor Author

shixianc commented Aug 19, 2025

Nice work, I really like the speedups!

Regarding the failed fused_moe tests, did you manually inspect the ground truth vs. CUTLASS MoE outputs to confirm that they look similar and max diff is triggered for only a few elements?

@ElizaWszola Thanks for reviewing! The failed unittest is not for cutlass but for triton moe, please see my previous comment:

@mgoin @yewentao256 regarding the unrelated pytest error I'm able to track it down to #21083 (cuda fp8 block quant kernel) which uses atol=0.15 in its unittest, however the failed fused_moe test uses atol=0.035, if I increase to 0.075 we can pass all the tests. Perhaps we should consider increasing the tolerance.

but yeah it's only triggered for < 0.01% elements.

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, thank you!

@mgoin mgoin merged commit b17109b into vllm-project:main Aug 20, 2025
70 checks passed
djmmoss pushed a commit to djmmoss/vllm that referenced this pull request Aug 21, 2025
shanes-cerebras pushed a commit to smsegal/vllm that referenced this pull request Aug 24, 2025
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
xiao-llm pushed a commit to xiao-llm/vllm that referenced this pull request Aug 28, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
mengxingkongzhouhan pushed a commit to mengxingkongzhouhan/vllm that referenced this pull request Aug 30, 2025
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Sep 3, 2025
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance Performance-related issues ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants