Skip to content

Conversation

@ruisizhang123
Copy link
Contributor

@ruisizhang123 ruisizhang123 commented Oct 30, 2025

As titled, this is a follow up PR that avoid issuing additional bwd AG when activation checkpointing is enable. Prev, we only tested DSV3, which is not composable with AC.

The idea is quite simple: we add an additional checkpoint policy to AC when reshard_after_forward is False, which avoids recompute FSDP-related comms in ac.

reshard_after_fwd = False

  1. SAC + llama3 (trace)
Screenshot 2025-10-30 at 4 28 59 PM
  1. Full AC + llama3 (trace)
Screenshot 2025-10-30 at 4 30 53 PM
  1. No AC + llama3 [trace]
Screenshot 2025-10-30 at 4 32 05 PM

reshard_after_fwd = True

  1. SAC + llama3 (Trace)
Screenshot 2025-10-31 at 11 34 47 AM
  1. Full AC + llama3 (Trace)
Screenshot 2025-10-31 at 11 38 02 AM
  1. No AC + llama3 (Trace)
Screenshot 2025-10-31 at 11 43 04 AM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 30, 2025
@ruisizhang123 ruisizhang123 marked this pull request as draft October 30, 2025 23:56
@ezyang ezyang requested a review from soulitzer October 31, 2025 12:44
@ruisizhang123 ruisizhang123 force-pushed the ruisi/zero2_fix branch 2 times, most recently from 275259e to 284695c Compare October 31, 2025 18:55
@ruisizhang123 ruisizhang123 marked this pull request as ready for review October 31, 2025 18:55
@soulitzer
Copy link
Contributor

Is the logic for wrapping mostly identical for wrapping the modules/ worth deduplicating?
Is the only difference between zero2 and zero3 style FSDP that the policy for the collectives are different?

@ruisizhang123
Copy link
Contributor Author

ruisizhang123 commented Oct 31, 2025

Is the logic for wrapping mostly identical for wrapping the modules/ worth deduplicating? Is the only difference between zero2 and zero3 style FSDP that the policy for the collectives are different?

yes, I actually think I should reuse some functions from general apply_ac.py. However, the addtional simplefsdp ac policy is scattered around several functions, which makes reuse a bit hard...

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem might be harder than it sounds. That's actually one of the reasons I hadn't implemented it by myself.

For now, I'm OK with erroring out when reshard_after_forward=False + SAC/full AC is used.

Comment on lines +26 to +30
_op_simple_fsdp_save_list = {
torch.ops._c10d_functional.all_gather_into_tensor.default,
torch.ops._c10d_functional.wait_tensor.default,
torch.ops.aten._to_copy.default,
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To achieve the same effect, can we just modify the the input to apply_ac?
https://github.com/pytorch/torchtitan/blob/main/torchtitan/distributed/activation_checkpoint.py#L292
Do we have to duplicate other parts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you will see the additional simplefsdp_checkpointing_context_fn function is applied to _apply_full_ac and _apply_op_sac function. This is where this checkpoint policy actually takes into effect and why reusing other parts won't work, if this makes sense.


# for avoid recomputing SimpleFSDP all_gather in zero2-style FSDP
# it enforces additional policy to always mark SimpleFSDP all_gather as PREFER_SAVE
_op_simple_fsdp_save_list = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid this won't give us the right semantics.

SimpleFSDP is not the only module that uses these ops (all-gather, wait, to_copy). When you specify these ops in the save list, the side effect is that it will save all other occurrences of these ops and cause memory regression compared with FSDP2 reshard_after_forward=False + SAC.

This may be worked around by using custom all-gather/wait/to_copy for SimpleFSDP, as suggested by @fmassa . But then the question is how do you substitute the DTensor built-in collectives to use these custom ops?

Besides, what happens if full AC is combined with reshard_after_forward=False? I don't think the latter will take effect. Using SAC with fsdp ops only policy + custom SimpleFSDP ops is a proxy workaround.

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I also thought about FSDP+TP case for reshard_after_forward=False. But seems there is just not a good way of handling this other than adding a customized ac policy.... I'm more leaning toward add a warning "in multi-parallelism setting, open reshard_after_forward may cause memory regression".

Adding a custom op can be a big change and might break things... One potential way is get FSDP process group from device mesh, and check the all_gather op's process group in ac annotation here. Then, we only add MUST_SAVE to FSDP AG node, if this makes sense.

  • For full AC, see my previous comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One potential way is get FSDP process group from device mesh, and check the all_gather op's process group in ac annotation here. Then, we only add MUST_SAVE to FSDP AG node, if this makes sense.

This sounds fine for AG. But I'm more worried about _to_copy (needed in SimpleFSDP to achieve mixed precision). There are just too many other _to_copies in a transformer.

For full AC, see my previous comment.

Oh, my bad that I missed that part. It seems you are indeed using SAC with FSDP policy to mimic full AC + reshard_after_forward=False. Because of the implementation difference of SAC and full AC, they may not be identical, but I think it's OK approximation.

Nevertheless, the coding style is made hacky because of API limitations. Let's take this chance to discuss with the team how people want to move forward.

cc @fmassa @xmfan @soulitzer @fegin

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Nov 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it would be much easier to tag _to_copy ops in fx graph by looking at if the _to_copy is connected with FSDP AG. We have a chance to get things right in compile mode, but adding AC in eager mode correctly is hard...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From the other angle -- maybe it's also easier to make reshard_after_forward=False in compile.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants