|
1 | 1 | # SPDX-License-Identifier: Apache-2.0
|
2 | 2 | # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
3 |
| -from typing import Callable, Optional |
| 3 | +from typing import Callable, Dict, Optional, Union |
4 | 4 |
|
5 | 5 | import torch
|
6 | 6 | from torch.nn.parameter import Parameter
|
@@ -122,7 +122,8 @@ def __init__(self, moe: FusedMoEConfig):
|
122 | 122 | "MXFP4 MoE is enabled on Blackwell but FlashInfer "
|
123 | 123 | "is not available. This may result in degraded performance. "
|
124 | 124 | "Please `pip install vllm[flashinfer]` for best results.")
|
125 |
| - |
| 125 | + self._cache_permute_indices: Dict[torch.Size, torch.Tensor] = {} |
| 126 | + |
126 | 127 | def _should_use_marlin(self):
|
127 | 128 | if envs.VLLM_MXFP4_USE_MARLIN is not None:
|
128 | 129 | return envs.VLLM_MXFP4_USE_MARLIN
|
@@ -261,12 +262,37 @@ def create_weights(self, layer: torch.nn.Module, num_experts: int,
|
261 | 262 | )
|
262 | 263 | layer.register_parameter("w2_bias", w2_bias)
|
263 | 264 | set_weight_attrs(w2_bias, extra_weight_attrs)
|
| 265 | + |
| 266 | + def _maybe_get_cached_permute_indices( |
| 267 | + self, |
| 268 | + dst_w_weight: torch.Tensor, |
| 269 | + epilogue_tile_m: int, |
| 270 | + num_elts_per_sf: Union[None, int] = None, |
| 271 | + ) -> torch.Tensor: |
| 272 | + from flashinfer.utils import get_shuffle_matrix_a_row_indices, get_shuffle_matrix_sf_a_row_indices |
| 273 | + key = self._cache_permute_indices.get(dst_w_weight.shape) |
| 274 | + if key is None: |
| 275 | + if num_elts_per_sf is None: |
| 276 | + permute1 = get_shuffle_matrix_a_row_indices( |
| 277 | + dst_w_weight, epilogue_tile_m=epilogue_tile_m |
| 278 | + ) |
| 279 | + else: |
| 280 | + permute1 = get_shuffle_matrix_sf_a_row_indices( |
| 281 | + dst_w_weight, |
| 282 | + epilogue_tile_m=epilogue_tile_m, |
| 283 | + num_elts_per_sf=num_elts_per_sf, |
| 284 | + ) |
| 285 | + self._cache_permute_indices[dst_w_weight.shape] = permute1.to( |
| 286 | + dst_w_weight.device |
| 287 | + ) |
| 288 | + permute_indices = self._cache_permute_indices[dst_w_weight.shape] |
| 289 | + return permute_indices |
264 | 290 |
|
265 | 291 | def process_weights_after_loading(self, layer):
|
266 | 292 | if self.use_marlin:
|
267 | 293 | prepare_moe_fp4_layer_for_marlin(layer)
|
268 | 294 | elif should_use_flashinfer_mxfp4():
|
269 |
| - from flashinfer import shuffle_matrix_a, shuffle_matrix_sf_a |
| 295 | + from flashinfer.fp4_quantization import nvfp4_block_scale_interleave |
270 | 296 | layer.gemm1_alpha = Parameter(torch.tensor(
|
271 | 297 | [1.702] * self.num_experts, dtype=torch.float32).cuda(),
|
272 | 298 | requires_grad=False)
|
@@ -343,25 +369,74 @@ def swap_every_two_rows(x, axis=-1):
|
343 | 369 | gemm2_bias_shuffled = []
|
344 | 370 | epilogue_tile_m = 128 # FIXME: this depends on the kernel internals
|
345 | 371 | for i in range(self.num_experts):
|
| 372 | + # w13 weight shuffling |
| 373 | + permute_indices = self._maybe_get_cached_permute_indices( |
| 374 | + w13_weight[i].view(torch.uint8), |
| 375 | + epilogue_tile_m, |
| 376 | + ) |
346 | 377 | gemm1_weights_mxfp4_shuffled.append(
|
347 |
| - shuffle_matrix_a(w13_weight[i].view(torch.uint8), |
348 |
| - epilogue_tile_m)) |
| 378 | + w13_weight[i] |
| 379 | + .view(torch.uint8)[permute_indices.to(w13_weight.device)] |
| 380 | + .contiguous() |
| 381 | + ) |
| 382 | + # w13 scale shuffling |
| 383 | + permute_sf_indices = self._maybe_get_cached_permute_indices( |
| 384 | + w13_weight_scale[i].view(torch.uint8), |
| 385 | + epilogue_tile_m, |
| 386 | + num_elts_per_sf=16, |
| 387 | + ) |
349 | 388 | gemm1_scales_mxfp4_shuffled.append(
|
350 |
| - shuffle_matrix_sf_a(w13_weight_scale[i].view(torch.uint8), |
351 |
| - epilogue_tile_m)) |
| 389 | + nvfp4_block_scale_interleave( |
| 390 | + w13_weight_scale[i] |
| 391 | + .view(torch.uint8)[ |
| 392 | + permute_sf_indices.to(w13_weight_scale.device) |
| 393 | + ] |
| 394 | + .contiguous() |
| 395 | + ) |
| 396 | + ) |
| 397 | + # w13 bias shuffling |
| 398 | + permute_bias_indices = self._maybe_get_cached_permute_indices( |
| 399 | + w13_bias[i].clone().reshape(-1, 1), |
| 400 | + epilogue_tile_m, |
| 401 | + ) |
352 | 402 | gemm1_bias_shuffled.append(
|
353 |
| - shuffle_matrix_a(w13_bias[i].clone().reshape(-1, 1), |
354 |
| - epilogue_tile_m)) |
355 |
| - |
| 403 | + w13_bias[i].clone().reshape(-1, 1)[permute_bias_indices.to(w13_bias.device)] |
| 404 | + .contiguous() |
| 405 | + ) |
| 406 | + # w2 weight shuffling |
| 407 | + permute_indices = self._maybe_get_cached_permute_indices( |
| 408 | + w2_weight[i].view(torch.uint8), |
| 409 | + epilogue_tile_m, |
| 410 | + ) |
356 | 411 | gemm2_weights_mxfp4_shuffled.append(
|
357 |
| - shuffle_matrix_a(w2_weight[i].view(torch.uint8), |
358 |
| - epilogue_tile_m)) |
| 412 | + w2_weight[i] |
| 413 | + .view(torch.uint8)[permute_indices.to(w2_weight.device)] |
| 414 | + .contiguous() |
| 415 | + ) |
| 416 | + # w2 scale shuffling |
| 417 | + permute_sf_indices = self._maybe_get_cached_permute_indices( |
| 418 | + w2_weight_scale[i].view(torch.uint8), |
| 419 | + epilogue_tile_m, |
| 420 | + num_elts_per_sf=16, |
| 421 | + ) |
359 | 422 | gemm2_scales_mxfp4_shuffled.append(
|
360 |
| - shuffle_matrix_sf_a(w2_weight_scale[i].view(torch.uint8), |
361 |
| - epilogue_tile_m)) |
| 423 | + nvfp4_block_scale_interleave( |
| 424 | + w2_weight_scale[i] |
| 425 | + .view(torch.uint8)[ |
| 426 | + permute_sf_indices.to(w13_weight_scale.device) |
| 427 | + ] |
| 428 | + .contiguous() |
| 429 | + ) |
| 430 | + ) |
| 431 | + # w2 bias shuffling |
| 432 | + permute_indices = self._maybe_get_cached_permute_indices( |
| 433 | + w2_bias[i].clone().reshape(-1, 1), |
| 434 | + epilogue_tile_m, |
| 435 | + ) |
362 | 436 | gemm2_bias_shuffled.append(
|
363 |
| - shuffle_matrix_a(w2_bias[i].clone().reshape(-1, 1), |
364 |
| - epilogue_tile_m)) |
| 437 | + w2_bias[i].clone().reshape(-1, 1)[permute_indices.to(w2_bias.device)] |
| 438 | + .contiguous() |
| 439 | + ) |
365 | 440 |
|
366 | 441 | w13_weight = torch.stack(gemm1_weights_mxfp4_shuffled)
|
367 | 442 | w13_weight_scale = torch.stack(
|
|
0 commit comments