Skip to content

Conversation

Wangzheee
Copy link

@Wangzheee Wangzheee commented Aug 11, 2025

A new FP8 MLA pipeline, referencing FlashMLA and the fp8 PR (#54)

Improvements

Compared with FlashMLA:

-Use WGMMA FP8.
-Use FP8 dtypes for Q and KV.
-Saved shared-memory, and allocated shared-memory for sP1, removed retrieve_rP_for_SP(sQ(8)), removed RS WGMMA of rQ(8)*sK.

Compared with the fp8 implementation (#54):

  • Fine-grained optimization of the pipeline between TMA copy and WGMMA.
  • Rebuild the pipeline for transposed V, using 4 named barriers to switch between 4 buffers (V0L, V0R, V1L, V1R).
  • Use additional shared-memory buffers sP0, sP1, sVt0, and sVt1 for ping-pong
  • Use 128-bit STSM and 128-bit LDSM for mutual copying between rP and sP (because the output rP of Q@K FP8 WGMMA cannot be directly copied to the input sP for the P@V FP8 SS WGMMA layout; therefore, rP is copied to sP, and sP is copied back to rP for the P@V FP8 RS WGMMA).
  • Most importantly, after implementing Q@K at a fine granularity (576/64 = 9 tiles), it becomes possible to compute ROPE in BF16, which resolves the previous accuracy issue.
  • More features, and programming style of SM90

TODO

  • Redesign the output rP of Q@K FP8 WGMMA so it can be copied directly to the input sP for the P@V FP8 SS WGMMA layout.
  • Combine fill_oob_KV operations for each tile (each fill_oob_KV has a warpgroup synchronization).
  • Implement ROPE in BF16.

Performance (H20)

*Performance is limited by warpgroup synchronization of fill_oob_KV
MLA: cache length = 8196; head_num = 64

  64(batch_size=32, MTP=2) 96(batch_size=48, MTP=2) 128(batch_size=64, MTP=2) 256(batch_size=128, MTP=2)
bf16 0.569 ms, 128 TFLOPS 0.819 ms, 134 TFLOPS 1.090 ms, 134 TFLOPS 2.158 ms, 135 TFLOPS
The previous version fp8 #54 0.354 ms, 206 TFLOPS 0.506 ms, 217 TFLOPS 0.671 ms, 218 TFLOPS 1.308 ms, 223 TFLOPS
Now fp8 0.337 ms, 217 TFLOPS 0.485 ms, 226 TFLOPS 0.639 ms, 229 TFLOPS 1.243 ms, 235 TFLOPS
Performance improvement vs bf16 69% vs fp8(PR 54) 5% vs bf16 62% vs fp8(PR 54) 5% vs bf16 62% vs fp8(PR 54) 5% vs bf16 74% vs fp8(PR 54) 5%

@MicroZHY
Copy link

Hi @Wangzheee,

Thank you for the updated FP8 implementation! Could you please briefly summarize:

The key incremental optimizations that give the new FP8 kernel an extra 8–10 % performance boost over the previous PR #54.
Whether any ideas or code snippets were borrowed from MatthewBonanni’s branch feature/fp8_mla_flashmla (https://github.com/MatthewBonanni/vllm/tree/feature/fp8_mla_flashmla).
Looking forward to your insights—much appreciated!

Best regards

@MicroZHY
Copy link

Hi @Wangzheee,

Below are the parameters and results I obtained when running test_fp8_flash_mla.py on an H800:

b = 64, s_q = 2, mean_sk = 4096, h_q = 64, h_kv = 1, d = 576, dv = 512,
causal = True, varlen = False, torch_dtype = torch.float8_e4m3fn

out: cos_diff = 0.0006639208807784902,
     RMSE     = 0.0017831785204429822,
     amax_diff= 0.04801611602306366

lse: cos_diff = 3.260247627423496e-11,
     RMSE     = 7.119073855497076e-05,
     amax_diff= 0.0002155303955078125

0.744 ms, 98 TFLOPS, 221 GB/s

The measured performance appears lower than expected. Could you share your own benchmark results and the corresponding test script so I can compare and identify possible bottlenecks?

@Wangzheee
Copy link
Author

Hi @Wangzheee,

Thank you for the updated FP8 implementation! Could you please briefly summarize:

The key incremental optimizations that give the new FP8 kernel an extra 8–10 % performance boost over the previous PR #54. Whether any ideas or code snippets were borrowed from MatthewBonanni’s branch feature/fp8_mla_flashmla (https://github.com/MatthewBonanni/vllm/tree/feature/fp8_mla_flashmla). Looking forward to your insights—much appreciated!

Best regards
Hi~ Detailed introduction has been added to the description of PR

@Wangzheee
Copy link
Author

Wangzheee commented Aug 19, 2025

Hi @Wangzheee,

Below are the parameters and results I obtained when running test_fp8_flash_mla.py on an H800:

b = 64, s_q = 2, mean_sk = 4096, h_q = 64, h_kv = 1, d = 576, dv = 512,
causal = True, varlen = False, torch_dtype = torch.float8_e4m3fn

out: cos_diff = 0.0006639208807784902,
     RMSE     = 0.0017831785204429822,
     amax_diff= 0.04801611602306366

lse: cos_diff = 3.260247627423496e-11,
     RMSE     = 7.119073855497076e-05,
     amax_diff= 0.0002155303955078125

0.744 ms, 98 TFLOPS, 221 GB/s

The measured performance appears lower than expected. Could you share your own benchmark results and the corresponding test script so I can compare and identify possible bottlenecks?

Hi~ I use the H20:
b=64, s_q=2, mean_sk=4096, h_q=64, h_kv=1, d=576, dv=512, causal=True, varlen=False, torch_dtype=torch.float8_e4m3fn
0.342 ms, 213 TFLOPS, 467 GB/s
For H800, can you compare this PR with other dtype?

@MicroZHY
Copy link

A new FP8 MLA pipeline, referencing FlashMLA and the fp8 PR (#54)

Improvements

Compared with FlashMLA:

-Use WGMMA FP8. -Use FP8 dtypes for Q and KV. -Saved shared-memory, and allocated shared-memory for sP1, removed retrieve_rP_for_SP(sQ(8)), removed RS WGMMA of rQ(8)*sK.

Compared with the fp8 implementation (#54):

  • Fine-grained optimization of the pipeline between TMA copy and WGMMA.
  • Rebuild the pipeline for transposed V, using 4 named barriers to switch between 4 buffers (V0L, V0R, V1L, V1R).
  • Use additional shared-memory buffers sP0, sP1, sVt0, and sVt1 for ping-pong
  • Use 128-bit STSM and 128-bit LDSM for mutual copying between rP and sP (because the output rP of Q@K FP8 WGMMA cannot be directly copied to the input sP for the P@V FP8 SS WGMMA layout; therefore, rP is copied to sP, and sP is copied back to rP for the P@V FP8 RS WGMMA).
  • Most importantly, after implementing Q@K at a fine granularity (576/64 = 9 tiles), it becomes possible to compute ROPE in BF16, which resolves the previous accuracy issue.
  • More features, and programming style of SM90

TODO

  • Redesign the output rP of Q@K FP8 WGMMA so it can be copied directly to the input sP for the P@V FP8 SS WGMMA layout.
  • Combine fill_oob_KV operations for each tile (each fill_oob_KV has a warpgroup synchronization).
  • Implement ROPE in BF16.

Performance (H20)

*Performance is limited by warpgroup synchronization of fill_oob_KV MLA: cache length = 8196; head_num = 64

  64(batch_size=32, MTP=2) 96(batch_size=48, MTP=2) 128(batch_size=64, MTP=2) 256(batch_size=128, MTP=2)
bf16 0.569 ms, 128 TFLOPS 0.819 ms, 134 TFLOPS 1.090 ms, 134 TFLOPS 2.158 ms, 135 TFLOPS
The previous version fp8 #54 0.354 ms, 206 TFLOPS 0.506 ms, 217 TFLOPS 0.671 ms, 218 TFLOPS 1.308 ms, 223 TFLOPS
Now fp8 0.337 ms, 217 TFLOPS 0.485 ms, 226 TFLOPS 0.639 ms, 229 TFLOPS 1.243 ms, 235 TFLOPS
Performance improvement vs bf16 69% vs fp8(PR 54) 5% vs bf16 62% vs fp8(PR 54) 5% vs bf16 62% vs fp8(PR 54) 5% vs bf16 74% vs fp8(PR 54) 5%

Thank you for the detailed description in the PR. If I understand correctly, you plan to implement RoPE in BF16. Is there any theoretical justification that implementing RoPE in BF16 can improve precision, or is RoPE inherently more sensitive to precision?

@Fengming-Zhang
Copy link

Hi @Wangzheee,
I follow the config and run python3 test_flash_mla_sm90.py --dtype=float8_e4m3fn on H20. However I didn't get the same profits as you mentioned.
image
image
As the result above, it has equal performance but lower precision, so I'm confused of the purpose of this patch.
I'm not sure if I made any mistake?
torch version=2.6.0 and cuda=12.8

Thanks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants