Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix aux loss
  • Loading branch information
rakkit committed Dec 10, 2025
commit 2830bcb073e0ec6b86b5244c63e1ad031b26a4c4
13 changes: 11 additions & 2 deletions torchtitan/models/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,13 @@ def forward(
_, selected_experts_indices = torch.topk(
scores + expert_bias, k=self.top_k, dim=1
)
expert_indices_for_load_balance = torch.topk(scores, k=self.top_k, dim=1)[1]
top_scores = scores.gather(dim=1, index=selected_experts_indices)
else:
top_scores, selected_experts_indices = torch.topk(
scores, k=self.top_k, dim=1
)
expert_indices_for_load_balance = torch.topk(scores, k=self.top_k, dim=1)[1]

# debug override: balanced round-robin routing
if self._debug_force_load_balance:
Expand All @@ -288,7 +290,13 @@ def forward(
max=self.num_experts,
)

return top_scores, scores, selected_experts_indices, num_tokens_per_expert
return (
top_scores,
scores,
selected_experts_indices,
num_tokens_per_expert,
expert_indices_for_load_balance,
)

def init_weights(self, init_std: float):
nn.init.trunc_normal_(self.gate.weight, mean=0.0, std=init_std)
Expand Down Expand Up @@ -425,6 +433,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
scores,
selected_experts_indices,
num_tokens_per_expert,
expert_indices_for_load_balance,
) = self.router(x, self.expert_bias)

# tokens_per_expert will be used to update the expert bias for load balancing.
Expand All @@ -439,7 +448,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.load_balance_loss_type == "sequence_wise":
load_balance_loss = MoE.sequence_wise_aux_loss(
scores,
selected_experts_indices.long(),
expert_indices_for_load_balance.long(),
bs,
slen,
self.top_k,
Expand Down