Conversation
IvanYashchuk
left a comment
There was a problem hiding this comment.
In thunder/tests/distributed/test_ddp.py it's okay to skip the test grad bucketing in Thunder's DDP. I only get two failures with TransformerEngine (the same error as in #2060):
FAILED thunder/tests/distributed/test_ddp.py::test_ddp_transformer_engine_torch_cuda_thunder.dtypes.float32 - ValueError: not enough values to unpack (expected 5, got 4)
FAILED thunder/tests/distributed/test_ddp.py::test_ddp_transformer_engine_llama_sanity_torch_cuda_thunder.dtypes.float32 - ValueError: not enough values to unpack (expected 5, got 4)Similarly, in thunder/tests/distributed/test_fsdp.py it's okay to skip failing tests for Thunder's FSDP bucketing and no_sync.
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Pull Request Overview
This pull request updates several components of the autodiff and distributed transforms along with modifications to the executor implementations and corresponding tests to support new behaviors in the CI. Key changes include:
- Introducing a new utility function (_group_get_grad_bsyms) in the autodiff transform.
- Adjusting test expectations and xfail markers for distributed traces.
- Refining conditionals in passes and executors to appropriately handle get_grad operations.
Reviewed Changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| thunder/transforms/autodiff.py | Added _group_get_grad_bsyms and updated gradient grouping logic. |
| thunder/tests/test_examine_memory.py | Updated test expectations for memory estimates. |
| thunder/tests/distributed/test_fsdp.py | Updated unshard parameter names and corrected trace index usage. |
| thunder/tests/distributed/test_ddp.py | Added xfail marker for grad bucketing test. |
| thunder/executors/torchex.py | Minor whitespace addition before shallow_copy registration. |
| thunder/executors/torch_compile.py | Excluded GET_GRAD from implementation mapping in executor. |
| thunder/executors/torch_autograd.py | Early return if bw_trace is None added. |
| thunder/executors/passes.py | Extended condition to pass through GET_GRAD symbols. |
| thunder/core/transform_common.py | Skipping further processing for GET_GRAD symbols now. |
| thunder/core/rematerialization.py | Enhanced filtering of parameter names during rematerialization. |
| thunder/init.py | Revised trace-split logic under the delay_trace_split branch. |
|
@IvanYashchuk about skipping bucketing with ddp and fsdp, we are actually counting on that to work properly for our distributed work. Do you think this can be tackled on your end? |
|
Thanks for the clarification @IvanYashchuk. I agree we can forego bucketing for now and eventually circle back to it at a later stage. |
|
I've benchmarked this change with the following baseline and result: |
|
Does it work with the thunder.jit? Could we also benchmark with "thunder" as compiler? |
|
|
@riccardofelluga has made an excellent point. The bound symbols which create tensors that are to be recomputed in the backward pass are only inserted once the joint trace is split. This means that those operations do not get fused optimally. |
|
@beverlylytle Thank you for running the benchmark. But isn't it then the fusion or part of it that gets duplicated for recomputation? The other question is whether everything between the autodiff and the split propagates the tag properly. We would probably need to take a good look. What we need to do is likely look for recompute tags in the fusion subsymbols and then duplicate the fusion and then dce the bits away that we don't need. WDYT? |
|
Yes, it is the fusion or part of it that gets duplicated, but a fusion region may compute more than what is intended to be recomputed in backward, right? Moreover, the fusion regions for the basic backward together with the fusion regions for the stuff to be recomputed in the backward may not maximally fused. |
Yes, but this could be DCEd and rerun the codegen?
Indeed, and it is tricky. But our fusion algorithms all seem wild currently. The one used by JIT by default fuses little (only adjacent regions) and the one used by thunderfx reoders operations that can wreck memory consumption (we saw this for checkpointing in particular). But so I seem to remember that the interaction of fusions with recomputation was one of the key things you wanted to get out of the joint trace logic, so it might be more tricky yet. The other thing on my list of larger project that might impact our choices is to have something like def fn(inp, target):
out = model(inp)
loss = lossfn(out, target)
loss.backward()
jfn = jit(fn, models=(model,))as one trace. This would let us not split the trace and likely is an important next step (followed by adding |
|
Not only is it important that the duplicates of the bsyms that should be recomputed exist during the fusion pass, it is also important that they exist during rematerialization. I am going to try an approach where the duplication logic is moved from the splitting up to the creation of the joint trace. I will modify CSE so that no commonality will be found across the forward-backward divide. |
|
To my mind, we could merge this PR as is and do what is needed in a follow up. WDYT? (also @IvanYashchuk ) |
t-vi
left a comment
There was a problem hiding this comment.
Exciting to have this go in!
Thank you @beverlylytle @riccardofelluga @IvanYashchuk @lantiga
This PR aims to use a joint forward-backward trace in
transform_for_executionwhile jitting, instead of separately processing a forward trace and a backward trace. This change is behind the compile option flagdelay_trace_split, which currently defaults to True. Provided no performance or memory issues appear, this will allow for a follow-up PR which can remove the flag and delete ~300 lines from torch_autograd.py and ~300 lines from rematerialization.py along with the relevant tests.