Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update typing to conform to the rest of thunder code base
  • Loading branch information
riccardofelluga committed May 28, 2025
commit 0493ed85e7c2e9690e485f3cf250b718b8771af1
107 changes: 51 additions & 56 deletions thunder/transforms/autodiff.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
import thunder.core.proxies
import thunder.core.pytree
import thunder.core.transform_common
import thunder.core.transforms
from thunder.core.transforms import ForwardBackwardTraces
import time

from thunder.core import prims, utils

from thunder.core.pytree import tree_map, tree_iter, tree_flatten_with_dataclass
from thunder.core.proxies import TensorProxy, ProxyTag, Proxy, CollectionProxy, variableify
from thunder.core.symbol import BoundSymbol, BoundSymbolTag
from thunder.core.trace import TraceProvenance, tracectx, TraceCtx, from_trace, TraceTag
from thunder.core.trace_interpreter import TraceSubstitutionProcessor
from thunder.core.transforms import (
is_constant_for_vjp,
_get_gradfn_and_executor,
augmented_forward_impls,
backward_impls,
recompute_saved_for_backward,
ForwardBackwardTraces,
)
from thunder.core.proxies import ProxyTag
from thunder.core.symbol import BoundSymbol, BoundSymbolTag
import thunder.core.utils
from thunder.core.vjp_utils import make_aug_forward_and_backward
from thunder.core.pytree import tree_map
import thunder
import time
from thunder.core.transform_common import dce
import thunder.torch as ltorch


def _should_recompute_bsym_in_backward(bsym):
Expand All @@ -39,12 +36,12 @@ def grad_transform_on_trace(trace, /, *args, **kwargs):
# - if neither of the above apply, and the symbol has subsymbols, push the decomposition
# to the front of the queue
# - if none of the above apply and we have a prim, raise an error
class AugmentedForwardProcessor(thunder.core.trace_interpreter.TraceSubstitutionProcessor):
class AugmentedForwardProcessor(TraceSubstitutionProcessor):
def __init__(self, trace):
super().__init__(trace)
self.collected_bw_part_bsyms = []

def process_bsym(self, bsym: thunder.core.symbol.BoundSymbol) -> None:
def process_bsym(self, bsym: BoundSymbol) -> None:
if bsym.sym is prims.python_return:
# BEGINNING of return handling (and putting the backward computation in the joint trace)
# This is big (and a bit messy):
Expand All @@ -59,8 +56,8 @@ def process_bsym(self, bsym: thunder.core.symbol.BoundSymbol) -> None:
# backward computation.

output_proxy_names = set()
for o in thunder.core.pytree.tree_iter(bsym.args[0]["output"]):
if isinstance(o, thunder.Proxy):
for o in tree_iter(bsym.args[0]["output"]):
if isinstance(o, Proxy):
output_proxy_names.add(self.read(o).name)
grad_proxy_map = {}

Expand Down Expand Up @@ -94,7 +91,7 @@ def process_bsym(self, bsym: thunder.core.symbol.BoundSymbol) -> None:
self.add_processed_bsyms(self.new_trace.pop_scope())

if current_grad is not None:
new_grad = self.add_bsyms_from_function(thunder.torch.add, current_grad, new_grad)
new_grad = self.add_bsyms_from_function(ltorch.add, current_grad, new_grad)

