Skip to content

Commit ba70a5e

Browse files
frank-weifacebook-github-bot
authored andcommitted
reduce the weight loading time
Summary: ATT On GB200, the MOE MXFP4 weight transpose takes quite a long time. Add the cache for weight transpose indices so that the expert weight transpose time can be reduced **20b:** Before: Model loading took 94sec ``` �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:27:08 [default_loader.py:267] Loading weights took 2.83 seconds �[1;36m(EngineCore_0 pid=3397977)�[0;0m INFO 09-01 19:28:41 [gpu_model_runner.py:1977] Model loading took 14.1643 GiB and 94.110470 seconds ``` After: Model loading took 5.9sec ``` �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:43 [default_loader.py:267] Loading weights took 2.54 seconds �[1;36m(EngineCore_0 pid=3005216)�[0;0m INFO 09-02 16:54:47 [gpu_model_runner.py:1977] Model loading took 14.1693 GiB and 5.918206 seconds ``` **120b:** **Loading time verification:** **Before, P1928776629** E2E predictor warm up takes: 17:28:53 ~ 17:39:59 = 11min 6sec Model loading takes 568.133048 seconds ``` (EngineCore_0 pid=344869) INFO 09-02 17:29:45 [default_loader.py:267] Loading weights took 8.25 seconds (EngineCore_0 pid=344869) INFO 09-02 17:39:05 [gpu_model_runner.py:1977] Model loading took 68.7019 GiB and 568.133048 seconds ``` **After, P1928762318** E2E predictor warm up takes: 17:26:12 ~ 17:28:15 = 2min 3sec Model loading takes 15.083996 seconds ``` (EngineCore_0 pid=156514) INFO 09-02 17:27:05 [default_loader.py:267] Loading weights took 9.18 seconds (EngineCore_0 pid=156514) INFO 09-02 17:27:12 [gpu_model_runner.py:1977] Model loading took 68.7093 GiB and 15.083996 seconds ``` **Accuracy verification:** ``` aime25 medium: P1928806083 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-medium_temp1.0_20250902_175112', 'metric': 0.7875}] aime25 high:P1928898566 [{'eval_name': 'aime25', 'model_name': 'gpt-oss-120b-high_temp1.0_20250902_180141', 'metric': 0.9}] ``` Test Plan: Compared the transposed weights and they are matched between before and after. P1928725920 python test_eq.py ``` import torch [g1w, g1s, g1b] = torch.load("/tmp/gemm1_wei.pt") [g1w2, g1s2, g1b2] = torch.load("/tmp/gemm1_wei2.pt") for i in range(len(g1w)): print(i) print(torch.equal(g1w[i], g1w2[i])) print(torch.equal(g1s[i], g1s2[i])) print(torch.equal(g1b[i], g1b2[i])) [g2w, g2s, g2b] = torch.load("/tmp/gemm2_wei.pt") [g2w2, g2s2, g2b2] = torch.load("/tmp/gemm2_wei2.pt") for i in range(len(g2w)): print(i) print(torch.equal(g2w[i], g2w2[i])) print(torch.equal(g2s[i], g2s2[i])) print(torch.equal(g2b[i], g2b2[i])) ``` Rollback Plan: Differential Revision: D81544286
1 parent 426cc86 commit ba70a5e

File tree

1 file changed

+91
-16
lines changed
  • vllm/model_executor/layers/quantization

1 file changed

+91
-16
lines changed

vllm/model_executor/layers/quantization/mxfp4.py

Lines changed: 91 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3-
from typing import Callable, Optional
3+
from typing import Callable, Dict, Optional, Union
44

55
import torch
66
from torch.nn.parameter import Parameter
@@ -122,7 +122,8 @@ def __init__(self, moe: FusedMoEConfig):
122122
"MXFP4 MoE is enabled on Blackwell but FlashInfer "
123123
"is not available. This may result in degraded performance. "
124124
"Please `pip install vllm[flashinfer]` for best results.")
125-
125+
self._cache_permute_indices: Dict[torch.Size, torch.Tensor] = {}
126+
126127
def _should_use_marlin(self):
127128
if envs.VLLM_MXFP4_USE_MARLIN is not None:
128129
return envs.VLLM_MXFP4_USE_MARLIN
@@ -261,12 +262,37 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
261262
)
262263
layer.register_parameter("w2_bias", w2_bias)
263264
set_weight_attrs(w2_bias, extra_weight_attrs)
265+
266+
def _maybe_get_cached_permute_indices(
267+
self,
268+
dst_w_weight: torch.Tensor,
269+
epilogue_tile_m: int,
270+
num_elts_per_sf: Union[None, int] = None,
271+
) -> torch.Tensor:
272+
from flashinfer.utils import get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices
273+
key = self._cache_permute_indices.get(dst_w_weight.shape)
274+
if key is None:
275+
if num_elts_per_sf is None:
276+
permute1 = get_shuffle_matrix_a_row_indices(
277+
dst_w_weight, epilogue_tile_m=epilogue_tile_m
278+
)
279+
else:
280+
permute1 = get_shuffle_matrix_sf_a_row_indices(
281+
dst_w_weight,
282+
epilogue_tile_m=epilogue_tile_m,
283+
num_elts_per_sf=num_elts_per_sf,
284+
)
285+
self._cache_permute_indices[dst_w_weight.shape] = permute1.to(
286+
dst_w_weight.device
287+
)
288+
permute_indices = self._cache_permute_indices[dst_w_weight.shape]
289+
return permute_indices
264290

