-
Notifications
You must be signed in to change notification settings - Fork 899
update fp8 support #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hi @Wangzheee, Thank you for the updated FP8 implementation! Could you please briefly summarize: The key incremental optimizations that give the new FP8 kernel an extra 8–10 % performance boost over the previous PR #54. Best regards |
Hi @Wangzheee, Below are the parameters and results I obtained when running test_fp8_flash_mla.py on an H800:
The measured performance appears lower than expected. Could you share your own benchmark results and the corresponding test script so I can compare and identify possible bottlenecks? |
|
Hi~ I use the H20: |
Thank you for the detailed description in the PR. If I understand correctly, you plan to implement RoPE in BF16. Is there any theoretical justification that implementing RoPE in BF16 can improve precision, or is RoPE inherently more sensitive to precision? |
Hi @Wangzheee, Thanks. |
A new FP8 MLA pipeline, referencing FlashMLA and the fp8 PR (#54)
Improvements
Compared with FlashMLA:
-Use WGMMA FP8.
-Use FP8 dtypes for Q and KV.
-Saved shared-memory, and allocated shared-memory for sP1, removed retrieve_rP_for_SP(sQ(8)), removed RS WGMMA of rQ(8)*sK.
Compared with the fp8 implementation (#54):
TODO
Performance (H20)
*Performance is limited by warpgroup synchronization of fill_oob_KV
MLA: cache length = 8196; head_num = 64