-
Notifications
You must be signed in to change notification settings - Fork 595
[simplefsdp] fix region ac in zero2-style FSDP #1970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
9c4b454 to
429bbb7
Compare
275259e to
284695c
Compare
284695c to
b59290b
Compare
|
Is the logic for wrapping mostly identical for wrapping the modules/ worth deduplicating? |
yes, I actually think I should reuse some functions from general |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem might be harder than it sounds. That's actually one of the reasons I hadn't implemented it by myself.
For now, I'm OK with erroring out when reshard_after_forward=False + SAC/full AC is used.
| _op_simple_fsdp_save_list = { | ||
| torch.ops._c10d_functional.all_gather_into_tensor.default, | ||
| torch.ops._c10d_functional.wait_tensor.default, | ||
| torch.ops.aten._to_copy.default, | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To achieve the same effect, can we just modify the the input to apply_ac?
https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/activation_checkpoint.py#L292
Do we have to duplicate other parts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you will see the additional simplefsdp_checkpointing_context_fn function is applied to _apply_full_ac and _apply_op_sac function. This is where this checkpoint policy actually takes into effect and why reusing other parts won't work, if this makes sense.
|
|
||
| # for avoid recomputing SimpleFSDP all_gather in zero2-style FSDP | ||
| # it enforces additional policy to always mark SimpleFSDP all_gather as PREFER_SAVE | ||
| _op_simple_fsdp_save_list = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm afraid this won't give us the right semantics.
SimpleFSDP is not the only module that uses these ops (all-gather, wait, to_copy). When you specify these ops in the save list, the side effect is that it will save all other occurrences of these ops and cause memory regression compared with FSDP2 reshard_after_forward=False + SAC.
This may be worked around by using custom all-gather/wait/to_copy for SimpleFSDP, as suggested by @fmassa . But then the question is how do you substitute the DTensor built-in collectives to use these custom ops?
Besides, what happens if full AC is combined with reshard_after_forward=False? I don't think the latter will take effect. Using SAC with fsdp ops only policy + custom SimpleFSDP ops is a proxy workaround.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- I also thought about FSDP+TP case for reshard_after_forward=False. But seems there is just not a good way of handling this other than adding a customized ac policy.... I'm more leaning toward add a warning "in multi-parallelism setting, open reshard_after_forward may cause memory regression".
Adding a custom op can be a big change and might break things... One potential way is get FSDP process group from device mesh, and check the all_gather op's process group in ac annotation here. Then, we only add MUST_SAVE to FSDP AG node, if this makes sense.
- For full AC, see my previous comment.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One potential way is get FSDP process group from device mesh, and check the all_gather op's process group in ac annotation here. Then, we only add MUST_SAVE to FSDP AG node, if this makes sense.
This sounds fine for AG. But I'm more worried about _to_copy (needed in SimpleFSDP to achieve mixed precision). There are just too many other _to_copies in a transformer.
For full AC, see my previous comment.
Oh, my bad that I missed that part. It seems you are indeed using SAC with FSDP policy to mimic full AC + reshard_after_forward=False. Because of the implementation difference of SAC and full AC, they may not be identical, but I think it's OK approximation.
Nevertheless, the coding style is made hacky because of API limitations. Let's take this chance to discuss with the team how people want to move forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it would be much easier to tag _to_copy ops in fx graph by looking at if the _to_copy is connected with FSDP AG. We have a chance to get things right in compile mode, but adding AC in eager mode correctly is hard...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From the other angle -- maybe it's also easier to make reshard_after_forward=False in compile.
As titled, this is a follow up PR that avoid issuing additional bwd AG when activation checkpointing is enable. Prev, we only tested DSV3, which is not composable with AC.
The idea is quite simple: we add an additional checkpoint policy to AC when reshard_after_forward is False, which avoids recompute FSDP-related comms in ac.
reshard_after_fwd = False
reshard_after_fwd = True