Warn that SAC + Compile for MoE models is not yet supported#2052
Warn that SAC + Compile for MoE models is not yet supported#2052
Conversation
stack-info: PR: #2052, branch: xmfan/stack/4
|
|
||
| if ac_config.mode == "selective": | ||
| logger.warning( | ||
| "Selective Activation Checkpointing is not yet supported for MoE models, " |
There was a problem hiding this comment.
This is a little bit confusing, SAC works with eager for MoE models
stack-info: PR: #2052, branch: xmfan/stack/4
wwwjn
left a comment
There was a problem hiding this comment.
LGTM! Thanks for making this!
There was a problem hiding this comment.
sorry, didn't follow -- what's the issue between compile + SAC + MoE?
CheckpointWrapper is being applied to all submodules in SAC, but only at the block-level for Full AC. That breaks the logic of apply_compile ever since #1895.
What's the problem with full AC at block level? is it because we have full AC (compile)?
Also could you help make a central list on the composability issues among AC, compile, MoE?
I realized that
pytorch/pytorch#167844 fixes SAC around torch.compile region
| "Compile + Selective Activation Checkpointing is not yet supported for MoE models, " | ||
| "please use Full Activation Checkpointing instead. Turning off Compile." | ||
| ) | ||
| return |
SAC will wrap each submodule of TransformerBlock separately (_apply_op_sac_to_transformer_block_with_flex), which will make each submodule of TransformerBlock an instance of CheckpointWrapper. This will make the So #1895 only works with Full AC, not SAC. AC(compile(moe)) works, but SAC(compile(moe)) doesn't work. |
So everything should be fixed now, we just need to remove the hack in _apply_op_sac_to_transformer_block_with_flex and test |
So there are two cases here, depending on whether you care that compiling makes your graph opaque. The fix there primarily addresses one of the cases. |
|
To check my understanding:
So if only FlexAttn is compiled (not each transformer layers / or submodule of transformer layers), SAC works.
Say if we compile each transformer layers, do you mean we can only save / recompute all the ops within the transformer layer, can not specify which ops to save in SAC region? |
Is this full AC behavior? Or do you mean something else? Seems I was aware of this behavior before. |
|
@wwwjn @tianyu-l yeah I think your understanding is correct - either save all activations need for backward computed within the compiled region or recompute all ops, just like full AC.
Yes, but existing policy needs to be updated to handle the inductor HOP. |
Stacked PRs:
Warn that SAC + Compile for MoE models is not yet supported. Behavior should be identical for moe blocks, dense blocks are no longer compiled.
This also fixes another issue: CheckpointWrapper is being applied to all submodules in SAC, but only at the block-level for Full AC. That breaks the logic of
apply_compileever since #1895.