Skip to content

Commit ad6b898

Browse files
nvpohanhCyrilvallez
andcommitted
Fix broken Llama4 accuracy in MoE part (#40609)
* Fix broken Llama4 accuracy in MoE part Llama4 accuracy is broken by a bug in #39501 . It forgot to transpose the router_scores before applying it to routed_in, causing Llama4 to generate garbage output. This PR fixes that issue by adding back the transpose() and adding some comments explaining why the transpose() is needed. Signed-off-by: Po-Han Huang <[email protected]> * remove comment --------- Signed-off-by: Po-Han Huang <[email protected]> Co-authored-by: Cyril Vallez <[email protected]>
1 parent e7d351c commit ad6b898

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/models/llama4/modeling_llama4.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def forward(self, hidden_states):
158158
hidden_states = hidden_states.reshape(-1, self.hidden_dim)
159159
router_scores, router_logits = self.router(hidden_states)
160160
routed_in = hidden_states.repeat(router_scores.shape[1], 1)
161-
routed_in = routed_in * router_scores.reshape(-1, 1)
161+
routed_in = routed_in * router_scores.transpose(0, 1).reshape(-1, 1)
162162
routed_out = self.experts(routed_in)
163163
out = self.shared_expert(hidden_states)
164164
out.add_(routed_out.reshape(router_scores.shape[1], -1, routed_out.shape[-1]).sum(dim=0))

0 commit comments

Comments
 (0)