Skip to content

Conversation

@rakkit
Copy link
Contributor

@rakkit rakkit commented Nov 19, 2025

This is a draft PR for:

  1. Make the moe's load_balance_coeff configurable
  2. add the batch and seq-wise aux loss for load balance. [ref: dpskv3 eqn. 17~20]

For now, it only applies to the DeepSeek model, but I can add it for all other moe models at the end.
(also, we dont log the aux loss, but i can add it in optimizer hook to do this if you want)

The main concern is that the aux loss does not work well with PP. From what I have tested, it works well only with 1F1B. And it is broken for ZBV or interleaved 1f1b.

To test it:
[sequence_wise, by default]
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --training.extra_losses.load_balance_loss_weight=0.001
image

[batch_wise, need to pick this in ModelArgs]
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh --training.extra_losses.load_balance_loss_weight=0.001
image

(turn it off)
CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" NGPU=4 ./run_train.sh
image

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Nov 19, 2025
job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager
)

self.loss_fn = functools.partial(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add a condition here to wrap loss or not for MoE. for now all models in torchtitan only return a single output so its ok for now

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If subsume this moe loss wrapper into build_loss_fn we can avoid adding the logic here.

Copy link
Contributor

@wwwjn wwwjn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! @shuhuayu is working on a more formal review, and I have some house-keeping comments



@dataclass
class ExtraLosses:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section is specifically for MoE load balancing loss for now, do you foresee any other loss related params will be used in this section? If not, let's make the name for descriptive and specific

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Followup here. Should we merge these configs to the Model dataclass?

Copy link
Contributor

@shuhuayu shuhuayu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the pr @rakkit ! Made some comments here.

job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager
)

self.loss_fn = functools.partial(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If subsume this moe loss wrapper into build_loss_fn we can avoid adding the logic here.

@rakkit
Copy link
Contributor Author

rakkit commented Dec 9, 2025

Thanks a lot for the feedback, @wwwjn @shuhuayu (sorry for the late update)!

Summary of new changes:

  • Made the MoE loss a wrapper, so we can now do
    build_loss_fn = moe_loss_wrap(build_cross_entropy_loss)
    when defining the model in TrainSpec.

  • Moved ExtraLosses to the Training scope.
    The main purpose is to decouple this from the model definition.

  • Renamed load_balance_coeff to moe_aux_loss_free_bias_coeff — a bit longer, but clearer.

  • Now applied on moe models in models folder, (dpskv3, llama4, qwen)

  • Other refactors, thanks again @shuhuayu.

And be aware that the PP & aux-loss still does not work

self.load_balance_loss_weight,
)
else:
load_balance_loss = torch.tensor(0.0, device=out.device, dtype=out.dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I can see out is not defined in this scope yet.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed. thanks : )

@staticmethod
def sequence_wise_aux_loss(
scores: torch.Tensor,
indices: torch.Tensor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this will use the biased topk(scores + expert_bias) instead of the unbiased topk(scores) from DSv3 eq 18

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope, thats top_scores

Copy link
Contributor

@lckr lckr Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yeah, scores is the raw sigmoid output. But isn't indices (= selected_experts_indices) derived as topk(scores + expert_bias)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

emm, good question. need to think about this.

Copy link
Contributor Author

@rakkit rakkit Dec 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think you might be right, eq 18 the topk dont have "bias"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. I fixed this and rerun the two aux loss types and no aux loss in PR description.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants