Skip to content

Better __getitem__ & __setitem__#26

Merged
samuela merged 1 commit intosamuela:mainfrom
mavenlin:get_set_item
Aug 22, 2025
Merged

Better __getitem__ & __setitem__#26
samuela merged 1 commit intosamuela:mainfrom
mavenlin:get_set_item

Conversation

@mavenlin
Copy link
Copy Markdown
Contributor

  • Existing implementation doesn't perform conversion on the key.
  • Added various tests to make sure it is working as expected.
  • In jax, jitting only works when the shape of __getitem__ can be inferred.
    Meaning that key has to be int or indice to be jittable.
    slice with start & end being a traced array would work only when it is not jitted.
    This behavior is consistent with jax, so we just turn off the jit test for dynamic slice object.

@implements(torch.tensor)
def tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False):
assert not requires_grad
return jnp.array(
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should preserve None as it is so that jax will infer the type from the data itself instead of torch.get_default_dtype(). The latter will make torch.tensor(1) a float tensor instead of integer.

Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah good catch! do we have a test for this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added them

Copy link
Copy Markdown
Owner

@samuela samuela left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @mavenlin ! and thanks for splitting up into smaller PRs, much faster to review this way. overall i think this looks good. would you mind also rebasing to one or a few commits on top of main? currently there's >30 commits that appear to be from #19

@implements(torch.tensor)
def tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False):
assert not requires_grad
return jnp.array(
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah good catch! do we have a test for this?

return jnp.array(
data.value if isinstance(data, Torchish) else data, dtype=t2j_dtype(dtype or torch.get_default_dtype())
)
return jnp.array(data.value if isinstance(data, Torchish) else data, dtype=t2j_dtype(dtype))
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

data.value if isinstance(data, Torchish) else data

hmm should we use _coerce here? pro: it lines up with the spirit of what's happening here. con: it doesn't currently support lists

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll keep as it is for now, for coerce, maybe we can consider passing everything that we don't recognize as it is?

@mavenlin mavenlin force-pushed the get_set_item branch 4 times, most recently from be1ca2d to 55f0931 Compare August 22, 2025 14:38
@samuela samuela merged commit aa844db into samuela:main Aug 22, 2025
1 check passed
@samuela
Copy link
Copy Markdown
Owner

samuela commented Aug 22, 2025

Thanks @mavenlin !

@mavenlin mavenlin deleted the get_set_item branch August 23, 2025 14:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants