@@ -449,25 +449,11 @@ def forward_cuda(
449
449
logical_to_physical_map : Optional [torch .Tensor ] = None ,
450
450
logical_replica_count : Optional [torch .Tensor ] = None ,
451
451
) -> torch .Tensor :
452
- skip_expert_load_scatter_add = False
453
452
if enable_eplb :
454
453
assert expert_load_view is not None
455
454
assert logical_to_physical_map is not None
456
455
assert logical_replica_count is not None
457
456
assert isinstance (layer , FusedMoE )
458
- # if `skip_expert_load_scatter_add` is True,
459
- # update `expert_load_view` in modular_kernel,
460
- # skipping scatter_add_ in FusedMoE.select_experts.
461
- if (self .fused_experts is not None and
462
- isinstance (self .fused_experts , FusedMoEModularKernel )):
463
-
464
- # There is no `expert_num_tokens` in
465
- # `expert_tokens_meta` of DeepEPHTPrepareAndFinalize
466
- # so it is not supported DeepEPHTPrepareAndFinalize for now.
467
- # TODO: Maybe it is better to support DeepEPHTPrepareAndFinalize.
468
- if not isinstance (self .fused_experts .prepare_finalize ,
469
- DeepEPHTPrepareAndFinalize ):
470
- skip_expert_load_scatter_add = True
471
457
472
458
topk_weights , topk_ids = FusedMoE .select_experts (
473
459
hidden_states = x ,
@@ -486,7 +472,7 @@ def forward_cuda(
486
472
expert_load_view = expert_load_view ,
487
473
logical_to_physical_map = logical_to_physical_map ,
488
474
logical_replica_count = logical_replica_count ,
489
- skip_expert_load_scatter_add = skip_expert_load_scatter_add
475
+ fused_experts_method = self . fused_experts
490
476
)
491
477
492
478
if self .rocm_aiter_moe_enabled :
@@ -1408,7 +1394,7 @@ def select_experts(
1408
1394
expert_load_view : Optional [torch .Tensor ] = None ,
1409
1395
logical_to_physical_map : Optional [torch .Tensor ] = None ,
1410
1396
logical_replica_count : Optional [torch .Tensor ] = None ,
1411
- skip_expert_load_scatter_add : bool = False ,
1397
+ fused_experts_method : Optional [ Callable ] = None
1412
1398
) -> tuple [torch .Tensor , torch .Tensor ]:
1413
1399
"""
1414
1400
Route the input hidden states to the top-k experts based on the
@@ -1489,12 +1475,25 @@ def select_experts(
1489
1475
topk_ids = physical_ids
1490
1476
1491
1477
# 2. Record expert load metrics
1492
- # Note: When using FusedMoEModularKernel, expert load statistics are handled
1493
- # directly in the kernel using ExpertTokensMetadata.expert_num_tokens for better performance.
1494
- # For other implementations or when metadata is not available, we fall back to scatter_add_.
1478
+ # When using FusedMoEModularKernel,
1479
+ # expert load statistics are handled directly in the kernel using
1480
+ # ExpertTokensMetadata.expert_num_tokens for better performance.
1481
+ # For other implementations or when metadata is not available,
1482
+ # we fall back to scatter_add_.
1495
1483
1496
- # Check if we're using FusedMoEModularKernel and if it has already processed the load
1484
+ # Check if we're using FusedMoEModularKernel and
1485
+ # if it has already processed the load.
1497
1486
# If not, use the traditional scatter_add_ approach.
1487
+
1488
+ # There is no expert_num_tokens in
1489
+ # expert_tokens_meta of DeepEPHTPrepareAndFinalize
1490
+ # so it is not supported DeepEPHTPrepareAndFinalize for now.
1491
+ # TODO: Maybe it is better to support DeepEPHTPrepareAndFinalize.
1492
+ skip_expert_load_scatter_add = ((fused_experts_method is not None ) and
1493
+ isinstance (fused_experts_method , FusedMoEModularKernel ) and
1494
+ (fused_experts_method .prepare_finalize .__class__ !=
1495
+ "DeepEPHTPrepareAndFinalize" ))
1496
+
1498
1497
if not skip_expert_load_scatter_add :
1499
1498
logger .debug ("expert_load_view update from topk_ids through scatter_add_." )
1500
1499
# Fallback to scatter_add_ for non-modular kernel implementations
@@ -1512,6 +1511,8 @@ def select_experts(
1512
1511
expert_load_view .scatter_add_ (dim = 0 ,
1513
1512
index = index .long (),
1514
1513
src = src .to (expert_load_view ))
1514
+ else :
1515
+ logger .debug ("expert_load_view update in modular_kernel through add_." )
1515
1516
1516
1517
topk_ids = topk_ids .to (dtype = indices_type )
1517
1518
0 commit comments