Skip to content

Conversation

varun-sundar-rabindranath
Copy link
Contributor

@varun-sundar-rabindranath varun-sundar-rabindranath commented Jul 29, 2025

Purpose

DeepEPHighThroughput All2All kernel when used with DeepSeek models dispatches the tokens in 16bit datatype and quantizes after dispatch. This is inefficient for 2 reasons,

  • More data in communication
  • More data to quantize after dispatch

This PR introduces a fix to quantize to fp8 first and then dispatch the fp8 tensor.

Test Plan

canhazgpu run -g2 -- pytest -s tests/kernels/moe/test_modular_kernel_combinations.py

canhazgpu run -g2 -- pytest tests/kernels/moe/test_deepep_deepgemm_moe.py

VLLM_ALL2ALL_BACKEND="deepep_high_throughput" VLLM_USE_DEEP_GEMM=1  canhazgpu run -g 2 --  vllm serve Qwen/Qwen3-30B-A3B-FP8  --trust-remote-code --enable-expert-parallel --data-parallel-size 2 --port 9010 --no-enable-prefix-caching

Test Result

All tests pass for canhazgpu run -g2 -- pytest -s tests/kernels/moe/test_modular_kernel_combinations.py

All tests pass for canhazgpu run -g2 -- pytest tests/kernels/moe/test_deepep_deepgemm_moe.py

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  | 0.86|±  |0.0349|
|     |       |strict-match    |     5|exact_match|↑  | 0.94|±  |0.0239|

Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the deepseek Related to DeepSeek models label Jul 29, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a performance optimization for MoE layers using DeepEPHighThroughput with block quantization (e.g., for DeepSeek models). The change correctly modifies the logic to quantize the activations before dispatching them, which reduces communication overhead and is more efficient.

The implementation is clean and effective. The condition for pre-quantization is correctly expanded to include block-quantized cases, and the call to the quantization kernel is updated to pass the correct parameters, which also fixes a potential bug that the logical change would have otherwise introduced.

Overall, the changes look solid and align well with the stated purpose. I couldn't find any issues of high or critical severity.

@varun-sundar-rabindranath
Copy link
Contributor Author

@tlrmchlsmth @bnellnm PTAL ! Thanks 🙌

@bnellnm
Copy link
Contributor

bnellnm commented Jul 29, 2025

So we still go down the "quantize after" codepath if the quantization is per-tensor? Is there some reason that quantization can't happen beforehand in that case also? Or does DeepEP not support that?

@varun-sundar-rabindranath
Copy link
Contributor Author

So we still go down the "quantize after" codepath if the quantization is per-tensor? Is there some reason that quantization can't happen beforehand in that case also? Or does DeepEP not support that?

It is a DeepEP limitation. DeepEP doesn't support that.

@bnellnm
Copy link
Contributor

bnellnm commented Jul 29, 2025

So we still go down the "quantize after" codepath if the quantization is per-tensor? Is there some reason that quantization can't happen beforehand in that case also? Or does DeepEP not support that?

It is a DeepEP limitation. DeepEP doesn't support that.

Would it make sense to fake it out by replicating the scale and then resizing/truncating them after the dispatch?

@varun-sundar-rabindranath
Copy link
Contributor Author

varun-sundar-rabindranath commented Jul 30, 2025

So we still go down the "quantize after" codepath if the quantization is per-tensor? Is there some reason that quantization can't happen beforehand in that case also? Or does DeepEP not support that?

It is a DeepEP limitation. DeepEP doesn't support that.

Would it make sense to fake it out by replicating the scale and then resizing/truncating them after the dispatch?

I went back and looked at the DeepEP documentation here
The documentation suggests that only block-quantization is supported. But the function seemingly also supports per-token quantization (We have unit test that have been passing - look here ).

However, it looks like we are an assert away in the DeepEP repo from crashing. To be safe, I have updated the code to support only block-quantization for the "Quant-then-Dispatch" block. For any other quantization we will "Dispatch-then-Quant"

cc @tlrmchlsmth

Copy link
Member

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@tlrmchlsmth tlrmchlsmth added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 31, 2025
@tlrmchlsmth tlrmchlsmth enabled auto-merge (squash) July 31, 2025 14:33
Varun Sundar Rabindranath added 2 commits August 1, 2025 06:32
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Varun Sundar Rabindranath <[email protected]>
auto-merge was automatically disabled August 1, 2025 06:33

Head branch was pushed to by a user without write access

@varun-sundar-rabindranath varun-sundar-rabindranath force-pushed the varun/ht-quant-dispatch-ordering branch from 80cb125 to fcf2fe9 Compare August 1, 2025 06:33
@varun-sundar-rabindranath varun-sundar-rabindranath changed the title [Bugfix] [Performance] DeepEPHighThroughput + DeepSeek : Quant and then Dispatch [Bugfix] [Performance] DeepEPHighThroughput + DeepSeek : Quant before Dispatch Aug 1, 2025
@vllm-bot vllm-bot merged commit ac45c44 into vllm-project:main Aug 1, 2025
41 of 44 checks passed
npanpaliya pushed a commit to odh-on-pz/vllm-upstream that referenced this pull request Aug 6, 2025
… Dispatch (vllm-project#21837)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
jinzhen-lin pushed a commit to jinzhen-lin/vllm that referenced this pull request Aug 9, 2025
… Dispatch (vllm-project#21837)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Jinzhen Lin <[email protected]>
noamgat pushed a commit to noamgat/vllm that referenced this pull request Aug 9, 2025
… Dispatch (vllm-project#21837)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Noam Gat <[email protected]>
paulpak58 pushed a commit to paulpak58/vllm that referenced this pull request Aug 13, 2025
… Dispatch (vllm-project#21837)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Paul Pak <[email protected]>
diegocastanibm pushed a commit to diegocastanibm/vllm that referenced this pull request Aug 15, 2025
… Dispatch (vllm-project#21837)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Signed-off-by: Diego-Castan <[email protected]>
epwalsh pushed a commit to epwalsh/vllm that referenced this pull request Aug 28, 2025
… Dispatch (vllm-project#21837)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
zhewenl pushed a commit to zhewenl/vllm that referenced this pull request Aug 28, 2025
… Dispatch (vllm-project#21837)

Signed-off-by: Varun Sundar Rabindranath <[email protected]>
Co-authored-by: Varun Sundar Rabindranath <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants