Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
578d06e
Refactor connection to autograd with new joint trace creation
beverlylytle May 20, 2025
b04654a
apply update_fusion_call_ctx
beverlylytle May 21, 2025
0c19438
check bw for None
beverlylytle May 21, 2025
f0182ee
don't fuse get_grad
beverlylytle May 23, 2025
440bc96
group get_grads together for torch compile fusions
beverlylytle May 23, 2025
6bb4293
remove torchex impl of get_grad in favor of OpExProcessor exception
beverlylytle May 23, 2025
01dea9d
Merge branch 'main' into reautograd2
beverlylytle May 26, 2025
f45c92d
hide behind flag and clean up
beverlylytle May 26, 2025
eb32063
Xfail test_ddp_grad_bucketing
IvanYashchuk May 27, 2025
32b9675
Xfail test_limit_in_flight_allgathers with bucketing
IvanYashchuk May 27, 2025
675ea02
Xfail test_fsdp_with_no_sync_grad_accumulation
IvanYashchuk May 27, 2025
aafc899
Xfail test_fsdp_grad_parity_with_without_bucketing
IvanYashchuk May 27, 2025
9b4b252
Fix test_rematerialize_all_gather
IvanYashchuk May 27, 2025
a27f390
Restore test_torch_compile_cat_rope_single_fusion
IvanYashchuk May 27, 2025
56625f6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 27, 2025
971c52e
remove extra rematerialization
beverlylytle May 27, 2025
d059353
Merge branch 'main' into reautograd2
beverlylytle May 28, 2025
9cc8c78
remove outdated change
beverlylytle May 28, 2025
67cde6e
Merge branch 'main' into reautograd2
beverlylytle Jun 3, 2025
e1308e1
clean up after merge
beverlylytle Jun 3, 2025
b73702a
more clean up
beverlylytle Jun 3, 2025
d39f270
Merge branch 'main' into reautograd2
beverlylytle Jun 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 36 additions & 12 deletions thunder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,16 +551,26 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com
else:
requires_grad = False

delay_trace_split = compile_options.get("delay_trace_split", True)

if requires_grad:
# Currently split_forward_backward also includes
# transform_for_execution and various sorting of symbols,
# applying transform_for_execution after this would be
# breaking the order of operations
computation_trc, backward_trc = split_forward_backward(computation_trc, cd, cs, *computation_trc.args)
# Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces
# by split_forward_backward

if not requires_grad:
if delay_trace_split:

from thunder.transforms.autodiff import grad_transform_on_trace

computation_trc = grad_transform_on_trace(computation_trc)
else:
# Currently split_forward_backward also includes
# transform_for_execution and various sorting of symbols,
# applying transform_for_execution after this would be
# breaking the order of operations
computation_trc, backward_trc = split_forward_backward(
computation_trc, cd, cs, *computation_trc.args
)
# Note computation_trc and backward_trc have been appended to cs.last_(backward_)traces
# by split_forward_backward

if backward_trc is None:
from thunder.executors.passes import transform_for_execution as transform_for_execution_pass
from thunder.executors.passes import _transform_for_operator_executor_execution
from thunder.distributed.utils import maybe_sort_waits
Expand All @@ -576,9 +586,23 @@ def apply_transforms_and_build_cache_entry(cd, cs, cache_info, prologue_trc, com
executors_list=cd.executors_list,
use_del_last_used=False,
)
computation_traces.extend(extraces)
computation_trc = computation_traces[-1]
computation_trc = thunder.executors.passes.del_last_used(computation_trc)
computation_trc = extraces[-1]

if requires_grad and delay_trace_split:
from thunder.core.rematerialization import rematerialize
from thunder.executors.passes import update_fusion_call_ctx
from thunder.transforms.autodiff import split_into_forward_and_backward

computation_trc = rematerialize(computation_trc)
computation_trc = update_fusion_call_ctx(computation_trc)
computation_trc = dce(computation_trc)
computation_trc, backward_trc = split_into_forward_and_backward(computation_trc)

computation_trc = thunder.executors.passes.del_last_used(computation_trc)
computation_traces.append(computation_trc)
if backward_trc is not None:
backward_trc = thunder.executors.passes.del_last_used(backward_trc, clear_mutable_collections=True)
backward_traces.append(backward_trc)

if not compile_options.get("disable_inplace_copy_check", False):
thunder.core.transform_common._inplace_copy_sanity_check(computation_trc)
Expand Down
3 changes: 3 additions & 0 deletions thunder/core/transform_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,9 @@ def cse_single_bsym(
skip_output=True,
)

if bsym.sym.id == prims.PrimIDs.GET_GRAD:
return new_bsym

