Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions vllm/model_executor/layers/fused_moe/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,6 @@ def use_deepep_ll_kernels(self):
return (self.use_all2all_kernels
and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")

@property
def use_flashinfer_cutlass_kernels(self):
return (envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")

@staticmethod
def make(tp_size_: int, dp_size_: int,
vllm_parallel_config: ParallelConfig) -> "FusedMoEParallelConfig":
Expand Down Expand Up @@ -404,7 +398,14 @@ def use_deepep_ll_kernels(self):

@property
def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
"""
Whether to use FlashInfer cutlass kernels for NVFP4 MoE.
"""
return (self.quant_config is not None
and self.quant_config.quant_dtype == "nvfp4"
and envs.VLLM_USE_FLASHINFER_MOE_FP4
and has_flashinfer_cutlass_fused_moe()
and envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput")

@staticmethod
def make(
Expand Down
8 changes: 4 additions & 4 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -920,7 +920,7 @@ def __init__(
self.batched_router_logits: Optional[torch.Tensor] = None
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or self.moe_parallel_config.use_flashinfer_cutlass_kernels):
or self.moe_config.use_flashinfer_cutlass_kernels):
self.batched_hidden_states = torch.zeros(
(moe.max_num_tokens, self.hidden_size),
dtype=moe.in_dtype,
Expand Down Expand Up @@ -974,7 +974,7 @@ def use_deepep_ll_kernels(self):

@property
def use_flashinfer_cutlass_kernels(self):
return self.moe_parallel_config.use_flashinfer_cutlass_kernels
return self.moe_config.use_flashinfer_cutlass_kernels

def update_expert_map(self):
# ep_size and ep_rank should already be updated
Expand Down Expand Up @@ -1665,7 +1665,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
# only when data parallelism (DP) is enabled.
use_flashinfer_cutlass_kernels = (
self.dp_size > 1
and self.moe_parallel_config.use_flashinfer_cutlass_kernels)
and self.moe_config.use_flashinfer_cutlass_kernels)
if (self.moe_parallel_config.use_pplx_kernels
or self.moe_parallel_config.use_deepep_ll_kernels
or use_flashinfer_cutlass_kernels):
Expand All @@ -1674,7 +1674,7 @@ def forward_impl(self, hidden_states: torch.Tensor,
do_naive_dispatch_combine: bool = (
self.dp_size > 1
and not self.moe_parallel_config.use_deepep_ht_kernels
and not self.moe_parallel_config.use_flashinfer_cutlass_kernels)
and not self.moe_config.use_flashinfer_cutlass_kernels)
if do_naive_dispatch_combine:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
Expand Down
6 changes: 2 additions & 4 deletions vllm/model_executor/layers/quantization/mxfp4.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,8 +623,6 @@ def apply(

if should_use_flashinfer_mxfp4():
from flashinfer import mxfp8_quantize, trtllm_fp4_block_scale_moe
assert not self.moe.use_ep, (
"EP is not supported for flashinfer mxfp4 moe backend yet.")
if _should_use_flashinfer_mxfp4_bf16():
assert x.dtype == torch.bfloat16
x_quant = x
Expand All @@ -650,12 +648,12 @@ def apply(
None, # output1_scale_scalar
None, # output1_scale_gate_scalar
None, # output2_scale_scalar
self.num_experts,
global_num_experts,
top_k,
None, # n_group
None, # topk_group
self.intermediate_size, # padded to multiple of 256
0, # local_expert_offset
layer.ep_rank * layer.local_num_experts, # local_expert_offset
self.num_experts, # local num experts
None,
self._get_tile_tokens_dim(x, top_k),
Expand Down