Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add insert alias tests
  • Loading branch information
riccardofelluga committed May 27, 2025
commit 9b9068a1100f6891261145de814f6120d9956e21
52 changes: 52 additions & 0 deletions thunder/tests/test_update_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,3 +370,55 @@ def f(a):
out = jitted_f(a)

torch.testing.assert_close(out, out_expected)


@instantiate(
dtypes=(dtypes.float32,),
)
def test_higher_order_inplace_alias_update(executor, device, dtype):
torch_dtype = dtypes.to_torch_dtype(dtype)

class Sin(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
ctx.save_for_backward(x)
y = x * 1
y.sin_()
return y

@staticmethod
def backward(ctx, g):
(x,) = ctx.saved_tensors
y = g * x
y.cos_()
return y

def foo(x):
return Sin.apply(x)

a = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=True)
b = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=True)
c = torch.ones(2, device=device, dtype=torch_dtype, requires_grad=True)

g = torch.rand_like(a)

jfoo = thunder.jit(foo, executors=executor.executors_list())
actual_jit = jfoo(a)

from thunder.dynamo import thunderfx

cfoo = thunderfx(foo, executors=executor.executors_list())
actual_fx = cfoo(b)

expected = foo(c)

# Checking both paths in jit_ext.py
torch.testing.assert_close(actual_jit, expected)
torch.testing.assert_close(actual_fx, expected)

actual_grad_jit = torch.autograd.grad(actual_jit, a, g)
actual_grad_fx = torch.autograd.grad(actual_fx, b, g)

expected_grad = torch.autograd.grad(expected, c, g)
torch.testing.assert_close(actual_grad_fx, expected_grad)
torch.testing.assert_close(actual_grad_jit, expected_grad)
Loading