# Skip appending this bsym to the new bound symbols due to its rhs being a common subexpression.
rhs = new_bsym.rhs
if (prior_bsym := rhs_to_bsym_map.get(rhs)) is not None and bsym._executor is prior_bsym._executor:
Expand Down
2 changes: 1 addition & 1 deletion thunder/executors/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _transform_for_operator_executor_execution(trace: TraceCtx, executors_list:
# - if none of the above apply and we have a prim, raise an error
class OpExProcessor(TraceSubstitutionProcessor):
def process_bsym(self, bsym: BoundSymbol) -> None:
if bsym.sym.python_impl is not None:
if bsym.sym.python_impl is not None or bsym.sym.id == prims.PrimIDs.GET_GRAD:
# keep the bound symbol and use the python impl
self.add_processed_bsyms([bsym])
self.set_result(bsym.output)
Expand Down
3 changes: 3 additions & 0 deletions thunder/executors/torch_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,9 @@ def split_forward_backward(computation_trc: TraceCtx, compile_data, compile_stat
# the forward trace and inputs of the backward trace.
fw_trace, bw_trace = forward_and_backward_from_trace(primal_trace, torch_autograd=True)

if bw_trace is None:
return fw_trace, None

fw_traces = [fw_trace]
bw_traces = [bw_trace]

Expand Down
2 changes: 1 addition & 1 deletion thunder/executors/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,4 @@ def cuda_device_checker(*args, **kwargs):

torch_compile_ex = TorchCompileExecutor(name="torchcompile")
register_executor(torch_compile_ex)
torch_compile_ex._implmap = {op: ImplInfo() for op in pytorch_ex.implmap}
torch_compile_ex._implmap = {op: ImplInfo() for op in pytorch_ex.implmap if op != prims.PrimIDs.GET_GRAD}
1 change: 1 addition & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -2339,6 +2339,7 @@ def _shape_impl(t):
shape = ex.register_operator("shape", meta=prims.shape_meta, fn=_shape_impl)
_register_implementation(prims.shape, shape, checker=_always_executable)


shallow_copy = ex.register_operator("shallow_copy", meta=prims.shallow_copy, fn=lambda x: x)
_register_implementation(prims.shallow_copy, shallow_copy, checker=_always_executable)

Expand Down
4 changes: 2 additions & 2 deletions thunder/tests/test_examine_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,5 +115,5 @@ def test_nanogpt_block():
# We are checking the estimated memory against a fixed value for consistency.
assert max_mem_fw[0] == 381754368
assert sum(max_mem_fw[1].values()) == 375462912
assert max_mem_bw[0] == 437292032
assert sum(max_mem_bw[1].values()) == 34642944
assert max_mem_bw[0] == 641097728
assert sum(max_mem_bw[1].values()) == 440474624
2 changes: 1 addition & 1 deletion thunder/tests/test_torch_compile_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def test_torch_compile_cat_rope_single_fusion():

backward_execution_trace = thunder.last_backward_traces(jfn)[-1]
assert len(get_fusions(backward_execution_trace)) == 1
assert len(backward_execution_trace.bound_symbols) == 14
assert len(backward_execution_trace.bound_symbols) == 17


@pytest.mark.skipif(not is_inductor_supported() or platform.system() == "Windows", reason="inductor unsupported")
Expand Down
30 changes: 27 additions & 3 deletions thunder/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,9 @@ def shallow_copy_if_input(p):
trace, _ = AugmentedForwardProcessor(trace)()
# run through DCE in case some of the gradients of intermediates are not needed.
trace = thunder.core.transform_common.dce(trace)
# group get_grad symbols together for torch compile fusions
# !!! is it preferrable to do this here or in the torch compile fusion pass?
_group_get_grad_bsyms(trace)

end_time_ns = time.perf_counter_ns()
elapsed_time_ns = end_time_ns - start_time_ns
Expand All @@ -349,6 +352,19 @@ def shallow_copy_if_input(p):
return trace


def _group_get_grad_bsyms(trace):
i = 0
n = len(trace.bound_symbols)
while i < n and trace.bound_symbols[i].sym != prims.get_grad:
i += 1
if i == n:
return
get_grad_bsyms = list(filter(lambda bsym: bsym.sym == prims.get_grad, trace.bound_symbols))
bsyms = list(filter(lambda bsym: bsym.sym != prims.get_grad, trace.bound_symbols))
bsyms = bsyms[:i] + list(get_grad_bsyms) + bsyms[i:]
trace.bound_symbols = bsyms


def split_into_forward_and_backward(joint_trace):
"""split a joint trace for forward and backward into separate ones, including recomputation (aka activation checkpointing)"""

Expand Down Expand Up @@ -376,7 +392,10 @@ def split_into_forward_and_backward(joint_trace):
assert isinstance(fw_output, tuple)

grad_outs = [None for _ in fw_output]
output_pos = {o.name: i for i, o in enumerate(fw_output) if isinstance(o, thunder.TensorProxy)}
output_pos = {}
for i, o in enumerate(fw_output):
if isinstance(o, thunder.TensorProxy):
output_pos.setdefault(o.name, []).append(i)

# the proxies we need to compute in the forward - we start with the outputs of the forward
forward_proxy_names = {o.name for o in thunder.core.pytree.tree_iter(fw_output) if isinstance(o, thunder.Proxy)}
Expand Down Expand Up @@ -412,11 +431,11 @@ def split_into_forward_and_backward(joint_trace):

# get grad is always part of the input, record the grad_out (will be part of the "cotangents" list)
if bsym.sym == prims.get_grad:
grad_outs[output_pos[bsym.args[0].name]] = bsym.output
grad_outs[output_pos[bsym.args[0].name].pop(0)] = bsym.output
continue

# copy_ updating a forward proxy is special regardless of the output
if bsym.sym == prims.copy_ and bsym.args[1].name in forward_proxy_names:
if (bsym.sym == prims.copy_ or bsym.sym.name == "copy_") and bsym.args[1].name in forward_proxy_names:
# todo: should we also handle ltorch.copy_ ?
forward_part_bsyms.insert(0, bsym.from_bsym())
forward_proxy_names.update(a.name for a in bsym.flat_proxy_args)
Expand Down Expand Up @@ -467,6 +486,11 @@ def split_into_forward_and_backward(joint_trace):
with thunder.core.trace.tracectx(forward_trace):
prims.python_return(fw_output_dict, (saved_for_backward_tensors, saved_for_backward_other))

if len(backward_part_bsyms) == 0 and not any(
[True if arg is not None else False for arg in return_bsym.args[0]["grad_flat_args"]]
):
return forward_trace, None

# then we construct the backward trace, unpacking saved_for_backward and cotangents lists
def backward_fn(saved_for_backward, cotangents):
pass
Expand Down
Loading