Skip to content

Gradient accumulation broken in PP #1733

@jdinalt

Description

@jdinalt

Bug description

Using gradient accumulation is incompatible with PipleineSchedule(..., scale_grads=True) option, which defaults to True.

When this option is set, at each step, all gradients are scaled by the micro-batch size. This works fine for a single gradient accumulation step, but when using multiple steps, this will rescale the total gradient by this factor, not just at the end of gradient accumulation.

The result is that the accumulated gradient is an exponential moving average, rather than a sum. Overall, the resulting gradients are much smaller than they should be and using gradient accumulation with PP is not equivalent to using it without PP -- the loss curves diverge substantially, as well as the gradient-norms are way off.

A secondary consequence is that at every step, it divides the gradients by n_microbatches, which is computationally expensive when applied to a large model.

I identified the same issue in my own pipeline trainer implementation a week or two ago. When checking how Torch Titan addressed the issue, I discovered that Titan probably has the same bug.

I had the time to confirm the presence of the issue today and have submitted #1732 to resolve the issue.

Versions

torch 2.10.0.dev20250915+cu126

For anyone who may be interested, I have added support for Torch Titan to my configuration framework, which is what I used for reproducing the issue.

https://github.com/jdinalt/forgather/tree/main/examples/torchtitan

Metadata

Metadata

Assignees

No one assigned

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions