@@ -81,7 +81,9 @@ def apply(
81
81
use_grouped_topk : bool ,
82
82
topk_group : Optional [int ] = None ,
83
83
num_expert_group : Optional [int ] = None ,
84
- custom_routing_function : Optional [Callable ] = None
84
+ custom_routing_function : Optional [Callable ] = None ,
85
+ scoring_func : str = "softmax" ,
86
+ e_score_correction_bias : Optional [torch .Tensor ] = None
85
87
) -> torch .Tensor :
86
88
return self .forward (x = x ,
87
89
layer = layer ,
@@ -91,7 +93,9 @@ def apply(
91
93
use_grouped_topk = use_grouped_topk ,
92
94
topk_group = topk_group ,
93
95
num_expert_group = num_expert_group ,
94
- custom_routing_function = custom_routing_function )
96
+ custom_routing_function = custom_routing_function ,
97
+ scoring_func = scoring_func ,
98
+ e_score_correction_bias = e_score_correction_bias )
95
99
96
100
def forward_cuda (
97
101
self ,
@@ -103,7 +107,9 @@ def forward_cuda(
103
107
renormalize : bool ,
104
108
topk_group : Optional [int ] = None ,
105
109
num_expert_group : Optional [int ] = None ,
106
- custom_routing_function : Optional [Callable ] = None
110
+ custom_routing_function : Optional [Callable ] = None ,
111
+ scoring_func : str = "softmax" ,
112
+ e_score_correction_bias : Optional [torch .Tensor ] = None
107
113
) -> torch .Tensor :
108
114
topk_weights , topk_ids = FusedMoE .select_experts (
109
115
hidden_states = x ,
@@ -113,7 +119,9 @@ def forward_cuda(
113
119
renormalize = renormalize ,
114
120
topk_group = topk_group ,
115
121
num_expert_group = num_expert_group ,
116
- custom_routing_function = custom_routing_function )
122
+ custom_routing_function = custom_routing_function ,
123
+ scoring_func = scoring_func ,
124
+ e_score_correction_bias = e_score_correction_bias )
117
125
118
126
return fused_experts (hidden_states = x ,
119
127
w1 = layer .w13_weight ,
@@ -136,7 +144,8 @@ def forward_tpu(
136
144
renormalize : bool ,
137
145
topk_group : Optional [int ] = None ,
138
146
num_expert_group : Optional [int ] = None ,
139
- custom_routing_function : Optional [Callable ] = None
147
+ custom_routing_function : Optional [Callable ] = None ,
148
+ ** kwargs ,
140
149
) -> torch .Tensor :
141
150
assert not use_grouped_topk
142
151
assert num_expert_group is None
@@ -190,6 +199,7 @@ def __init__(
190
199
prefix : str = "" ,
191
200
custom_routing_function : Optional [Callable ] = None ,
192
201
scoring_func : str = "softmax" ,
202
+ e_score_correction_bias : Optional [torch .Tensor ] = None ,
193
203
):
194
204
super ().__init__ ()
195
205
@@ -210,9 +220,12 @@ def __init__(
210
220
self .topk_group = topk_group
211
221
self .custom_routing_function = custom_routing_function
212
222
self .scoring_func = scoring_func
223
+ self .e_score_correction_bias = e_score_correction_bias
213
224
214
225
if self .scoring_func != "softmax" and not self .use_grouped_topk :
215
- raise ValueError ("Only softmax scoring function is supported for non-grouped topk." )
226
+ raise ValueError (
227
+ "Only softmax scoring function is supported for non-grouped topk."
228
+ )
216
229
217
230
if quant_config is None :
218
231
self .quant_method : Optional [QuantizeMethodBase ] = (
@@ -447,7 +460,8 @@ def select_experts(hidden_states: torch.Tensor,
447
460
topk_group : Optional [int ] = None ,
448
461
num_expert_group : Optional [int ] = None ,
449
462
custom_routing_function : Optional [Callable ] = None ,
450
- scoring_func : str = "softmax" ):
463
+ scoring_func : str = "softmax" ,
464
+ e_score_correction_bias : Optional [torch .Tensor ] = None ):
451
465
from vllm .model_executor .layers .fused_moe .fused_moe import (
452
466
fused_topk , grouped_topk )
453
467
@@ -462,7 +476,8 @@ def select_experts(hidden_states: torch.Tensor,
462
476
renormalize = renormalize ,
463
477
num_expert_group = num_expert_group ,
464
478
topk_group = topk_group ,
465
- scoring_func = scoring_func )
479
+ scoring_func = scoring_func ,
480
+ e_score_correction_bias = e_score_correction_bias )
466
481
elif custom_routing_function is None :
467
482
topk_weights , topk_ids = fused_topk (hidden_states = hidden_states ,
468
483
gating_output = router_logits ,
@@ -491,7 +506,9 @@ def forward(self, hidden_states: torch.Tensor,
491
506
use_grouped_topk = self .use_grouped_topk ,
492
507
topk_group = self .topk_group ,
493
508
num_expert_group = self .num_expert_group ,
494
- custom_routing_function = self .custom_routing_function )
509
+ custom_routing_function = self .custom_routing_function ,
510
+ scoring_func = self .scoring_func ,
511
+ e_score_correction_bias = self .e_score_correction_bias )
495
512
496
513
if self .reduce_results and self .tp_size > 1 :
497
514
final_hidden_states = tensor_model_parallel_all_reduce (
0 commit comments