Skip to content

Commit ef5dd7b

Browse files
new change
1 parent 868343e commit ef5dd7b

File tree

2 files changed

+22
-37
lines changed

2 files changed

+22
-37
lines changed

vllm/model_executor/layers/fused_moe/layer.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -449,25 +449,11 @@ def forward_cuda(
449449
logical_to_physical_map: Optional[torch.Tensor] = None,
450450
logical_replica_count: Optional[torch.Tensor] = None,
451451
) -> torch.Tensor:
452-
skip_expert_load_scatter_add = False
453452
if enable_eplb:
454453
assert expert_load_view is not None
455454
assert logical_to_physical_map is not None
456455
assert logical_replica_count is not None
457456
assert isinstance(layer, FusedMoE)
458-
# if `skip_expert_load_scatter_add` is True,
459-
# update `expert_load_view` in modular_kernel,
460-
# skipping scatter_add_ in FusedMoE.select_experts.
461-
if (self.fused_experts is not None and
462-
isinstance(self.fused_experts, FusedMoEModularKernel)):
463-
464-
# There is no `expert_num_tokens` in
465-
# `expert_tokens_meta` of DeepEPHTPrepareAndFinalize
466-
# so it is not supported DeepEPHTPrepareAndFinalize for now.
467-
# TODO: Maybe it is better to support DeepEPHTPrepareAndFinalize.
468-
if not isinstance(self.fused_experts.prepare_finalize,
469-
DeepEPHTPrepareAndFinalize):
470-
skip_expert_load_scatter_add = True
471457

472458
topk_weights, topk_ids = FusedMoE.select_experts(
473459
hidden_states=x,
@@ -486,7 +472,7 @@ def forward_cuda(
486472
expert_load_view=expert_load_view,
487473
logical_to_physical_map=logical_to_physical_map,
488474
logical_replica_count=logical_replica_count,
489-
skip_expert_load_scatter_add=skip_expert_load_scatter_add
475+
fused_experts_method=self.fused_experts
490476
)
491477

492478
if self.rocm_aiter_moe_enabled:
@@ -1408,7 +1394,7 @@ def select_experts(
14081394
expert_load_view: Optional[torch.Tensor] = None,
14091395
logical_to_physical_map: Optional[torch.Tensor] = None,
14101396
logical_replica_count: Optional[torch.Tensor] = None,
1411-
skip_expert_load_scatter_add: bool = False,
1397+
fused_experts_method: Optional[Callable] = None
14121398
) -> tuple[torch.Tensor, torch.Tensor]:
14131399
"""
14141400
Route the input hidden states to the top-k experts based on the
@@ -1489,12 +1475,25 @@ def select_experts(
14891475
topk_ids = physical_ids
14901476

14911477
# 2. Record expert load metrics
1492-
# Note: When using FusedMoEModularKernel, expert load statistics are handled
1493-
# directly in the kernel using ExpertTokensMetadata.expert_num_tokens for better performance.
1494-
# For other implementations or when metadata is not available, we fall back to scatter_add_.
1478+
# When using FusedMoEModularKernel,
1479+
# expert load statistics are handled directly in the kernel using
1480+
# ExpertTokensMetadata.expert_num_tokens for better performance.
1481+
# For other implementations or when metadata is not available,
1482+
# we fall back to scatter_add_.
14951483

1496-
# Check if we're using FusedMoEModularKernel and if it has already processed the load
1484+
# Check if we're using FusedMoEModularKernel and
1485+
# if it has already processed the load.
14971486
# If not, use the traditional scatter_add_ approach.
1487+
1488+
# There is no expert_num_tokens in
1489+
# expert_tokens_meta of DeepEPHTPrepareAndFinalize
1490+
# so it is not supported DeepEPHTPrepareAndFinalize for now.
1491+
# TODO: Maybe it is better to support DeepEPHTPrepareAndFinalize.
1492+
skip_expert_load_scatter_add = ((fused_experts_method is not None) and
1493+
isinstance(fused_experts_method, FusedMoEModularKernel) and
1494+
(fused_experts_method.prepare_finalize.__class__ !=
1495+
"DeepEPHTPrepareAndFinalize"))
1496+
14981497
if not skip_expert_load_scatter_add:
14991498
logger.debug("expert_load_view update from topk_ids through scatter_add_.")
15001499
# Fallback to scatter_add_ for non-modular kernel implementations
@@ -1512,6 +1511,8 @@ def select_experts(
15121511
expert_load_view.scatter_add_(dim=0,
15131512
index=index.long(),
15141513
src=src.to(expert_load_view))
1514+
else:
1515+
logger.debug("expert_load_view update in modular_kernel through add_.")
15151516

15161517
topk_ids = topk_ids.to(dtype=indices_type)
15171518

vllm/model_executor/layers/quantization/fp8.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -926,27 +926,11 @@ def apply(
926926
logical_to_physical_map: Optional[torch.Tensor] = None,
927927
logical_replica_count: Optional[torch.Tensor] = None,
928928
) -> torch.Tensor:
929-
skip_expert_load_scatter_add = False
930929
if enable_eplb:
931930
assert expert_load_view is not None
932931
assert logical_to_physical_map is not None
933932
assert logical_replica_count is not None
934933
assert isinstance(layer, FusedMoE)
935-
# if `skip_expert_load_scatter_add` is True,
936-
# update `expert_load_view` in modular_kernel,
937-
# skipping scatter_add_ in FusedMoE.select_experts.
938-
if (self.fused_experts is not None and
939-
isinstance(self.fused_experts, FusedMoEModularKernel)):
940-
941-
# There is no `expert_num_tokens` in
942-
# `expert_tokens_meta` of DeepEPHTPrepareAndFinalize
943-
# so it is not supported DeepEPHTPrepareAndFinalize for now.
944-
# TODO: Maybe it is better to support DeepEPHTPrepareAndFinalize.
945-
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize \
946-
import DeepEPHTPrepareAndFinalize
947-
if not isinstance(self.fused_experts.prepare_finalize,
948-
DeepEPHTPrepareAndFinalize):
949-
skip_expert_load_scatter_add = True
950934

951935
if not self.flashinfer_moe_enabled:
952936
topk_weights, topk_ids = FusedMoE.select_experts(
@@ -966,7 +950,7 @@ def apply(
966950
expert_load_view=expert_load_view,
967951
logical_to_physical_map=logical_to_physical_map,
968952
logical_replica_count=logical_replica_count,
969-
skip_expert_load_scatter_add=skip_expert_load_scatter_add
953+
fused_experts_method=self.fused_experts
970954
)
971955

972956
if self.rocm_aiter_moe_enabled:

0 commit comments

Comments
 (0)