Skip to content

Conversation

@tinuademargaret
Copy link

@tinuademargaret tinuademargaret commented Apr 28, 2025

This PR implements optimizer-in-backward fusion (suggested in this issue), which integrates optimizer updates directly into the backward pass to decrease peak memory usage by eliminating the need to store gradients. It uses pytorch's _apply_optimizer_in_backward to fuse optimizer steps with backpropagation.

This feature only works with the sft trainer.

Optimizer-in-backward fusion is incompatible with gradient accumulation and is controlled by a bwd_hook flag in the optimizer config.

There is added validation to ensure that training configurations that rely on gradient accumulation does not enable the optimizer fusion option.

@CLAassistant
Copy link

CLAassistant commented Apr 28, 2025

CLA assistant check
All committers have signed the CLA.

@tinuademargaret tinuademargaret marked this pull request as ready for review April 28, 2025 13:52
@eric-haibin-lin
Copy link
Collaborator

thanks! Just to double check, this supports gradient accumulation?

@tinuademargaret
Copy link
Author

thanks! Just to double check, this supports gradient accumulation?

No, It uses pytorch's apply_optimizer_in_backward which does not support gradient accumulation.

@vermouth1992
Copy link
Collaborator

thanks! Just to double check, this supports gradient accumulation?

No, It uses pytorch's apply_optimizer_in_backward which does not support gradient accumulation.

I guess this constrain limit the use of apply_optimizer_in_backward because most of the cases, we need gradient accumulation to avoid OOM :(

Copy link
Collaborator

@eric-haibin-lin eric-haibin-lin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! For RL i think it's very likely that we need gradient accumulation to train with large batch sizes. But this feature would still be useful for SFT. Do you think it make sense if you only keep the changes for SFT trainer, while keeping the RL trainer simple for now? Thx

@tinuademargaret
Copy link
Author

Thanks for the contribution! For RL i think it's very likely that we need gradient accumulation to train with large batch sizes. But this feature would still be useful for SFT. Do you think it make sense if you only keep the changes for SFT trainer, while keeping the RL trainer simple for now? Thx

I agree RL almost always needs gradient accumulation, so the immediate win for RL is limited. I've scoped the PR down to the SFT trainer for now and left the RL workers unchanged. PyTorch mentioned here that they are working on a more flexible post backward hook, I'm not sure of the status of this, but once that lands we can revisit a gradient accumulation friendly version for RL.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants