-
Notifications
You must be signed in to change notification settings - Fork 17
Better __getitem__ & __setitem__ #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
- Loading branch information
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,8 +6,14 @@ | |
| import jax | ||
| import jax.dlpack | ||
| import jax.numpy as jnp | ||
| import numpy as np | ||
| import torch | ||
| from torch.overrides import TorchFunctionMode, resolve_name | ||
| from torch.utils._pytree import register_pytree_node as torch_register_pytree_node | ||
| from torch.utils._pytree import tree_map as torch_tree_map | ||
|
|
||
| # so that __getitem__ & __setitem__ with mixed keys of int / tensor could work | ||
| torch_register_pytree_node(slice, lambda s: ((s.start, s.stop, s.step), None), lambda values, ctx: slice(*values)) | ||
|
|
||
|
|
||
| class RngPooper: | ||
|
|
@@ -130,7 +136,8 @@ def expand(self, *sizes): | |
|
|
||
| # fmt: off | ||
| def __add__(self, other): return Torchish(self.value + _coerce(other)) | ||
| def __getitem__(self, key): return Torchish(self.value.__getitem__(key)) | ||
| def __getitem__(self, key): return Torchish(self.value.__getitem__(torch_tree_map(_coerce, key))) | ||
| def __int__(self): return int(self.value) | ||
| def __lt__(self, other): return Torchish(self.value < _coerce(other)) | ||
| def __le__(self, other): return Torchish(self.value <= _coerce(other)) | ||
| def __eq__(self, other): return Torchish(self.value == _coerce(other)) | ||
|
|
@@ -145,6 +152,8 @@ def __radd__(self, other): return Torchish(_coerce(other) + self.value) | |
| def __rmatmul__(self, other): return Torchish(_coerce(other) @ self.value) | ||
| def __rmul__(self, other): return Torchish(_coerce(other) * self.value) | ||
| def __rsub__(self, other): return Torchish(_coerce(other) - self.value) | ||
| def __setitem__(self, key, value): | ||
| self.value = self.value.at[torch_tree_map(_coerce, key)].set(_coerce(value)) | ||
| def __sub__(self, other): return Torchish(self.value - _coerce(other)) | ||
|
|
||
| # For some reason `foo = torch.foo` doesn't work on these | ||
|
|
@@ -190,8 +199,12 @@ def _coerce(x): | |
| JAX-compatible.""" | ||
| if isinstance(x, Torchish): | ||
| return x.value | ||
| elif isinstance(x, int) or isinstance(x, float): | ||
| elif isinstance(x, (int, float, np.ndarray, jnp.ndarray)): # jax compatible types | ||
| return x | ||
| elif x is None: | ||
| return None | ||
| elif isinstance(x, torch.dtype): | ||
| return t2j_dtype(x) | ||
| else: | ||
| raise NotImplementedError( | ||
| f"Attempted to _coerce with {x}, type {type(x)}. Don't know what to do with that. _coerce supports int, float, and Torchish values." | ||
|
|
@@ -523,7 +536,7 @@ def sum(input, dim=None, keepdim=False, dtype=None): | |
| def tensor(data, dtype=None, device=None, requires_grad=False, pin_memory=False): | ||
| assert not requires_grad | ||
| return jnp.array( | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should preserve
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah good catch! do we have a test for this?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added them |
||
| data.value if isinstance(data, Torchish) else data, dtype=t2j_dtype(dtype or torch.get_default_dtype()) | ||
| data.value if isinstance(data, Torchish) else data, dtype=(t2j_dtype(dtype) if dtype is not None else None) | ||
| ) | ||
|
|
||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.