-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Add jax.tree_util.tree_traverse #19695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
2bec030 to
0eec4e7
Compare
This provides the same API as tree.traverse from google-deepmind/tree
0eec4e7 to
28dc694
Compare
| The structured output from the traversal. | ||
| """ | ||
| def traverse_children(): | ||
| children, treedef = tree_flatten(structure, is_leaf=lambda x: x is not structure) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to use itertools.count to handle self-referential cases.
import itertools
counter = itertools.count()
one_level_children, one_level_treedef = tree_flatten(structure, is_leaf=lambda x: next(counter) > 0)In [1]: from jax.tree_util import tree_flatten
In [2]: d = {'a': 1, 'b': 2, 'c': 3}; d['d'] = d
In [3]: tree_flatten(d, is_leaf=lambda x: x is not d)
RecursionError: maximum recursion depth exceeded while calling a Python object
In [4]: import itertools
In [5]: counter = itertools.count()
In [6]: tree_flatten(d, is_leaf=lambda x: next(counter) > 0)
Out[6]: ([1, 2, 3, {'a': 1, 'b': 2, 'c': 3, 'd': ...}], PyTreeDef({'a': *, 'b': *, 'c': *, 'd': *}))There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should handle self-referential cases, since the underlying structure is not a tree.
| :func:`jax.tree_util.tree_traverse`. | ||
| """ | ||
| def __repr__(self): return "jax.tree_util.MAP_TO_NONE" | ||
| __str__ = __repr__ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@superbobry I think there is no need to add __str__ = __repr__ here, because str() can use __repr__ when __str__ is not available.
From Python documentation:
If object does not have a str() method, then str() falls back to returning repr(object).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good call, thanks for pointing this out, Ayaka!
| """ | ||
| def __repr__(self): return "jax.tree_util.MAP_TO_NONE" | ||
| __str__ = __repr__ | ||
| MAP_TO_NONE = _MapToNone() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe need to avoid double instantiation and subclassing, so that there is only one instance of class _MapToNone.
I've implemented similar things before: https://github.com/ayaka14732/einshard/blob/00277824cbf75eb2b01d5cffb1935ab7c6fdc1c6/src/einshard/parsec/Void.py
| # See PEP 484 & https://github.com/google/jax/issues/7570 | ||
|
|
||
| from jax._src.tree_util import ( | ||
| MAP_TO_NONE as MAP_TO_NONE, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe export class _MapToNone as well since it may be used in writting type signatures.
This provides the same API as
tree.traversefrom https://github.com/google-deepmind/treecc/ @mtthss
Fixes #19631