[simple_fsdp] apply bucketing ag/rs passes, reordering collectives, sink#1464
[simple_fsdp] apply bucketing ag/rs passes, reordering collectives, sink#1464IvanKobzarev wants to merge 1 commit intomainfrom
Conversation
a4b99bf to
56196de
Compare
| logger.info("Applied Data Parallel (dp mode=%s) to the model", dp_mode) | ||
|
|
||
| if job_config.training.compile: | ||
| from torch._inductor.comms import ( |
There was a problem hiding this comment.
nit: you don't need to import. you can pass string names of the passes instead, e.g. passes = ["sink_waits_iterative", ...]
| sink_waits_iterative, | ||
| ) | ||
|
|
||
| torch._inductor.config.allow_buffer_reuse = False |
There was a problem hiding this comment.
i think since this will change the behavior of compile for non-simplefsdp cases, we need to understand the impact before we can land. Or, we can introduce a new option for compile_and_optimize_comms that bundles all these things together.
In my autoparallel branch I preferred to expose the manual controls directly, mostly to allow experimentation. But for general consumption i think it is nicer to enable the right 'recipe' with a simple switch, so i like your approach.
There was a problem hiding this comment.
i think since this will change the behavior of compile for non-simplefsdp cases
@wconstab
This is in the simple_fsdp folder for SimpleFSDP case alone. Why do you think it would change the behavior of non-simplefsdp cases?
There was a problem hiding this comment.
oops, i was just not paying attention to which file this was, i assumed it was the core parallelize file. disregard my comment.
|
|
||
| torch._inductor.config.bucket_all_gathers_fx = "fsdp" | ||
| torch._inductor.config.bucket_reduce_scatters_fx = "fsdp" | ||
| torch._inductor.config.reorder_for_compute_comm_overlap = True |
There was a problem hiding this comment.
Thank you for adding this. It would be nice if you could also update the readme and detail what these configs are doing!
tianyu-l
left a comment
There was a problem hiding this comment.
Please see inline comments. I think it'd good if we can keep the passes configurable.
| sink_waits_iterative, | ||
| ) | ||
|
|
||
| torch._inductor.config.allow_buffer_reuse = False |
There was a problem hiding this comment.
i think since this will change the behavior of compile for non-simplefsdp cases
@wconstab
This is in the simple_fsdp folder for SimpleFSDP case alone. Why do you think it would change the behavior of non-simplefsdp cases?
| torch._inductor.config.allow_buffer_reuse = False | ||
| torch._inductor.config.reorder_for_peak_memory = False | ||
| torch._inductor.config.reorder_for_compute_comm_overlap = False | ||
|
|
||
| torch._inductor.config.bucket_all_gathers_fx = "fsdp" | ||
| torch._inductor.config.bucket_reduce_scatters_fx = "fsdp" | ||
| torch._inductor.config.reorder_for_compute_comm_overlap = True | ||
| torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ | ||
| sink_waits_iterative, | ||
| reorder_communication_preserving_peak_memory, | ||
| ] |
There was a problem hiding this comment.
I think we should make this configurable. Although I agree they can be turned on be default, I think it's still valuable if users can turn these passes off:
- Some researchers may appreciate a plain graph without reordering. This was one of the origin motivation to add this experiment here.
- The passes added here may not always be stable? If they cause failures, we should have the fallback to not enable them.
Will already added options in config manager - will close this pr |
Could you elaborate this a bit? I haven't seen @wconstab 's change in the Also for future reference, for such changes please consider presenting performance report in PR summary. |
Sorry, that was not to the So for simplefsdp porting this to |
Depends on landing pytorch PR
pytorch/pytorch#158663