Workaround AC HOP mutation issue when tracing token dispatch#1984
Workaround AC HOP mutation issue when tracing token dispatch#1984
Conversation
stack-info: PR: #1984, branch: xmfan/stack/2
| input_shape, | ||
| permuted_indices, | ||
| input_splits, | ||
| output_splits, |
There was a problem hiding this comment.
These shouldn't be exposed to single-device model code. Plus, I don't think it will work if EP is not used.
If it's getting too hard, maybe we should use local_map / to_local to re-implement MoE.
|
Thank you for the fix! Do you think it would require fewer user-side changes if we reimplemented apply_ac as a graph pass? |
| max_norm = 1.0 # grad norm clipping | ||
| steps = 1000 | ||
| dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) | ||
| dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) |
There was a problem hiding this comment.
minor: toml file config changed.
|
This change is only needed if you use compile(torch.utils.checkpoint(, so graph pass wouldn't need it. but if you use both eager and graph-based, you will need this again |
Yes, what I meant is that if we're going for a compiler-based approach to distributed parallelism in simplefsdp, it would make sense to have a specialized apply_ac function that’s also compiler-based. (and users are not allowed to use eager checkpoint to implement ac) |
FIXES #1935
Stacked PRs:
tlparse: https://fburl.com/sqxd6c0w
Workaround AC HOP mutation issue when tracing token dispatch
TORCH_COMPILE_FORCE_DISABLE_CACHES=1 HF_TOKEN=<token> HF_HUB_DISABLE_XET=1 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/deepseek_v3_16b.toml" with-proxy ./run_train.sh --model.name simple_fsdp.deepseek_v3This is a problem for SimpleFSDP where we want to fullgraph the entire model, these "mutation" cause graph break
It is less of a problem outside SimpleFSDP, because we don't currently compile token dispatch