Skip to content

Use joint trace in transform_for_execution#2102

Merged
t-vi merged 22 commits intomainfrom
reautograd2
Jun 6, 2025
Merged

Use joint trace in transform_for_execution#2102
t-vi merged 22 commits intomainfrom
reautograd2

Conversation

@beverlylytle
Copy link
Copy Markdown
Collaborator

@beverlylytle beverlylytle commented May 20, 2025

This PR aims to use a joint forward-backward trace in transform_for_execution while jitting, instead of separately processing a forward trace and a backward trace. This change is behind the compile option flag delay_trace_split, which currently defaults to True. Provided no performance or memory issues appear, this will allow for a follow-up PR which can remove the flag and delete ~300 lines from torch_autograd.py and ~300 lines from rematerialization.py along with the relevant tests.

Copy link
Copy Markdown
Collaborator

@IvanYashchuk IvanYashchuk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In thunder/tests/distributed/test_ddp.py it's okay to skip the test grad bucketing in Thunder's DDP. I only get two failures with TransformerEngine (the same error as in #2060):

FAILED thunder/tests/distributed/test_ddp.py::test_ddp_transformer_engine_torch_cuda_thunder.dtypes.float32 - ValueError: not enough values to unpack (expected 5, got 4)
FAILED thunder/tests/distributed/test_ddp.py::test_ddp_transformer_engine_llama_sanity_torch_cuda_thunder.dtypes.float32 - ValueError: not enough values to unpack (expected 5, got 4)

Similarly, in thunder/tests/distributed/test_fsdp.py it's okay to skip failing tests for Thunder's FSDP bucketing and no_sync.

@IvanYashchuk IvanYashchuk requested a review from Copilot May 27, 2025 12:30
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This pull request updates several components of the autodiff and distributed transforms along with modifications to the executor implementations and corresponding tests to support new behaviors in the CI. Key changes include:

  • Introducing a new utility function (_group_get_grad_bsyms) in the autodiff transform.
  • Adjusting test expectations and xfail markers for distributed traces.
  • Refining conditionals in passes and executors to appropriately handle get_grad operations.

Reviewed Changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
thunder/transforms/autodiff.py Added _group_get_grad_bsyms and updated gradient grouping logic.
thunder/tests/test_examine_memory.py Updated test expectations for memory estimates.
thunder/tests/distributed/test_fsdp.py Updated unshard parameter names and corrected trace index usage.
thunder/tests/distributed/test_ddp.py Added xfail marker for grad bucketing test.
thunder/executors/torchex.py Minor whitespace addition before shallow_copy registration.
thunder/executors/torch_compile.py Excluded GET_GRAD from implementation mapping in executor.
thunder/executors/torch_autograd.py Early return if bw_trace is None added.
thunder/executors/passes.py Extended condition to pass through GET_GRAD symbols.
thunder/core/transform_common.py Skipping further processing for GET_GRAD symbols now.
thunder/core/rematerialization.py Enhanced filtering of parameter names during rematerialization.
thunder/init.py Revised trace-split logic under the delay_trace_split branch.

@lantiga
Copy link
Copy Markdown
Contributor

lantiga commented May 27, 2025

@IvanYashchuk about skipping bucketing with ddp and fsdp, we are actually counting on that to work properly for our distributed work. Do you think this can be tackled on your end?

@lantiga
Copy link
Copy Markdown
Contributor

lantiga commented May 27, 2025

Thanks for the clarification @IvanYashchuk. I agree we can forego bucketing for now and eventually circle back to it at a later stage.

@beverlylytle beverlylytle changed the title [WIP2] Use joint trace in transform_for_execution May 28, 2025
@beverlylytle beverlylytle marked this pull request as ready for review May 28, 2025 07:40
@beverlylytle beverlylytle mentioned this pull request May 28, 2025
4 tasks
@beverlylytle
Copy link
Copy Markdown
Collaborator Author

beverlylytle commented Jun 3, 2025

I've benchmarked this change with the following baseline and result:

@main
Model name: Llama-3-8B
Seq Length: 8192
Micro BS: 1
Global BS: 8
Number of Layers: 32
Number of parameters: 1.00B
Distributed Mode: fsdp
Sharding Mode: zero2
Bucketing: block
Compiler: dynamo_thunder
Low Precision Mode: none
Average iter time: 782.73 ms
Memory used: 72.61 GB
Tokens/s: 83707.48
Tokens/s/GPU: 10463.43
TFLOP/s: 4847.74

@reautograd2
Model name: Llama-3-8B
Seq Length: 8192
Micro BS: 1
Global BS: 8
Number of Layers: 32
Number of parameters: 1.00B
Distributed Mode: fsdp
Sharding Mode: zero2
Bucketing: block
Compiler: dynamo_thunder
Low Precision Mode: none
Average iter time: 781.49 ms
Memory used: 72.61 GB
Tokens/s: 83855.93
Tokens/s/GPU: 10481.99
TFLOP/s: 4856.34

@t-vi
Copy link
Copy Markdown
Collaborator

t-vi commented Jun 3, 2025

Does it work with the thunder.jit? Could we also benchmark with "thunder" as compiler?

@beverlylytle
Copy link
Copy Markdown
Collaborator Author

beverlylytle commented Jun 3, 2025

@main
Model name: Llama-3-8B
Seq Length: 8192
Micro BS: 1
Global BS: 8
Number of Layers: 32
Number of parameters: 1.00B
Distributed Mode: fsdp
Sharding Mode: zero2
Bucketing: none
Compiler: thunder
Low Precision Mode: none
Average iter time: 799.59 ms
Memory used: 75.75 GB
Saved for backward size: 58448.60 MiB
Saved for backward number of tensors: 775
Tokens/s: 81988.63
Tokens/s/GPU: 10248.58
TFLOP/s: 4748.20

@retrograd2
Model name: Llama-3-8B
Seq Length: 8192
Micro BS: 1
Global BS: 8
Number of Layers: 32
Number of parameters: 1.00B
Distributed Mode: fsdp
Sharding Mode: zero2
Bucketing: none
Compiler: thunder
Low Precision Mode: none
Average iter time: 808.52 ms
Memory used: 71.53 GB
Saved for backward size: 62988.09 MiB
Saved for backward number of tensors: 712
Tokens/s: 81032.54
Tokens/s/GPU: 10129.07
TFLOP/s: 4692.83

@beverlylytle
Copy link
Copy Markdown
Collaborator Author

beverlylytle commented Jun 4, 2025

@riccardofelluga has made an excellent point. The bound symbols which create tensors that are to be recomputed in the backward pass are only inserted once the joint trace is split. This means that those operations do not get fused optimally.

@t-vi
Copy link
Copy Markdown
Collaborator

t-vi commented Jun 4, 2025

@beverlylytle Thank you for running the benchmark.

But isn't it then the fusion or part of it that gets duplicated for recomputation?
At least that should be our target.

The other question is whether everything between the autodiff and the split propagates the tag properly. We would probably need to take a good look.
For Transform for Operator Execution, the trace processor would (want to) do it.
For the fusion passes, this may be more tricky currently.

What we need to do is likely look for recompute tags in the fusion subsymbols and then duplicate the fusion and then dce the bits away that we don't need. WDYT?

@beverlylytle
Copy link
Copy Markdown
Collaborator Author

Yes, it is the fusion or part of it that gets duplicated, but a fusion region may compute more than what is intended to be recomputed in backward, right? Moreover, the fusion regions for the basic backward together with the fusion regions for the stuff to be recomputed in the backward may not maximally fused.

@t-vi
Copy link
Copy Markdown
Collaborator

t-vi commented Jun 4, 2025

  • a fusion region may compute more than what is intended to be recomputed in backward, right?

Yes, but this could be DCEd and rerun the codegen?

  • Moreover, the fusion regions for the basic backward together with the fusion regions for the stuff to be recomputed in the backward may not maximally fused.

Indeed, and it is tricky. But our fusion algorithms all seem wild currently. The one used by JIT by default fuses little (only adjacent regions) and the one used by thunderfx reoders operations that can wreck memory consumption (we saw this for checkpointing in particular).
I think if we can rerun the codegen, we could at least re-fuse adjacent regions.

But so I seem to remember that the interaction of fusions with recomputation was one of the key things you wanted to get out of the joint trace logic, so it might be more tricky yet.

The other thing on my list of larger project that might impact our choices is to have something like

def fn(inp, target):
  out = model(inp)
  loss = lossfn(out, target)
  loss.backward()
   
jfn = jit(fn, models=(model,))

as one trace. This would let us not split the trace and likely is an important next step (followed by adding optimizer.step()) for optimizing training.

@beverlylytle
Copy link
Copy Markdown
Collaborator Author

Not only is it important that the duplicates of the bsyms that should be recomputed exist during the fusion pass, it is also important that they exist during rematerialization. I am going to try an approach where the duplication logic is moved from the splitting up to the creation of the joint trace. I will modify CSE so that no commonality will be found across the forward-backward divide.

@t-vi
Copy link
Copy Markdown
Collaborator

t-vi commented Jun 6, 2025

To my mind, we could merge this PR as is and do what is needed in a follow up. WDYT? (also @IvanYashchuk )

Copy link
Copy Markdown
Collaborator

@t-vi t-vi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exciting to have this go in!
Thank you @beverlylytle @riccardofelluga @IvanYashchuk @lantiga

@t-vi t-vi enabled auto-merge (squash) June 6, 2025 13:27
@t-vi t-vi merged commit 290a52e into main Jun 6, 2025
49 checks passed
@t-vi t-vi deleted the reautograd2 branch June 6, 2025 13:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants