Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
better get/set item and tests
  • Loading branch information
mavenlin committed Aug 22, 2025
commit 3adeb5a518cf80c633ce26d85a2ac8c3b9d2b139
54 changes: 54 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import jax.numpy as jnp
import pytest
import torch
Expand Down Expand Up @@ -59,6 +61,10 @@ def test_tensor():
t2j_function_test(lambda: torch.tensor([1, 2, 3]), [], tests=tests)
t2j_function_test(lambda: torch.tensor([[1, 2, 3], [4, 5, 6]]), [], tests=tests)

# test that an integer/boolean input will have the correct dtype
# explicitly cast them to clearly show the intention
t2j_function_test(lambda: torch.tensor(int(1)), [], tests=tests)
t2j_function_test(lambda: torch.tensor(bool(True)), [], tests=tests)
# torch allows calling torch.tensor with a torch.Tensor. This gets a little tricky with Torchish.
t2j_function_test(lambda: torch.tensor(torch.arange(3)), [], tests=tests)

Expand All @@ -85,6 +91,54 @@ def test_unbind():
t2j_function_test(torch.unbind, [(2, 3)], kwargs={"dim": 1}, tests=tests)


def test_get_set_item():
# slice(torch.tensor(1), torch.tensor(3)) is not jittable in jax, because the start/end are dynamic,
# and the shape can't be inferred. But it works when it is not jitted. This behavior is aligned with pure jax code,
# where dynamic slice only works when not jitted. Therefore we turn off jit test here.

# getitem
tests = [forward_test, backward_test]
tests_nojit = [partial(forward_test, test_jit=False), backward_test]
t2j_function_test(lambda x: x[0, 1, 2], [(3, 4, 5)], tests=tests)
t2j_function_test(lambda x: x[:, 1, 2], [(3, 4, 5)], tests=tests)
t2j_function_test(lambda x: x[0, :, 2], [(3, 4, 5)], tests=tests)
t2j_function_test(lambda x: x[0, 1, :], [(3, 4, 5)], tests=tests)
t2j_function_test(lambda x: x[1:, 1:3, 2], [(3, 4, 5)], tests=tests)
t2j_function_test(lambda x: x[torch.tensor([1, 2]), :, torch.tensor([2, 3])], [(3, 4, 5)], tests=tests)
# when the slice object is dynamic, jax will refuse to run with jit.
t2j_function_test(lambda x: x[1:, torch.tensor(1) : torch.tensor(3), 2], [(3, 4, 5)], tests=tests_nojit)

# setitem
def f(x, key, y):
x[*key] = y
return x

# in torch, __setitem__ is for inplace assignment, it is not differentiable in torch.
tests = [forward_test, partial(backward_test, argnums=(1,))]
tests_nojit = [partial(forward_test, test_jit=False), partial(backward_test, argnums=(1,))]
t2j_function_test(lambda x, y: f(x, [0, 1, 2], y), [(3, 4, 5), ()], tests=tests)
t2j_function_test(lambda x, y: f(x, [slice(None), 1, 2], y), [(3, 4, 5), (3,)], tests=tests)
t2j_function_test(lambda x, y: f(x, [0, slice(None), 2], y), [(3, 4, 5), (4,)], tests=tests)
t2j_function_test(lambda x, y: f(x, [0, 1, slice(None)], y), [(3, 4, 5), (5,)], tests=tests)
t2j_function_test(lambda x, y: f(x, [slice(1, None), slice(1, 3), 2], y), [(3, 4, 5), (2, 2)], tests=tests)
t2j_function_test(lambda x, y: f(x, [slice(1, None), slice(1, 3), 2], y), [(3, 4, 5), (2, 1)], tests=tests)
t2j_function_test(
lambda x, y: f(x, [torch.tensor([1, 2]), slice(None), torch.tensor([2, 3])], y),
[(3, 4, 5), (2, 4)],
tests=tests,
)
t2j_function_test(
lambda x, y: f(x, [slice(1, None), slice(torch.tensor(1), torch.tensor(3)), 2], y),
[(3, 4, 5), (2, 2)],
tests=tests_nojit,
)
t2j_function_test(
lambda x, y: f(x, [slice(1, None), slice(torch.tensor(1), torch.tensor(3)), 2], y),
[(3, 4, 5), ()],
tests=tests_nojit,
)


