-
-
Notifications
You must be signed in to change notification settings - Fork 10.4k
[gpt-oss] Cache permute indices for faster MXFP4 MoE layer loading #24154
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This pull request was exported from Phabricator. Differential Revision: D81544286 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a caching mechanism for permutation indices to accelerate weight loading in MoE layers, which is a valuable optimization that demonstrates significant performance gains. The overall implementation is solid, but I've identified a critical bug where an incorrect device is used for a tensor operation, which could lead to runtime errors or incorrect behavior.
w2_weight_scale[i] | ||
.view(torch.uint8)[ | ||
permute_sf_indices.to(w13_weight_scale.device) | ||
] | ||
.contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There appears to be a typo on line 426. The device should be w2_weight_scale.device
instead of w13_weight_scale.device
. While this might work if both tensors are on the same device, it is safer and more correct to use the device of the tensor being processed to avoid potential runtime errors or incorrect behavior.
w2_weight_scale[i] | |
.view(torch.uint8)[ | |
permute_sf_indices.to(w13_weight_scale.device) | |
] | |
.contiguous() | |
w2_weight_scale[i] | |
.view(torch.uint8)[ | |
permute_sf_indices.to(w2_weight_scale.device) | |
] | |
.contiguous() |
This pull request has merge conflicts that must be resolved before it can be |
ba70a5e
to
f01f213
Compare
Summary: ATT On GB200, the MOE MXFP4 weight transpose takes quite a long time. Add the cache for weight transpose indices so that the expert weight transpose time can be reduced **20b:** Before: Model loading took 94sec ``` �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:27:08 [default_loader.py:267] Loading weights took 2.83 seconds �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:28:41 [gpu_model_runner.py:1977] Model loading took 14.1643 GiB and 94.110470 seconds ``` After: Model loading took 5.9sec ``` �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:43 [default_loader.py:267] Loading weights took 2.54 seconds �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:47 [gpu_model_runner.py:1977] Model loading took 14.1693 GiB and 5.918206 seconds ``` **120b:** **Loading time verification:** **Before, P1928776629** E2E predictor warm up takes: 17:28:53 ~ 17:39:59 = 11min 6sec Model loading takes 568.133048 seconds ``` (EngineCore_0 pid=344869) INFO 09-02 17:29:45 [default_loader.py:267] Loading weights took 8.25 seconds (EngineCore_0 pid=344869) INFO 09-02 17:39:05 [gpu_model_runner.py:1977] Model loading took 68.7019 GiB and 568.133048 seconds ``` **After, P1928762318** E2E predictor warm up takes: 17:26:12 ~ 17:28:15 = 2min 3sec Model loading takes 15.083996 seconds ``` (EngineCore_0 pid=156514) INFO 09-02 17:27:05 [default_loader.py:267] Loading weights took 9.18 seconds (EngineCore_0 pid=156514) INFO 09-02 17:27:12 [gpu_model_runner.py:1977] Model loading took 68.7093 GiB and 15.083996 seconds ``` **Accuracy verification:** ``` aime25 medium: P1928806083 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-medium_temp1.0_20250902_175112', 'metric': 0.7875}] aime25 high:P1928898566 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-high_temp1.0_20250902_180141', 'metric': 0.9}] ``` Test Plan: Compared the transposed weights and they are matched between before and after. P1928725920 python test_eq.py ``` import torch [g1w, g1s, g1b] = torch.load("/tmp/gemm1_wei.pt") [g1w2, g1s2, g1b2] = torch.load("/tmp/gemm1_wei2.pt") for i in range(len(g1w)): print(i) print(torch.equal(g1w[i], g1w2[i])) print(torch.equal(g1s[i], g1s2[i])) print(torch.equal(g1b[i], g1b2[i])) [g2w, g2s, g2b] = torch.load("/tmp/gemm2_wei.pt") [g2w2, g2s2, g2b2] = torch.load("/tmp/gemm2_wei2.pt") for i in range(len(g2w)): print(i) print(torch.equal(g2w[i], g2w2[i])) print(torch.equal(g2s[i], g2s2[i])) print(torch.equal(g2b[i], g2b2[i])) ``` Rollback Plan: Reviewed By: zixi-qi Differential Revision: D81544286
This pull request was exported from Phabricator. Differential Revision: D81544286 |
This pull request was exported from Phabricator. Differential Revision: D81544286 |
Summary: Pull Request resolved: vllm-project#24154 ATT On GB200, the MOE MXFP4 weight transpose takes quite a long time. Add the cache for weight transpose indices so that the expert weight transpose time can be reduced **20b:** Before: Model loading took 94sec ``` �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:27:08 [default_loader.py:267] Loading weights took 2.83 seconds �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:28:41 [gpu_model_runner.py:1977] Model loading took 14.1643 GiB and 94.110470 seconds ``` After: Model loading took 5.9sec ``` �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:43 [default_loader.py:267] Loading weights took 2.54 seconds �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:47 [gpu_model_runner.py:1977] Model loading took 14.1693 GiB and 5.918206 seconds ``` **120b:** **Loading time verification:** **Before, P1928776629** E2E predictor warm up takes: 17:28:53 ~ 17:39:59 = 11min 6sec Model loading takes 568.133048 seconds ``` (EngineCore_0 pid=344869) INFO 09-02 17:29:45 [default_loader.py:267] Loading weights took 8.25 seconds (EngineCore_0 pid=344869) INFO 09-02 17:39:05 [gpu_model_runner.py:1977] Model loading took 68.7019 GiB and 568.133048 seconds ``` **After, P1928762318** E2E predictor warm up takes: 17:26:12 ~ 17:28:15 = 2min 3sec Model loading takes 15.083996 seconds ``` (EngineCore_0 pid=156514) INFO 09-02 17:27:05 [default_loader.py:267] Loading weights took 9.18 seconds (EngineCore_0 pid=156514) INFO 09-02 17:27:12 [gpu_model_runner.py:1977] Model loading took 68.7093 GiB and 15.083996 seconds ``` **Accuracy verification:** ``` aime25 medium: P1928806083 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-medium_temp1.0_20250902_175112', 'metric': 0.7875}] aime25 high:P1928898566 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-high_temp1.0_20250902_180141', 'metric': 0.9}] ``` Test Plan: Compared the transposed weights and they are matched between before and after. P1928725920 python test_eq.py ``` import torch [g1w, g1s, g1b] = torch.load("/tmp/gemm1_wei.pt") [g1w2, g1s2, g1b2] = torch.load("/tmp/gemm1_wei2.pt") for i in range(len(g1w)): print(i) print(torch.equal(g1w[i], g1w2[i])) print(torch.equal(g1s[i], g1s2[i])) print(torch.equal(g1b[i], g1b2[i])) [g2w, g2s, g2b] = torch.load("/tmp/gemm2_wei.pt") [g2w2, g2s2, g2b2] = torch.load("/tmp/gemm2_wei2.pt") for i in range(len(g2w)): print(i) print(torch.equal(g2w[i], g2w2[i])) print(torch.equal(g2s[i], g2s2[i])) print(torch.equal(g2b[i], g2b2[i])) ``` Rollback Plan: Reviewed By: zixi-qi Differential Revision: D81544286
f01f213
to
00b87a0
Compare
Summary: ATT On GB200, the MOE MXFP4 weight transpose takes quite a long time. Add the cache for weight transpose indices so that the expert weight transpose time can be reduced **20b:** Before: Model loading took 94sec ``` �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:27:08 [default_loader.py:267] Loading weights took 2.83 seconds �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:28:41 [gpu_model_runner.py:1977] Model loading took 14.1643 GiB and 94.110470 seconds ``` After: Model loading took 5.9sec ``` �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:43 [default_loader.py:267] Loading weights took 2.54 seconds �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:47 [gpu_model_runner.py:1977] Model loading took 14.1693 GiB and 5.918206 seconds ``` **120b:** **Loading time verification:** **Before, P1928776629** E2E predictor warm up takes: 17:28:53 ~ 17:39:59 = 11min 6sec Model loading takes 568.133048 seconds ``` (EngineCore_0 pid=344869) INFO 09-02 17:29:45 [default_loader.py:267] Loading weights took 8.25 seconds (EngineCore_0 pid=344869) INFO 09-02 17:39:05 [gpu_model_runner.py:1977] Model loading took 68.7019 GiB and 568.133048 seconds ``` **After, P1928762318** E2E predictor warm up takes: 17:26:12 ~ 17:28:15 = 2min 3sec Model loading takes 15.083996 seconds ``` (EngineCore_0 pid=156514) INFO 09-02 17:27:05 [default_loader.py:267] Loading weights took 9.18 seconds (EngineCore_0 pid=156514) INFO 09-02 17:27:12 [gpu_model_runner.py:1977] Model loading took 68.7093 GiB and 15.083996 seconds ``` **Accuracy verification:** ``` aime25 medium: P1928806083 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-medium_temp1.0_20250902_175112', 'metric': 0.7875}] aime25 high:P1928898566 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-high_temp1.0_20250902_180141', 'metric': 0.9}] ``` Test Plan: Compared the transposed weights and they are matched between before and after. P1928725920 python test_eq.py ``` import torch [g1w, g1s, g1b] = torch.load("/tmp/gemm1_wei.pt") [g1w2, g1s2, g1b2] = torch.load("/tmp/gemm1_wei2.pt") for i in range(len(g1w)): print(i) print(torch.equal(g1w[i], g1w2[i])) print(torch.equal(g1s[i], g1s2[i])) print(torch.equal(g1b[i], g1b2[i])) [g2w, g2s, g2b] = torch.load("/tmp/gemm2_wei.pt") [g2w2, g2s2, g2b2] = torch.load("/tmp/gemm2_wei2.pt") for i in range(len(g2w)): print(i) print(torch.equal(g2w[i], g2w2[i])) print(torch.equal(g2s[i], g2s2[i])) print(torch.equal(g2b[i], g2b2[i])) ``` Rollback Plan: Reviewed By: zixi-qi Differential Revision: D81544286
00b87a0
to
41d88d2
Compare
This pull request was exported from Phabricator. Differential Revision: D81544286 |
Summary: Pull Request resolved: vllm-project#24154 ATT On GB200, the MOE MXFP4 weight transpose takes quite a long time. Add the cache for weight transpose indices so that the expert weight transpose time can be reduced **20b:** Before: Model loading took 94sec ``` �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:27:08 [default_loader.py:267] Loading weights took 2.83 seconds �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:28:41 [gpu_model_runner.py:1977] Model loading took 14.1643 GiB and 94.110470 seconds ``` After: Model loading took 5.9sec ``` �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:43 [default_loader.py:267] Loading weights took 2.54 seconds �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:47 [gpu_model_runner.py:1977] Model loading took 14.1693 GiB and 5.918206 seconds ``` **120b:** **Loading time verification:** **Before, P1928776629** E2E predictor warm up takes: 17:28:53 ~ 17:39:59 = 11min 6sec Model loading takes 568.133048 seconds ``` (EngineCore_0 pid=344869) INFO 09-02 17:29:45 [default_loader.py:267] Loading weights took 8.25 seconds (EngineCore_0 pid=344869) INFO 09-02 17:39:05 [gpu_model_runner.py:1977] Model loading took 68.7019 GiB and 568.133048 seconds ``` **After, P1928762318** E2E predictor warm up takes: 17:26:12 ~ 17:28:15 = 2min 3sec Model loading takes 15.083996 seconds ``` (EngineCore_0 pid=156514) INFO 09-02 17:27:05 [default_loader.py:267] Loading weights took 9.18 seconds (EngineCore_0 pid=156514) INFO 09-02 17:27:12 [gpu_model_runner.py:1977] Model loading took 68.7093 GiB and 15.083996 seconds ``` **Accuracy verification:** ``` aime25 medium: P1928806083 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-medium_temp1.0_20250902_175112', 'metric': 0.7875}] aime25 high:P1928898566 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-high_temp1.0_20250902_180141', 'metric': 0.9}] ``` Test Plan: Compared the transposed weights and they are matched between before and after. P1928725920 python test_eq.py ``` import torch [g1w, g1s, g1b] = torch.load("/tmp/gemm1_wei.pt") [g1w2, g1s2, g1b2] = torch.load("/tmp/gemm1_wei2.pt") for i in range(len(g1w)): print(i) print(torch.equal(g1w[i], g1w2[i])) print(torch.equal(g1s[i], g1s2[i])) print(torch.equal(g1b[i], g1b2[i])) [g2w, g2s, g2b] = torch.load("/tmp/gemm2_wei.pt") [g2w2, g2s2, g2b2] = torch.load("/tmp/gemm2_wei2.pt") for i in range(len(g2w)): print(i) print(torch.equal(g2w[i], g2w2[i])) print(torch.equal(g2s[i], g2s2[i])) print(torch.equal(g2b[i], g2b2[i])) ``` Rollback Plan: Reviewed By: zixi-qi Differential Revision: D81544286
41d88d2
to
d28ec0e
Compare
This pull request was exported from Phabricator. Differential Revision: D81544286 |
Summary: Pull Request resolved: vllm-project#24154 ATT On GB200, the MOE MXFP4 weight transpose takes quite a long time. Add the cache for weight transpose indices so that the expert weight transpose time can be reduced **20b:** Before: Model loading took 94sec ``` �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:27:08 [default_loader.py:267] Loading weights took 2.83 seconds �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:28:41 [gpu_model_runner.py:1977] Model loading took 14.1643 GiB and 94.110470 seconds ``` After: Model loading took 5.9sec ``` �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:43 [default_loader.py:267] Loading weights took 2.54 seconds �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:47 [gpu_model_runner.py:1977] Model loading took 14.1693 GiB and 5.918206 seconds ``` **120b:** **Loading time verification:** **Before, P1928776629** E2E predictor warm up takes: 17:28:53 ~ 17:39:59 = 11min 6sec Model loading takes 568.133048 seconds ``` (EngineCore_0 pid=344869) INFO 09-02 17:29:45 [default_loader.py:267] Loading weights took 8.25 seconds (EngineCore_0 pid=344869) INFO 09-02 17:39:05 [gpu_model_runner.py:1977] Model loading took 68.7019 GiB and 568.133048 seconds ``` **After, P1928762318** E2E predictor warm up takes: 17:26:12 ~ 17:28:15 = 2min 3sec Model loading takes 15.083996 seconds ``` (EngineCore_0 pid=156514) INFO 09-02 17:27:05 [default_loader.py:267] Loading weights took 9.18 seconds (EngineCore_0 pid=156514) INFO 09-02 17:27:12 [gpu_model_runner.py:1977] Model loading took 68.7093 GiB and 15.083996 seconds ``` **Accuracy verification:** ``` aime25 medium: P1928806083 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-medium_temp1.0_20250902_175112', 'metric': 0.7875}] aime25 high:P1928898566 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-high_temp1.0_20250902_180141', 'metric': 0.9}] ``` Test Plan: Compared the transposed weights and they are matched between before and after. P1928725920 python test_eq.py ``` import torch [g1w, g1s, g1b] = torch.load("/tmp/gemm1_wei.pt") [g1w2, g1s2, g1b2] = torch.load("/tmp/gemm1_wei2.pt") for i in range(len(g1w)): print(i) print(torch.equal(g1w[i], g1w2[i])) print(torch.equal(g1s[i], g1s2[i])) print(torch.equal(g1b[i], g1b2[i])) [g2w, g2s, g2b] = torch.load("/tmp/gemm2_wei.pt") [g2w2, g2s2, g2b2] = torch.load("/tmp/gemm2_wei2.pt") for i in range(len(g2w)): print(i) print(torch.equal(g2w[i], g2w2[i])) print(torch.equal(g2s[i], g2s2[i])) print(torch.equal(g2b[i], g2b2[i])) ``` Rollback Plan: Reviewed By: zixi-qi Differential Revision: D81544286
d28ec0e
to
80d85d5
Compare
Summary: Pull Request resolved: vllm-project#24154 ATT On GB200, the MOE MXFP4 weight transpose takes quite a long time. Add the cache for weight transpose indices so that the expert weight transpose time can be reduced **20b:** Before: Model loading took 94sec ``` �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:27:08 [default_loader.py:267] Loading weights took 2.83 seconds �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:28:41 [gpu_model_runner.py:1977] Model loading took 14.1643 GiB and 94.110470 seconds ``` After: Model loading took 5.9sec ``` �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:43 [default_loader.py:267] Loading weights took 2.54 seconds �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:47 [gpu_model_runner.py:1977] Model loading took 14.1693 GiB and 5.918206 seconds ``` **120b:** **Loading time verification:** **Before, P1928776629** E2E predictor warm up takes: 17:28:53 ~ 17:39:59 = 11min 6sec Model loading takes 568.133048 seconds ``` (EngineCore_0 pid=344869) INFO 09-02 17:29:45 [default_loader.py:267] Loading weights took 8.25 seconds (EngineCore_0 pid=344869) INFO 09-02 17:39:05 [gpu_model_runner.py:1977] Model loading took 68.7019 GiB and 568.133048 seconds ``` **After, P1928762318** E2E predictor warm up takes: 17:26:12 ~ 17:28:15 = 2min 3sec Model loading takes 15.083996 seconds ``` (EngineCore_0 pid=156514) INFO 09-02 17:27:05 [default_loader.py:267] Loading weights took 9.18 seconds (EngineCore_0 pid=156514) INFO 09-02 17:27:12 [gpu_model_runner.py:1977] Model loading took 68.7093 GiB and 15.083996 seconds ``` **Accuracy verification:** ``` aime25 medium: P1928806083 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-medium_temp1.0_20250902_175112', 'metric': 0.7875}] aime25 high:P1928898566 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-high_temp1.0_20250902_180141', 'metric': 0.9}] ``` Test Plan: Compared the transposed weights and they are matched between before and after. P1928725920 python test_eq.py ``` import torch [g1w, g1s, g1b] = torch.load("/tmp/gemm1_wei.pt") [g1w2, g1s2, g1b2] = torch.load("/tmp/gemm1_wei2.pt") for i in range(len(g1w)): print(i) print(torch.equal(g1w[i], g1w2[i])) print(torch.equal(g1s[i], g1s2[i])) print(torch.equal(g1b[i], g1b2[i])) [g2w, g2s, g2b] = torch.load("/tmp/gemm2_wei.pt") [g2w2, g2s2, g2b2] = torch.load("/tmp/gemm2_wei2.pt") for i in range(len(g2w)): print(i) print(torch.equal(g2w[i], g2w2[i])) print(torch.equal(g2s[i], g2s2[i])) print(torch.equal(g2b[i], g2b2[i])) ``` Rollback Plan: Reviewed By: zixi-qi Differential Revision: D81544286 Signed-off-by: Wei Wei <[email protected]>
80d85d5
to
44d4e65
Compare
Signed-off-by: Wei Wei <[email protected]>
Summary: Pull Request resolved: vllm-project#24154 ATT On GB200, the MOE MXFP4 weight transpose takes quite a long time. Add the cache for weight transpose indices so that the expert weight transpose time can be reduced **20b:** Before: Model loading took 94sec ``` �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:27:08 [default_loader.py:267] Loading weights took 2.83 seconds �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:28:41 [gpu_model_runner.py:1977] Model loading took 14.1643 GiB and 94.110470 seconds ``` After: Model loading took 5.9sec ``` �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:43 [default_loader.py:267] Loading weights took 2.54 seconds �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:47 [gpu_model_runner.py:1977] Model loading took 14.1693 GiB and 5.918206 seconds ``` **120b:** **Loading time verification:** **Before, P1928776629** E2E predictor warm up takes: 17:28:53 ~ 17:39:59 = 11min 6sec Model loading takes 568.133048 seconds ``` (EngineCore_0 pid=344869) INFO 09-02 17:29:45 [default_loader.py:267] Loading weights took 8.25 seconds (EngineCore_0 pid=344869) INFO 09-02 17:39:05 [gpu_model_runner.py:1977] Model loading took 68.7019 GiB and 568.133048 seconds ``` **After, P1928762318** E2E predictor warm up takes: 17:26:12 ~ 17:28:15 = 2min 3sec Model loading takes 15.083996 seconds ``` (EngineCore_0 pid=156514) INFO 09-02 17:27:05 [default_loader.py:267] Loading weights took 9.18 seconds (EngineCore_0 pid=156514) INFO 09-02 17:27:12 [gpu_model_runner.py:1977] Model loading took 68.7093 GiB and 15.083996 seconds ``` **Accuracy verification:** ``` aime25 medium: P1928806083 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-medium_temp1.0_20250902_175112', 'metric': 0.7875}] aime25 high:P1928898566 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-high_temp1.0_20250902_180141', 'metric': 0.9}] ``` Test Plan: Compared the transposed weights and they are matched between before and after. P1928725920 python test_eq.py ``` import torch [g1w, g1s, g1b] = torch.load("/tmp/gemm1_wei.pt") [g1w2, g1s2, g1b2] = torch.load("/tmp/gemm1_wei2.pt") for i in range(len(g1w)): print(i) print(torch.equal(g1w[i], g1w2[i])) print(torch.equal(g1s[i], g1s2[i])) print(torch.equal(g1b[i], g1b2[i])) [g2w, g2s, g2b] = torch.load("/tmp/gemm2_wei.pt") [g2w2, g2s2, g2b2] = torch.load("/tmp/gemm2_wei2.pt") for i in range(len(g2w)): print(i) print(torch.equal(g2w[i], g2w2[i])) print(torch.equal(g2s[i], g2s2[i])) print(torch.equal(g2b[i], g2b2[i])) ``` Rollback Plan: Reviewed By: zixi-qi Differential Revision: D81544286 Signed-off-by: Wei Wei <[email protected]>
Look good! Let's add a unit test? |
Signed-off-by: Wei Wei <[email protected]>
Signed-off-by: Wei Wei <[email protected]>
Signed-off-by: Wei Wei <[email protected]>
thanks for the change! this is huge! could you update your PR title and replace internal pastebin with gist? :) |
Signed-off-by: Wei Wei <[email protected]>
Thanks @yeqcharlotte and @22quinn for the review. I have updated this PR as suggested. |
Great fix! Just curious, do we know why this issue is so much more noticeable on GB200 than other GPUs? It seems like this improvement is backend agnostic. |
@jwfromm , the issue is raised during enabling the mxfp4 on MOE weight. AFAIK, only Blackwell supports this format. |
…llm-project#24154) Signed-off-by: Wei Wei <[email protected]>
…llm-project#24154) Signed-off-by: Wei Wei <[email protected]> Signed-off-by: rogeryoungh <[email protected]>
…llm-project#24154) Signed-off-by: Wei Wei <[email protected]> Signed-off-by: bruceszchen <[email protected]>
…llm-project#24154) Signed-off-by: Wei Wei <[email protected]> Signed-off-by: bruceszchen <[email protected]>
…llm-project#24154) Signed-off-by: Wei Wei <[email protected]>
…llm-project#24154) Signed-off-by: Wei Wei <[email protected]>
Summary:
On GB200, the MOE MXFP4 weight transpose takes quite a long time when the gpt-oss model is loaded.
Add the cache for weight transpose indices so that the expert weight transpose time can be reduced
20b:
Before: Model loading took 94sec
After: Model loading took 5.9sec
120b:
Loading time verification:
Before, P1928776629
E2E predictor warm up takes: 17:28:53 ~ 17:39:59 = 11min 6sec
Model loading takes 568.133048 seconds
After, P1928762318
E2E predictor warm up takes: 17:26:12 ~ 17:28:15 = 2min 3sec
Model loading takes 15.083996 seconds
Accuracy verification:
Test Plan:
Compared the transposed weights and they are matched between before and after. [link]
python test_eq.py
Rollback Plan:
Differential Revision: D81544286