Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
alias fwd trace fx path earlier
  • Loading branch information
riccardofelluga committed May 27, 2025
commit 78ff43ca167b45787821042518ac2639792b5b85
20 changes: 10 additions & 10 deletions thunder/core/jit_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,25 +940,25 @@ def _generate_random_str_id() -> str:
aug_fwd_result = aug_fwd_trace.output
output, saved_values = unwrap(aug_fwd_result)

trace_of_forward = from_trace(aug_fwd_trace)
from thunder.core.update_aliases import insert_alias_updates

alias_tensor_indices = [[i] for i in range(len(aug_fwd_trace.args))]
aliased_aug_fwd_trace = insert_alias_updates(aug_fwd_trace, alias_tensor_indices)

trace_of_forward = from_trace(aliased_aug_fwd_trace)
for bsym in aug_fwd_trace.bound_symbols:
if bsym.sym.id == prims.PrimIDs.RETURN:
continue
trace_of_forward.bound_symbols.append(bsym.from_bsym())
with tracectx(trace_of_forward):
prims.python_return(*(sequencify(output)))

from thunder.core.update_aliases import insert_alias_updates

alias_tensor_indices = [[i] for i in range(len(trace_of_forward.args))]
aliased_trace_of_forward = insert_alias_updates(trace_of_forward, alias_tensor_indices)

# See NOTE: `autograd_function_apply` and `no_grad` interaction for details about
# `thunder.torch.call_higher_order_function_and_consider_outer_autograd_setting`
@wraps(aug_fwd_trace.python_callable())
@wraps(aliased_aug_fwd_trace.python_callable())
@thunder.torch.call_higher_order_function_and_consider_outer_autograd_setting
def forward(*args, **kwargs):
return interpret_trace(aliased_trace_of_forward, *args, **kwargs)
return interpret_trace(trace_of_forward, *args, **kwargs)

grads = sequencify(tree_map(lambda t: TensorProxy(like=t), sequencify(output)))
bwd_tensor_args = grads + tuple(saved_values)
Expand Down Expand Up @@ -988,15 +988,15 @@ def forward(*args, **kwargs):
def grad_transform(*args, **kwargs):
from thunder.core.transforms import get_grad, put_grads

primal, residuals = interpret_trace(aug_fwd_trace, *args, **kwargs)
primal, residuals = interpret_trace(aliased_aug_fwd_trace, *args, **kwargs)
grads = tree_map(lambda t: get_grad(t), sequencify(primal))
bwd_args = (None,) + tuple(grads) + tuple(sequencify(residuals))
result = interpret_trace(aliased_bwd_trace, *bwd_args)
put_grads(args[1:], result)

return primal

forward_op = get_jit_ctx().ad_hoc_executor.register_operator(aliased_trace_of_forward._siginfo.name, like=forward)
forward_op = get_jit_ctx().ad_hoc_executor.register_operator(trace_of_forward._siginfo.name, like=forward)
unwrapped_output = forward_op(*unwrapped_fwd_args)
output = wrap(
unwrapped_output, provenance=ProvenanceRecord(PseudoInst.LOOKASIDE, inputs=[fwd.provenance, aug_fwd_provenance])
Expand Down