[precompile] add ability to precompile torchtitan models#2092
[precompile] add ability to precompile torchtitan models#2092bobrenjc93 wants to merge 2 commits intogh/bobrenjc93/4/basefrom
Conversation
for simplefsdp dsv3 we see the time taken to get through the first batch go down from 17.99 => 1.73 seconds. For posterity the command used for testing was ``` TORCH_LOGS="all" NGPU=2 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" cache_tlp ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable --activation_checkpoint.mode "none" ``` For this to work you'll need to work on a pytorch checkout later than pytorch/pytorch#169242 This currently has only been tested with dsv3 and simplefsdp. Notably the current implementation does not yet support PP. This will be added at a later time. [ghstack-poisoned]
for simplefsdp dsv3 we see the time taken to get through the first batch go down from 17.99 => 1.73 seconds. For posterity the command used for testing was ``` TORCH_LOGS="all" NGPU=2 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" cache_tlp ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable --activation_checkpoint.mode "none" ``` For this to work you'll need to work on a pytorch checkout later than pytorch/pytorch#169242 This currently has only been tested with dsv3 and simplefsdp. Notably the current implementation does not yet support PP. This will be added at a later time. ghstack-source-id: 757d8b7 Pull Request resolved: #2092
for simplefsdp dsv3 we see the time taken to get through the first batch go down from 17.99 => 1.73 seconds. For posterity the command used for testing was ``` TORCH_LOGS="all" NGPU=2 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" cache_tlp ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable --activation_checkpoint.mode "none" ``` For this to work you'll need to work on a pytorch checkout later than pytorch/pytorch#169242 This currently has only been tested with dsv3 and simplefsdp. Notably the current implementation does not yet support PP. This will be added at a later time. [ghstack-poisoned]
for simplefsdp dsv3 we see the time taken to get through the first batch go down from 17.99 => 1.73 seconds. For posterity the command used for testing was ``` TORCH_LOGS="all" NGPU=2 CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" cache_tlp ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable --activation_checkpoint.mode "none" ``` For this to work you'll need to work on a pytorch checkout later than pytorch/pytorch#169242 This currently has only been tested with dsv3 and simplefsdp. Notably the current implementation does not yet support PP. This will be added at a later time. ghstack-source-id: 76c04e1 Pull Request resolved: #2092
|
|
||
| self.job_config = job_config | ||
|
|
||
| if job_config.compile.enable_precompilation: |
There was a problem hiding this comment.
qq. Is this for simplefsdp-only or also works for fsdp2+block-level compile?
maybe you want to add this config to apply_compile here for fsdp2:
torchtitan/torchtitan/models/llama3/infra/parallelize.py
Lines 236 to 247 in cbdb311
There was a problem hiding this comment.
Currently only simplefsdp but this should work with fsdp2+block-level compile with some additional work.
|
|
||
| # Create a unique filename based on model configuration and rank | ||
| filename = f"compiled_fn_{model_name}_{model_flavor}_rank_{rank}.pt" | ||
| return os.path.join("/tmp", filename) |
There was a problem hiding this comment.
This isn't a realistic file path for training on FB infra, as the tmp is cleared if you restart training
There was a problem hiding this comment.
Agreed. For FB infra, we would either package the artifact into the conda or fbpkg build, or place it in oilfs and keep a reference to it. For Torchtitan, using /tmp seemed acceptable, though I can make the location configurable through an environment variable. Did you have a different approach in mind?
| } | ||
| module_cls = type( | ||
| f"SimpleFSDP{module.__class__.__name__}", | ||
| f"SimpleFSDP{module.__class__.__name__}_{_wrap_class_counter}", |
There was a problem hiding this comment.
|
@bobrenjc93 Will this new precompile option also work with the compiler toolkit experiment? |
tianyu-l
left a comment
There was a problem hiding this comment.
Sorry not sure if this is a draft or ready for review, so putting a hold so that it's not accidentally merged as is.
If it's ready for review: the change seems quite intrusive, please consider simplifying or putting it in compiler_toolkit experiment folder.
|
@tianyu-l @aditvenk this PR served as a proof of concept to demonstrate an end-to-end flow where precompilation works with simplefsdp. I'll abandon it for now and shift focus to a more narrowly scoped PR that integrates precompile with the compiler toolkit (which does need some work since |
Stack from ghstack (oldest at bottom):
For context for folks who don't know, precompile is a new technology which allows
us to serialize a torch.compile'd model as a file on disk that we can load in the future
to avoid recompilations. It doesn't help with cold starts but is quite useful for warm
starts and preemptions where the underlying model doesn't change.
for simplefsdp dsv3 we see the time taken to get through the first
batch go down from 17.99 => 1.73 seconds. For posterity the command
used for testing was
For this to work you'll need to work on a pytorch checkout later than
pytorch/pytorch#169242
This currently has only been tested with dsv3 and simplefsdp. Notably
the current implementation does not yet support PP. This will be added
at a later time.