grad_proxy_map[p.name] = new_grad
self.write(new_grad, new_grad)
Expand All @@ -121,15 +118,15 @@ def process_bsym(self, bsym: thunder.core.symbol.BoundSymbol) -> None:
# do we also need to map?
self.write(nbsym.output, current_grad)
# replace_output_with_current_grad
self.swap_map[thunder.core.proxies.variableify(nbsym.output)] = current_grad
self.swap_map[variableify(nbsym.output)] = current_grad
elif name in output_proxy_names:
# output here???
new_bsym = nbsym.from_bsym()
self.add_processed_bsyms([new_bsym])
grad_proxy_map[name] = new_bsym.output
else:
# TODO: mark this? if all inputs to the backward formula are unused, we would not want to compute it.
new_bsym = thunder.torch.zeros.bind(
new_bsym = ltorch.zeros.bind(
*p.shape, device=p.device, dtype=p.dtype, output=nbsym.output
)
self.add_processed_bsyms([new_bsym])
Expand All @@ -145,7 +142,7 @@ def process_bsym(self, bsym: thunder.core.symbol.BoundSymbol) -> None:
grad_flat_args = []
for p in bsym.args[0]["flat_args"]:
# or p = self.read(p) here?
if isinstance(p, thunder.TensorProxy) and p.requires_grad and p.name in grad_proxy_map:
if isinstance(p, TensorProxy) and p.requires_grad and p.name in grad_proxy_map:
# is it always OK if we don't have a gradient? (one case: unused input)
# result of put_grad???
grad_flat_args.append(grad_proxy_map[p.name])
Expand Down Expand Up @@ -177,8 +174,8 @@ def process_bsym(self, bsym: thunder.core.symbol.BoundSymbol) -> None:
self.set_result(bsym.output)
return

# 2. Special case the thunder.torch.checkpoint higher order function
if bsym.sym == thunder.torch.checkpoint:
# 2. Special case the ltorch.checkpoint higher order function
if bsym.sym == ltorch.checkpoint:
# Tag all intermediate outputs as to be recomputed.
function_arg_names = {a.name for a in bsym.flat_proxy_args}

Expand All @@ -203,15 +200,15 @@ def process_bsym(self, bsym: thunder.core.symbol.BoundSymbol) -> None:
# this is a bit of a hack in order to only replace the output,
# not the input
(a,) = bsym.args
a_inp = self.swap_map.get(thunder.core.proxies.variableify(a), a)
with thunder.core.trace.tracectx(self.new_trace):
a_inp = self.swap_map.get(variableify(a), a)
with tracectx(self.new_trace):
o = prims.shallow_copy(a_inp)
self.add_to_swap_map(a, o)
self.add_to_swap_map(a_inp, o)
self.write(a_inp, o)

self.new_trace.push_scope([])
with thunder.core.trace.tracectx(self.new_trace):
with tracectx(self.new_trace):
prims.put_grad(a_inp, prims.get_grad(o))
backward_part_bsyms = self.new_trace.pop_scope()
self.collected_bw_part_bsyms.insert(0, backward_part_bsyms)
Expand Down Expand Up @@ -243,34 +240,34 @@ def joint_forward_backward(*args, **kwargs):
# we need to shallow copy inputs that are returned for "get_grad" and "put_grad" to properly work
# (this shallow copy is the equivalent because we of an "edge" in the PyTorch autograd graph)
def shallow_copy_if_input(p):
if isinstance(p, thunder.TensorProxy) and p.name in arg_proxy_names:
return thunder.core.prims.shallow_copy(p)
if isinstance(p, TensorProxy) and p.name in arg_proxy_names:
return prims.shallow_copy(p)
return p

res = tree_map(shallow_copy_if_input, res)

# now we need the backward. it starts by getting the grad_outs
grad_outs = []
for r in thunder.core.pytree.tree_iter(res):
if isinstance(r, thunder.TensorProxy):
for r in tree_iter(res):
if isinstance(r, TensorProxy):
grad_outs.append(prims.get_grad(r))

# The backward computes the grad_inps of the bsym from the grad_outs
# TODO: non-grad outputs of bwd?
grad_inps = bwd_impl(*saved_for_backward, *grad_outs)
if isinstance(grad_inps, thunder.Proxy):
if isinstance(grad_inps, Proxy):
grad_inps = [grad_inps]

# match the grad_inps to the inputs of the boudnd symbol and put the grads

flat_inps = args
# for autograd_function_apply, skip the function args
# TODO: fix the returned gradients to include two None?.
if bsym.sym == thunder.torch.autograd_function_apply:
if bsym.sym == ltorch.autograd_function_apply:
flat_inps = args[2:]

# there may be non-gradient requiring additional args (todo: maybe only support this for non-tensor ones?)
num_flat_tensor_inps = sum(isinstance(i, thunder.TensorProxy) for i in flat_inps)
num_flat_tensor_inps = sum(isinstance(i, TensorProxy) for i in flat_inps)
utils.check(
num_flat_tensor_inps <= len(grad_inps),
lambda: f"Backward for {bsym.sym.id} returned {len(grad_inps)} value(s), but expected {num_flat_tensor_inps}",
Expand All @@ -279,7 +276,7 @@ def shallow_copy_if_input(p):
assert len(grad_inps) <= len(flat_inps)
for i, gi in zip(flat_inps, grad_inps):
# for integer proxies etc. we expect gi to be None
if isinstance(i, thunder.TensorProxy) and gi is not None:
if isinstance(i, TensorProxy) and gi is not None:
prims.put_grad(i, gi)
return res

Expand All @@ -306,9 +303,7 @@ def shallow_copy_if_input(p):
nbsym.tags |= bsym.tags

# simple splitting: only compute in forward what is needed for the output
forward_part_proxy_names = {
o.name for o in thunder.core.pytree.tree_iter(result) if isinstance(o, thunder.Proxy)
}
forward_part_proxy_names = {o.name for o in tree_iter(result) if isinstance(o, Proxy)}
forward_part_bsyms = []
backward_part_bsyms = []
for nbsym in reversed(new_bsyms):
Expand Down Expand Up @@ -352,16 +347,16 @@ def shallow_copy_if_input(p):
joint_trace, _ = AugmentedForwardProcessor(trace)()

# run through DCE in case some of the gradients of intermediates are not needed.
joint_trace = thunder.core.transform_common.dce(joint_trace)
joint_trace = dce(joint_trace)

end_time_ns = time.perf_counter_ns()
elapsed_time_ns = end_time_ns - start_time_ns
elapsed_time_millis = elapsed_time_ns // 1000000
joint_trace.set_provenance(thunder.core.trace.TraceProvenance(f"Grad transform pass (took {elapsed_time_millis} milliseconds)"))
joint_trace.set_provenance(TraceProvenance(f"Grad transform pass (took {elapsed_time_millis} milliseconds)"))
return joint_trace


def split_into_forward_and_backward(joint_trace):
def split_into_forward_and_backward(joint_trace: TraceCtx):
"""split a joint trace for forward and backward into separate ones, including recomputation (aka activation checkpointing)"""

# the joint trace will have the forward computation at the beginning and then the backward computation
Expand All @@ -388,10 +383,10 @@ def split_into_forward_and_backward(joint_trace):
assert isinstance(fw_output, tuple)

grad_outs = [None for _ in fw_output]
output_pos = {o.name: i for i, o in enumerate(fw_output) if isinstance(o, thunder.TensorProxy)}
output_pos = {o.name: i for i, o in enumerate(fw_output) if isinstance(o, TensorProxy)}

# the proxies we need to compute in the forward - we start with the outputs of the forward
forward_proxy_names = {o.name for o in thunder.core.pytree.tree_iter(fw_output) if isinstance(o, thunder.Proxy)}
forward_proxy_names = {o.name for o in tree_iter(fw_output) if isinstance(o, Proxy)}
# we also have the inputs available, so we add flat_args.
# for inplace, we need to update this (or have flat args be the right thing?...)
forward_proxy_names.update(a.name for a in return_bsym.args[0]["flat_args"] if isinstance(a, thunder.Proxy))
Expand Down Expand Up @@ -463,12 +458,12 @@ def split_into_forward_and_backward(joint_trace):
for a in bsym.flat_proxy_args
if a.name in forward_proxy_names and a.name not in backward_recomputed_proxy_names
)
saved_for_backward_tensors = [p for p in saved_for_backward.values() if isinstance(p, thunder.TensorProxy)]
saved_for_backward_other = [p for p in saved_for_backward.values() if not isinstance(p, thunder.TensorProxy)]
saved_for_backward_tensors = [p for p in saved_for_backward.values() if isinstance(p, TensorProxy)]
saved_for_backward_other = [p for p in saved_for_backward.values() if not isinstance(p, TensorProxy)]

# we build the forward trace
forward_trace = thunder.core.trace.from_trace(joint_trace)
forward_trace.tags.add(thunder.core.trace.TraceTag.AUGMENTED_FORWARD)
forward_trace = from_trace(joint_trace)
forward_trace.tags.add(TraceTag.AUGMENTED_FORWARD)
forward_trace.names = forward_trace.names.copy() ## ehem
forward_trace.bound_symbols += forward_part_bsyms

Expand All @@ -477,16 +472,16 @@ def split_into_forward_and_backward(joint_trace):
# replace the backward output with the forward one for the forward trace
fw_output_dict.update({"output": return_bsym.args[0]["fw_flat_out"]})

flat_output, _ = thunder.core.pytree.tree_flatten_with_dataclass(fw_output)
flat_output, _ = tree_flatten_with_dataclass(fw_output)
fw_output_dict["flat_output"] = tuple(flat_output)
with thunder.core.trace.tracectx(forward_trace):
with tracectx(forward_trace):
prims.python_return(fw_output_dict, (saved_for_backward_tensors, saved_for_backward_other))

# then we construct the backward trace, unpacking saved_for_backward and cotangents lists
def backward_fn(saved_for_backward, cotangents):
pass

backward_trace = thunder.core.trace.TraceCtx(fn=backward_fn)
backward_trace = TraceCtx(fn=backward_fn)
backward_trace.names = forward_trace.names
backward_trace.name_ctr = forward_trace.name_ctr

Expand All @@ -495,11 +490,11 @@ def backward_fn(saved_for_backward, cotangents):
backward_trace.names.discard("cotangents")

# set up the inputs of the backward properly (args and unpacking)
with thunder.core.trace.tracectx(backward_trace):
p_C0 = thunder.core.proxies.CollectionProxy(list(saved_for_backward_tensors), name="C0")
p_C1 = thunder.core.proxies.CollectionProxy(list(saved_for_backward_other), name="C1")
p_saved_for_backward = thunder.core.proxies.CollectionProxy([p_C0, p_C1], name="saved_for_backward")
p_cotangents = thunder.core.proxies.CollectionProxy(grad_outs, name="cotangents")
with tracectx(backward_trace):
p_C0 = CollectionProxy(list(saved_for_backward_tensors), name="C0")
p_C1 = CollectionProxy(list(saved_for_backward_other), name="C1")
p_saved_for_backward = CollectionProxy([p_C0, p_C1], name="saved_for_backward")
p_cotangents = CollectionProxy(grad_outs, name="cotangents")

# set the args (which currently don't use the collection proxies but the collections directly)
saved_for_backward_tuple = [p_C0.collection(), p_C1.collection()]
Expand All @@ -517,13 +512,13 @@ def backward_fn(saved_for_backward, cotangents):
backward_trace.bound_symbols += backward_part_bsyms

# and finally the backward return statement
with thunder.core.trace.tracectx(backward_trace):
with tracectx(backward_trace):
prims.python_return(tuple(return_bsym.args[0]["output"]))

return forward_trace, backward_trace


def forward_and_backward_from_trace(trace: thunder.core.trace.TraceCtx, torch_autograd=False) -> ForwardBackwardTraces:
def forward_and_backward_from_trace(trace: TraceCtx, torch_autograd=False) -> ForwardBackwardTraces:
if not torch_autograd:
from thunder.core.transforms import forward_and_backward_from_trace as legacy_autograd

Expand Down