-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Add jax.tree_util.tree_traverse #19695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -939,3 +939,73 @@ def _prefix_error( | |
| f"{prefix_tree_keys} and {full_tree_keys}") | ||
| for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children): | ||
| yield from _prefix_error((*key_path, k), t1, t2) | ||
|
|
||
|
|
||
| def tree_traverse(fn, structure, *, top_down=True): | ||
| """Traverses the given nested structure, applying the given function. | ||
|
|
||
| The traversal is depth-first. If ``top_down`` is True (default), parents | ||
| are returned before their children (giving the option to avoid traversing | ||
| into a sub-tree). | ||
|
|
||
| >>> visited = [] | ||
| >>> tree_traverse(visited.append, [(1, 2), [3], {"a": 4}], top_down=True) | ||
| [(1, 2), [3], {'a': 4}] | ||
| >>> visited | ||
| [[(1, 2), [3], {'a': 4}], (1, 2), 1, 2, [3], 3, {'a': 4}, 4] | ||
|
|
||
| >>> visited = [] | ||
| >>> tree_traverse(visited.append, [(1, 2), [3], {"a": 4}], top_down=False) | ||
| [(1, 2), [3], {'a': 4}] | ||
| >>> visited | ||
| [1, 2, (1, 2), 3, [3], 4, {'a': 4}, [(1, 2), [3], {'a': 4}]] | ||
|
|
||
| Args: | ||
| fn: The function to be applied to each sub-nest of the structure. | ||
|
|
||
| When traversing top-down: | ||
| If ``fn(subtree) is None`` the traversal continues into the sub-tree. | ||
| If ``fn(subtree) is not None`` the traversal does not continue into | ||
| the sub-tree. The sub-tree will be replaced by ``fn(subtree)`` in the | ||
| returned structure (to replace the sub-tree with None, use the special | ||
| value :data:`MAP_TO_NONE`). | ||
|
|
||
| When traversing bottom-up: | ||
| If ``fn(subtree) is None`` the traversed sub-tree is returned unaltered. | ||
| If ``fn(subtree) is not None`` the sub-tree will be replaced by | ||
| ``fn(subtree)`` in the returned structure (to replace the sub-tree | ||
| with None, use the special value :data:`MAP_TO_NONE`). | ||
|
|
||
| structure: The structure to traverse. | ||
| top_down: If True, parent structures will be visited before their children. | ||
|
|
||
| Returns: | ||
| The structured output from the traversal. | ||
| """ | ||
| def traverse_children(): | ||
| children, treedef = tree_flatten(structure, is_leaf=lambda x: x is not structure) | ||
| if treedef_is_strict_leaf(treedef): | ||
| return structure | ||
| else: | ||
| return tree_unflatten(treedef, [tree_traverse(fn, c, top_down=top_down) for c in children]) | ||
|
|
||
| if top_down: | ||
| ret = fn(structure) | ||
| if ret is None: | ||
| return traverse_children() | ||
| else: | ||
| traversed_structure = traverse_children() | ||
| ret = fn(traversed_structure) | ||
| if ret is None: | ||
| return traversed_structure | ||
| return None if ret is MAP_TO_NONE else ret | ||
|
|
||
|
|
||
| class _MapToNone: | ||
| """ | ||
| :data:`MAP_TO_NONE` is a special object used as a sentinel within | ||
| :func:`jax.tree_util.tree_traverse`. | ||
| """ | ||
| def __repr__(self): return "jax.tree_util.MAP_TO_NONE" | ||
| __str__ = __repr__ | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @superbobry I think there is no need to add From Python documentation:
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good call, thanks for pointing this out, Ayaka! |
||
| MAP_TO_NONE = _MapToNone() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe need to avoid double instantiation and subclassing, so that there is only one instance of class I've implemented similar things before: https://github.com/ayaka14732/einshard/blob/00277824cbf75eb2b01d5cffb1935ab7c6fdc1c6/src/einshard/parsec/Void.py |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,6 +39,7 @@ | |
| # See PEP 484 & https://github.com/google/jax/issues/7570 | ||
|
|
||
| from jax._src.tree_util import ( | ||
| MAP_TO_NONE as MAP_TO_NONE, | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe export class |
||
| Partial as Partial, | ||
| PyTreeDef as PyTreeDef, | ||
| all_leaves as all_leaves, | ||
|
|
@@ -52,6 +53,7 @@ | |
| tree_reduce as tree_reduce, | ||
| tree_structure as tree_structure, | ||
| tree_transpose as tree_transpose, | ||
| tree_traverse as tree_traverse, | ||
| tree_unflatten as tree_unflatten, | ||
| treedef_children as treedef_children, | ||
| treedef_is_leaf as treedef_is_leaf, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to use
itertools.countto handle self-referential cases.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we should handle self-referential cases, since the underlying structure is not a tree.