Skip to content
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions thunder/transforms/autodiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,9 +520,8 @@ def backward_fn(saved_for_backward, cotangents):

def forward_and_backward_from_trace(trace: TraceCtx, torch_autograd=False) -> ForwardBackwardTraces:
if not torch_autograd:
from thunder.core.transforms import forward_and_backward_from_trace as legacy_autograd
return thunder.core.transforms.forward_and_backward_from_trace(trace, torch_autograd=torch_autograd)

return legacy_autograd(trace, torch_autograd=torch_autograd)
joint_trace = grad_transform_on_trace(trace)

forward_trace, backward_trace = split_into_forward_and_backward(joint_trace)
Expand Down
Loading