Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
update typing to fit the rest of thunder codebase
  • Loading branch information
riccardofelluga committed May 27, 2025
commit 094a9e76c68fae2969a08d85941b9453801ca587
97 changes: 47 additions & 50 deletions thunder/transforms/autodiff.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import thunder.core.transforms
from thunder.core.transforms import ForwardBackwardTraces
import time

from thunder.core import prims, utils

from thunder.core.pytree import tree_map, tree_iter, tree_flatten_with_dataclass
from thunder.core.proxies import TensorProxy, ProxyTag, Proxy, CollectionProxy, variableify
from thunder.core.symbol import BoundSymbol, BoundSymbolTag
from thunder.core.trace import TraceProvenance, tracectx, TraceCtx, from_trace, TraceTag
from thunder.core.trace_interpreter import TraceSubstitutionProcessor
from thunder.core.transforms import (
is_constant_for_vjp,
_get_gradfn_and_executor,
augmented_forward_impls,
backward_impls,
recompute_saved_for_backward,
ForwardBackwardTraces,
)
from thunder.core.proxies import ProxyTag
from thunder.core.symbol import BoundSymbol, BoundSymbolTag
from thunder.core.vjp_utils import make_aug_forward_and_backward
from thunder.core.pytree import tree_map
import thunder
import time
from thunder.core.transform_common import dce
import thunder.torch as ltorch


def _should_recompute_bsym_in_backward(bsym):
Expand All @@ -35,12 +36,12 @@ def grad_transform_on_trace(trace, /, *args, **kwargs):
# - if neither of the above apply, and the symbol has subsymbols, push the decomposition
# to the front of the queue
# - if none of the above apply and we have a prim, raise an error
class AugmentedForwardProcessor(thunder.core.trace_interpreter.TraceSubstitutionProcessor):
class AugmentedForwardProcessor(TraceSubstitutionProcessor):
def __init__(self, trace):
super().__init__(trace)
self.collected_bw_part_bsyms = []

def process_bsym(self, bsym: thunder.core.symbol.BoundSymbol) -> None:
def process_bsym(self, bsym: BoundSymbol) -> None:
if bsym.sym is prims.python_return:
# BEGINNING of return handling (and putting the backward computation in the joint trace)
# This is big (and a bit messy):
Expand All @@ -56,8 +57,8 @@ def process_bsym(self, bsym: thunder.core.symbol.BoundSymbol) -> None:
#
input_proxy_names = {p.name for p in bsym.args[0]["flat_args"] if isinstance(p, thunder.Proxy)}
output_proxy_names = set()
for o in thunder.core.pytree.tree_iter(bsym.args[0]["output"]):
if isinstance(o, thunder.Proxy):
for o in tree_iter(bsym.args[0]["output"]):
if isinstance(o, Proxy):
output_proxy_names.add(self.read(o).name)
grad_proxy_map = {}

Expand Down Expand Up @@ -141,7 +142,7 @@ def process_bsym(self, bsym: thunder.core.symbol.BoundSymbol) -> None:
grad_flat_args = []
for p in bsym.args[0]["flat_args"]:
# or p = self.read(p) here?
if isinstance(p, thunder.TensorProxy) and p.requires_grad and p.name in grad_proxy_map:
if isinstance(p, TensorProxy) and p.requires_grad and p.name in grad_proxy_map:
# is it always OK if we don't have a gradient? (one case: unused input)
# result of put_grad???
grad_flat_args.append(grad_proxy_map[p.name])
Expand All @@ -164,8 +165,8 @@ def process_bsym(self, bsym: thunder.core.symbol.BoundSymbol) -> None:
self.set_result(bsym.output)
return

# 2. Special case the thunder.torch.checkpoint higher order function
if bsym.sym == thunder.torch.checkpoint:
# 2. Special case the ltorch.checkpoint higher order function
if bsym.sym == ltorch.checkpoint:
# Tag all intermediate outputs as to be recomputed.
function_arg_names = {a.name for a in bsym.flat_proxy_args}

Expand All @@ -190,15 +191,15 @@ def process_bsym(self, bsym: thunder.core.symbol.BoundSymbol) -> None:
# this is a bit of a hack in order to only replace the output,
# not the input
(a,) = bsym.args
a_inp = self.swap_map.get(thunder.core.proxies.variableify(a), a)
with thunder.core.trace.tracectx(self.new_trace):
a_inp = self.swap_map.get(variableify(a), a)
with tracectx(self.new_trace):
o = prims.shallow_copy(a_inp)
self.add_to_swap_map(a, o)
self.add_to_swap_map(a_inp, o)
self.write(a_inp, o)

