Skip to content

Fix MLM loss masking in PyTorch BERT pretraining#2688

Open
Chessing234 wants to merge 1 commit into
d2l-ai:masterfrom
Chessing234:fix/bert-mlm-loss-reduction
Open

Fix MLM loss masking in PyTorch BERT pretraining#2688
Chessing234 wants to merge 1 commit into
d2l-ai:masterfrom
Chessing234:fix/bert-mlm-loss-reduction

Conversation

@Chessing234
Copy link
Copy Markdown

Summary

  • nn.CrossEntropyLoss() defaults to reduction='mean', returning a scalar
  • Multiplying a scalar by mlm_weights_X.reshape(-1, 1) makes the padding mask a no-op — every position contributes equally regardless of whether it was a real prediction or padding
  • Changing to reduction='none' returns per-element losses so the weight mask correctly zeros out padded positions before the manual sum/normalize
  • The NSP loss now calls .mean() explicitly since reduction='none' applies to the shared loss instance
  • The MXNet version is already correct because gluon.loss.SoftmaxCELoss() returns per-element losses by default

Fixes #2582

Changes

  • chapter_natural-language-processing-pretraining/bert-pretraining.md:
    • Line 58: nn.CrossEntropyLoss()nn.CrossEntropyLoss(reduction='none')
    • Line 118: loss(nsp_Y_hat, nsp_y)loss(nsp_Y_hat, nsp_y).mean()

Test plan

  • Run the PyTorch BERT pretraining notebook end-to-end
  • Verify MLM loss values differ from before (mask now has effect)
  • Verify NSP loss values are unchanged (.mean() matches prior default behavior)

🤖 Generated with Claude Code

nn.CrossEntropyLoss() defaults to reduction='mean', returning a scalar.
Multiplying a scalar by mlm_weights_X makes the padding mask a no-op —
every position contributes equally regardless of the mask. Changing to
reduction='none' returns per-element losses so the mask correctly zeros
out padded prediction positions before summing.

The NSP loss now calls .mean() explicitly since reduction='none' applies
globally to the loss instance.

Fixes d2l-ai#2582

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

The mlm loss computation in the function _get_batch_loss_bert seems wrong in d2l pytorch code

1 participant