def test_oneliners():
f = [forward_test]
fb = f + [backward_test]
Expand Down
5 changes: 3 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,14 @@ def args_generator(shapes, samplers=None, rng=random.PRNGKey(123)):
yield args


def forward_test(f, args, kwargs={}, **assert_kwargs):
def forward_test(f, args, kwargs={}, test_jit=True, **assert_kwargs):
torch_args = jax.tree.map(lambda x: j2t(x.copy()) if isinstance(x, (jnp.ndarray, jnp.dtype)) else x, args)
torch_kwargs = jax.tree.map(lambda x: j2t(x) if isinstance(x, jnp.dtype) else x, kwargs)
f_ = lambda *args, **kwargs: f(*args, **torch_kwargs)
torch_output = f_(*torch_args)
aac(t2j(f_)(*args), torch_output, **assert_kwargs)
aac(jit(t2j(f_))(*args), torch_output, **assert_kwargs)
if test_jit:
aac(jit(t2j(f_))(*args), torch_output, **assert_kwargs)


def backward_test(f, args, kwargs={}, argnums=None, **assert_kwargs):
Expand Down
19 changes: 16 additions & 3 deletions torch2jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,14 @@
import jax
import jax.dlpack
import jax.numpy as jnp
import numpy as np
import torch
from torch.overrides import TorchFunctionMode, resolve_name
from torch.utils._pytree import register_pytree_node as torch_register_pytree_node
from torch.utils._pytree import tree_map as torch_tree_map

# so that __getitem__ & __setitem__ with mixed keys of int / tensor could work
torch_register_pytree_node(slice, lambda s: ((s.start, s.stop, s.step), None), lambda values, ctx: slice(*values))


class RngPooper:
Expand Down Expand Up @@ -130,7 +136,8 @@ def expand(self, *sizes):

# fmt: off
def __add__(self, other): return Torchish(self.value + _coerce(other))
def __getitem__(self, key): return Torchish(self.value.__getitem__(key))
def __getitem__(self, key): return Torchish(self.value.__getitem__(torch_tree_map(_coerce, key)))
def __int__(self): return int(self.value)
def __lt__(self, other): return Torchish(self.value < _coerce(other))
def __le__(self, other): return Torchish(self.value <= _coerce(other))
def __eq__(self, other): return Torchish(self.value == _coerce(other))
Expand All @@ -145,6 +152,8 @@ def __radd__(self, other): return Torchish(_coerce(other) + self.value)
def __rmatmul__(self, other): return Torchish(_coerce(other) @ self.value)
def __rmul__(self, other): return Torchish(_coerce(other) * self.value)
def __rsub__(self, other): return Torchish(_coerce(other) - self.value)
def __setitem__(self, key, value):
self.value = self.value.at[torch_tree_map(_coerce, key)].set(_coerce(value))
def __sub__(self, other): return Torchish(self.value - _coerce(other))

# For some reason `foo = torch.foo` doesn't work on these
Expand Down Expand Up @@ -190,8 +199,12 @@ def _coerce(x):
JAX-compatible."""
if isinstance(x, Torchish):
return x.value
elif isinstance(x, int) or isinstance(x, float):
elif isinstance(x, (int, float, np.ndarray, jnp.ndarray)): # jax compatible types
return x
elif x is None:
return None
elif isinstance(x, torch.dtype):
return t2j_dtype(x)
else:
raise NotImplementedError(
f"Attempted to _coerce with {x}, type {type(x)}. Don't know what to do with that. _coerce supports int, float, and Torchish values."
Expand Down Expand Up @@ -523,7 +536,7 @@ def sum(input, dim=None, keepdim=False, dtype=None):
def tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False):
assert not requires_grad
return jnp.array(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should preserve None as it is so that jax will infer the type from the data itself instead of torch.get_default_dtype(). The latter will make torch.tensor(1) a float tensor instead of integer.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah good catch! do we have a test for this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added them

data.value if isinstance(data, Torchish) else data, dtype=t2j_dtype(dtype or torch.get_default_dtype())
data.value if isinstance(data, Torchish) else data, dtype=(t2j_dtype(dtype) if dtype is not None else None)
)


Expand Down