-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Optimizer-in-backward Fusion Implementation #1295
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Optimizer-in-backward Fusion Implementation #1295
Conversation
|
thanks! Just to double check, this supports gradient accumulation? |
No, It uses pytorch's |
I guess this constrain limit the use of |
eric-haibin-lin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution! For RL i think it's very likely that we need gradient accumulation to train with large batch sizes. But this feature would still be useful for SFT. Do you think it make sense if you only keep the changes for SFT trainer, while keeping the RL trainer simple for now? Thx
I agree RL almost always needs gradient accumulation, so the immediate win for RL is limited. I've scoped the PR down to the SFT trainer for now and left the RL workers unchanged. PyTorch mentioned here that they are working on a more flexible post backward hook, I'm not sure of the status of this, but once that lands we can revisit a gradient accumulation friendly version for RL. |
This PR implements optimizer-in-backward fusion (suggested in this issue), which integrates optimizer updates directly into the backward pass to decrease peak memory usage by eliminating the need to store gradients. It uses pytorch's
_apply_optimizer_in_backwardto fuse optimizer steps with backpropagation.This feature only works with the sft trainer.
Optimizer-in-backward fusion is incompatible with gradient accumulation and is controlled by a
bwd_hookflag in the optimizer config.There is added validation to ensure that training configurations that rely on gradient accumulation does not enable the optimizer fusion option.