265291
def process_weights_after_loading(self, layer):
266292
if self.use_marlin:
267293
prepare_moe_fp4_layer_for_marlin(layer)
268294
elif should_use_flashinfer_mxfp4():
269-
from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a
295+
from flashinfer.fp4_quantization import nvfp4_block_scale_interleave
270296
layer.gemm1_alpha = Parameter(torch.tensor(
271297
[1.702] * self.num_experts, dtype=torch.float32).cuda(),
272298
requires_grad=False)
@@ -343,25 +369,74 @@ def swap_every_two_rows(x, axis=-1):
343369
gemm2_bias_shuffled = []
344370
epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
345371
for i in range(self.num_experts):
372+
# w13 weight shuffling
373+
permute_indices = self._maybe_get_cached_permute_indices(
374+
w13_weight[i].view(torch.uint8),
375+
epilogue_tile_m,
376+
)
346377
gemm1_weights_mxfp4_shuffled.append(
347-
shuffle_matrix_a(w13_weight[i].view(torch.uint8),
348-
epilogue_tile_m))
378+
w13_weight[i]
379+
.view(torch.uint8)[permute_indices.to(w13_weight.device)]
380+
.contiguous()
381+
)
382+
# w13 scale shuffling
383+
permute_sf_indices = self._maybe_get_cached_permute_indices(
384+
w13_weight_scale[i].view(torch.uint8),
385+
epilogue_tile_m,
386+
num_elts_per_sf=16,
387+
)
349388
gemm1_scales_mxfp4_shuffled.append(
350-
shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8),
351-
epilogue_tile_m))
389+
nvfp4_block_scale_interleave(
390+
w13_weight_scale[i]
391+
.view(torch.uint8)[
392+
permute_sf_indices.to(w13_weight_scale.device)
393+
]
394+
.contiguous()
395+
)
396+
)
397+
# w13 bias shuffling
398+
permute_bias_indices = self._maybe_get_cached_permute_indices(
399+
w13_bias[i].clone().reshape(-1, 1),
400+
epilogue_tile_m,
401+
)
352402
gemm1_bias_shuffled.append(
353-
shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1),
354-
epilogue_tile_m))
355-
403+
w13_bias[i].clone().reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)]
404+
.contiguous()
405+
)
406+
# w2 weight shuffling
407+
permute_indices = self._maybe_get_cached_permute_indices(
408+
w2_weight[i].view(torch.uint8),
409+
epilogue_tile_m,
410+
)
356411
gemm2_weights_mxfp4_shuffled.append(
357-
shuffle_matrix_a(w2_weight[i].view(torch.uint8),
358-
epilogue_tile_m))
412+
w2_weight[i]
413+
.view(torch.uint8)[permute_indices.to(w2_weight.device)]
414+
.contiguous()
415+
)
416+
# w2 scale shuffling
417+
permute_sf_indices = self._maybe_get_cached_permute_indices(
418+
w2_weight_scale[i].view(torch.uint8),
419+
epilogue_tile_m,
420+
num_elts_per_sf=16,
421+
)
359422
gemm2_scales_mxfp4_shuffled.append(
360-
shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8),
361-
epilogue_tile_m))
423+
nvfp4_block_scale_interleave(
424+
w2_weight_scale[i]
425+
.view(torch.uint8)[
426+
permute_sf_indices.to(w13_weight_scale.device)
427+
]
428+
.contiguous()
429+
)
430+
)
431+
# w2 bias shuffling
432+
permute_indices = self._maybe_get_cached_permute_indices(
433+
w2_bias[i].clone().reshape(-1, 1),
434+
epilogue_tile_m,
435+
)
362436
gemm2_bias_shuffled.append(
363-
shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1),
364-
epilogue_tile_m))
437+
w2_bias[i].clone().reshape(-1, 1)[permute_indices.to(w2_bias.device)]
438+
.contiguous()
439+
)
365440

366441
w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
367442
w13_weight_scale = torch.stack(

0 commit comments

Comments
 (0)