self.new_trace.push_scope([])
with thunder.core.trace.tracectx(self.new_trace):
with tracectx(self.new_trace):
prims.put_grad(a_inp, prims.get_grad(o))
backward_part_bsyms = self.new_trace.pop_scope()
self.collected_bw_part_bsyms.insert(0, backward_part_bsyms)
Expand Down Expand Up @@ -230,34 +231,34 @@ def joint_forward_backward(*args, **kwargs):
# we need to shallow copy inputs that are returned for "get_grad" and "put_grad" to properly work
# (this shallow copy is the equivalent because we of an "edge" in the PyTorch autograd graph)
def shallow_copy_if_input(p):
if isinstance(p, thunder.TensorProxy) and p.name in arg_proxy_names:
return thunder.core.prims.shallow_copy(p)
if isinstance(p, TensorProxy) and p.name in arg_proxy_names:
return prims.shallow_copy(p)
return p

res = tree_map(shallow_copy_if_input, res)

# now we need the backward. it starts by getting the grad_outs
grad_outs = []
for r in thunder.core.pytree.tree_iter(res):
if isinstance(r, thunder.TensorProxy):
for r in tree_iter(res):
if isinstance(r, TensorProxy):
grad_outs.append(prims.get_grad(r))

# The backward computes the grad_inps of the bsym from the grad_outs
# TODO: non-grad outputs of bwd?
grad_inps = bwd_impl(*saved_for_backward, *grad_outs)
if isinstance(grad_inps, thunder.Proxy):
if isinstance(grad_inps, Proxy):
grad_inps = [grad_inps]

# match the grad_inps to the inputs of the boudnd symbol and put the grads

flat_inps = args
# for autograd_function_apply, skip the function args
# TODO: fix the returned gradients to include two None?.
if bsym.sym == thunder.torch.autograd_function_apply:
if bsym.sym == ltorch.autograd_function_apply:
flat_inps = args[2:]

# there may be non-gradient requiring additional args (todo: maybe only support this for non-tensor ones?)
num_flat_tensor_inps = sum(isinstance(i, thunder.TensorProxy) for i in flat_inps)
num_flat_tensor_inps = sum(isinstance(i, TensorProxy) for i in flat_inps)
utils.check(
num_flat_tensor_inps <= len(grad_inps),
lambda: f"Backward for {bsym.sym.id} returned {len(grad_inps)} value(s), but expected {num_flat_tensor_inps}",
Expand All @@ -266,7 +267,7 @@ def shallow_copy_if_input(p):
assert len(grad_inps) <= len(flat_inps)
for i, gi in zip(flat_inps, grad_inps):
# for integer proxies etc. we expect gi to be None
if isinstance(i, thunder.TensorProxy) and gi is not None:
if isinstance(i, TensorProxy) and gi is not None:
prims.put_grad(i, gi)
return res

Expand All @@ -293,9 +294,7 @@ def shallow_copy_if_input(p):
nbsym.tags |= bsym.tags

# simple splitting: only compute in forward what is needed for the output
forward_part_proxy_names = {
o.name for o in thunder.core.pytree.tree_iter(result) if isinstance(o, thunder.Proxy)
}
forward_part_proxy_names = {o.name for o in tree_iter(result) if isinstance(o, Proxy)}
forward_part_bsyms = []
backward_part_bsyms = []
for nbsym in reversed(new_bsyms):
Expand Down Expand Up @@ -338,18 +337,16 @@ def shallow_copy_if_input(p):
# run the trace through the processor
trace, _ = AugmentedForwardProcessor(trace)()
# run through DCE in case some of the gradients of intermediates are not needed.
trace = thunder.core.transform_common.dce(trace)
trace = dce(trace)

end_time_ns = time.perf_counter_ns()
elapsed_time_ns = end_time_ns - start_time_ns
elapsed_time_millis = elapsed_time_ns // 1000000
trace.set_provenance(
thunder.core.trace.TraceProvenance(f"Grad transform pass (took {elapsed_time_millis} milliseconds)")
)
trace.set_provenance(TraceProvenance(f"Grad transform pass (took {elapsed_time_millis} milliseconds)"))
return trace


def split_into_forward_and_backward(joint_trace):
def split_into_forward_and_backward(joint_trace: TraceCtx):
"""split a joint trace for forward and backward into separate ones, including recomputation (aka activation checkpointing)"""

# the joint trace will have the forward computation at the beginning and then the backward computation
Expand All @@ -376,10 +373,10 @@ def split_into_forward_and_backward(joint_trace):
assert isinstance(fw_output, tuple)

grad_outs = [None for _ in fw_output]
output_pos = {o.name: i for i, o in enumerate(fw_output) if isinstance(o, thunder.TensorProxy)}
output_pos = {o.name: i for i, o in enumerate(fw_output) if isinstance(o, TensorProxy)}

# the proxies we need to compute in the forward - we start with the outputs of the forward
forward_proxy_names = {o.name for o in thunder.core.pytree.tree_iter(fw_output) if isinstance(o, thunder.Proxy)}
forward_proxy_names = {o.name for o in tree_iter(fw_output) if isinstance(o, Proxy)}
# we also have the inputs available, so we add flat_args.
# for inplace, we need to update this (or have flat args be the right thing?...)
forward_proxy_names.update(a.name for a in return_bsym.args[0]["flat_args"] if isinstance(a, thunder.Proxy))
Expand Down Expand Up @@ -451,8 +448,8 @@ def split_into_forward_and_backward(joint_trace):
for a in bsym.flat_proxy_args
if a.name in forward_proxy_names and a.name not in backward_recomputed_proxy_names
)
saved_for_backward_tensors = [p for p in saved_for_backward.values() if isinstance(p, thunder.TensorProxy)]
saved_for_backward_other = [p for p in saved_for_backward.values() if not isinstance(p, thunder.TensorProxy)]
saved_for_backward_tensors = [p for p in saved_for_backward.values() if isinstance(p, TensorProxy)]
saved_for_backward_other = [p for p in saved_for_backward.values() if not isinstance(p, TensorProxy)]

