Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 57 additions & 29 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,11 @@ def forward_cuda(
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if enable_eplb:
assert expert_load_view is not None
assert logical_to_physical_map is not None
assert logical_replica_count is not None
assert isinstance(layer, FusedMoE)

topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
Expand All @@ -481,7 +486,8 @@ def forward_cuda(
expert_map=expert_map,
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count)
logical_replica_count=logical_replica_count,
fused_experts_method=self.fused_experts)

if self.rocm_aiter_moe_enabled:
return self.rocm_aiter_fused_experts(
Expand All @@ -508,6 +514,7 @@ def forward_cuda(
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
expert_load_view=expert_load_view,
)
else:
assert fused_experts is not None
Expand Down Expand Up @@ -1444,6 +1451,7 @@ def select_experts(
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
fused_experts_method: Optional[Callable] = None
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Route the input hidden states to the top-k experts based on the
Expand Down Expand Up @@ -1526,34 +1534,41 @@ def select_experts(

# 2. Record expert load metrics.

# TODO(bowen): When using `FusedMoEModularKernel`, this
# can be done in a more unified way, since
# `FusedMoEPrepareAndFinalize` will return the expert
# token count, in some cases directly from the kernel.
# However, now there are many code paths not using
# the modular kernel, e.g. calling `fused_experts`,
# so we decide to keep the logic here.
#
# If later refactor moved all the MoE kernel calls
# to the modular kernel, we can move this logic there
# to achieve better efficiency.

# `expert_load_view`: (num_physical_experts,)

topk_ids_flatten = topk_ids.flatten()

# Performance optimization:
# `masked_fill` is significantly faster than `masked_select`
invalid_mask = topk_ids_flatten < 0
# Replace invalid expert ids with 0 (just a dummy position)
# to avoid out-of-bounds errors in scatter_add_
index = topk_ids_flatten.masked_fill_(invalid_mask, 0)
# `src` is the valid mask, which is 1 for valid and 0 for invalid
src = ~invalid_mask

expert_load_view.scatter_add_(dim=0,
index=index.long(),
src=src.to(expert_load_view))
# When using FusedMoEModularKernel,
# expert load statistics are handled directly in the kernel using
# ExpertTokensMetadata.expert_num_tokens for better performance.
# For other implementations or when metadata is not available,
# we fall back to here.

# There is no expert_num_tokens in
# expert_tokens_meta of DeepEPHTPrepareAndFinalize
# so it is not supported DeepEPHTPrepareAndFinalize for now.
# TODO: Maybe it is better to support DeepEPHTPrepareAndFinalize.
skip_expert_load_scatter_add = (
(fused_experts_method is not None)
and isinstance(fused_experts_method, FusedMoEModularKernel)
and (fused_experts_method.prepare_finalize.__class__
!= "DeepEPHTPrepareAndFinalize"))

if not skip_expert_load_scatter_add:
logger.debug("expert_load_view update from topk_ids.")
topk_ids_flatten = topk_ids.flatten()

# Performance optimization:
# `masked_fill` is significantly faster than `masked_select`
invalid_mask = topk_ids_flatten < 0
# Replace invalid expert ids with 0 (just a dummy position)
# to avoid out-of-bounds errors in scatter_add_
index = topk_ids_flatten.masked_fill_(invalid_mask, 0)
# `src` is the valid mask,
# which is 1 for valid and 0 for invalid
src = ~invalid_mask

expert_load_view.scatter_add_(dim=0,
index=index.long(),
src=src.to(expert_load_view))
else:
logger.debug("expert_load_view update in modular_kernel.")

topk_ids = topk_ids.to(dtype=indices_type)

Expand Down Expand Up @@ -1856,6 +1871,19 @@ def extra_repr(self) -> str:

return s

def update_map(self, new_expert_map):
self.expert_map = new_expert_map

def get_map(self):
return self.expert_map

def get_log2phy_map(self):
return self.logical_to_physical_map

def clear_expert_load_view(self):
if self.expert_load_view is not None:
self.expert_load_view.zero_()


def moe_forward(
hidden_states: torch.Tensor,
Expand Down
33 changes: 33 additions & 0 deletions vllm/model_executor/layers/fused_moe/modular_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,9 @@ def __init__(
self.prepare_finalize = prepare_finalize
self.fused_experts = fused_experts
self.shared_experts = shared_experts
# for EPLB
self.local_to_global_physical_experts = None
self.expert_map = None
assert prepare_finalize.activation_format == \
fused_experts.activation_formats[0], (
f"{prepare_finalize.__class__.__name__}."
Expand Down Expand Up @@ -755,6 +758,7 @@ def forward(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""
This function computes a Mixture of Experts (MoE) layer using two sets
Expand Down Expand Up @@ -787,6 +791,10 @@ def forward(
- apply_router_weight_on_input (bool): When true, the topk weights are
applied directly on the inputs. This is only applicable when topk is
1.
- expert_load_view (Optional[torch.Tensor]): Optional tensor for
tracking expert load statistics. If provided, the kernel will
update it using ExpertTokensMetadata.expert_num_tokens for
better performance.

Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
Expand Down Expand Up @@ -840,6 +848,31 @@ def forward(
(a1q, a1q_scale, expert_tokens_meta, _expert_topk_ids,
_expert_topk_weights) = receiver()

# In EPLB, update expert load from expert_num_tokens.
if (expert_tokens_meta is not None and expert_load_view is not None
and expert_tokens_meta.expert_num_tokens is not None
and expert_map is not None):
# Initialize the mapping of the local physical experts
# to global physical experts, after which it will not change.
# `expert_load_view`: (num_physical_experts,)
# `expert_num_tokens`: (local_num_physical_experts,)
if self.expert_map is None:
self.expert_map = expert_map.clone()
self.local_to_global_physical_experts = \
torch.nonzero(expert_map != -1,
as_tuple=False).squeeze()
else:
if not torch.equal(self.expert_map, expert_map):
self.expert_map = expert_map.clone()
self.local_to_global_physical_experts = \
torch.nonzero(expert_map != -1,
as_tuple=False).squeeze()
# Use pre-computed expert token counts from metadata
expert_load_view.scatter_add_(
dim=0,
index=self.local_to_global_physical_experts,
src=expert_tokens_meta.expert_num_tokens)

# Maybe prepare gathered topk_ids and topk_weights from other EP ranks.
topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids
topk_weights = (topk_weights if _expert_topk_weights is None else
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,7 +1055,7 @@ def apply(
expert_load_view=expert_load_view,
logical_to_physical_map=logical_to_physical_map,
logical_replica_count=logical_replica_count,
)
fused_experts_method=self.fused_experts)

if self.rocm_aiter_moe_enabled:
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
Expand Down Expand Up @@ -1114,6 +1114,7 @@ def apply(
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
expert_load_view=expert_load_view,
)
else:
return flashinfer_cutlass_moe_fp8(
Expand Down
Loading