# we build the forward trace
forward_trace = thunder.core.trace.from_trace(joint_trace)
Expand All @@ -461,10 +458,9 @@ def split_into_forward_and_backward(joint_trace):
forward_trace.bound_symbols += forward_part_bsyms

# now we create the return value and return bound symbol for the forward
fw_output_dict = {k: v for k, v in return_bsym.args[0].items() if k != "grad_flat_args"}
flat_output, _ = thunder.core.pytree.tree_flatten_with_dataclass(fw_output)
flat_output, _ = tree_flatten_with_dataclass(fw_output)
fw_output_dict["flat_output"] = tuple(flat_output)
with thunder.core.trace.tracectx(forward_trace):
with tracectx(forward_trace):
prims.python_return(fw_output_dict, (saved_for_backward_tensors, saved_for_backward_other))

# then we construct the backward trace, unpacking saved_for_backward and cotangents lists
Expand All @@ -480,11 +476,11 @@ def backward_fn(saved_for_backward, cotangents):
backward_trace.names.discard("cotangents")

# set up the inputs of the backward properly (args and unpacking)
with thunder.core.trace.tracectx(backward_trace):
p_C0 = thunder.core.proxies.CollectionProxy(list(saved_for_backward_tensors), name="C0")
p_C1 = thunder.core.proxies.CollectionProxy(list(saved_for_backward_other), name="C1")
p_saved_for_backward = thunder.core.proxies.CollectionProxy([p_C0, p_C1], name="saved_for_backward")
p_cotangents = thunder.core.proxies.CollectionProxy(grad_outs, name="cotangents")
with tracectx(backward_trace):
p_C0 = CollectionProxy(list(saved_for_backward_tensors), name="C0")
p_C1 = CollectionProxy(list(saved_for_backward_other), name="C1")
p_saved_for_backward = CollectionProxy([p_C0, p_C1], name="saved_for_backward")
p_cotangents = CollectionProxy(grad_outs, name="cotangents")

# set the args (which currently don't use the collection proxies but the collections directly)
saved_for_backward_tuple = [p_C0.collection(), p_C1.collection()]
Expand All @@ -502,15 +498,16 @@ def backward_fn(saved_for_backward, cotangents):
backward_trace.bound_symbols += backward_part_bsyms

# and finally the backward return statement
with thunder.core.trace.tracectx(backward_trace):
with tracectx(backward_trace):
prims.python_return(tuple(return_bsym.args[0]["grad_flat_args"]))

return forward_trace, backward_trace


def forward_and_backward_from_trace(trace: thunder.core.trace.TraceCtx, torch_autograd=False) -> ForwardBackwardTraces:
def forward_and_backward_from_trace(trace: TraceCtx, torch_autograd=False) -> ForwardBackwardTraces:
if not torch_autograd:
return thunder.core.transforms.forward_and_backward_from_trace(trace, torch_autograd=torch_autograd)
from thunder.core.transforms import forward_and_backward_from_trace as legacy_autograd
return legacy_autograd(trace, torch_autograd=torch_autograd)
joint_trace = grad_transform_on_trace(trace)

forward_trace, backward_trace = split_into_forward_and_backward(joint_trace)
Expand Down