Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list

## jax 0.4.25

* New Features
* Added {func}`jax.tree_util.tree_traverse`

## jaxlib 0.4.25

## jax 0.4.24 (Feb 6, 2024)
Expand Down
2 changes: 2 additions & 0 deletions docs/jax.tree_util.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ List of Functions
:toctree: _autosummary

Partial
MAP_TO_NONE
all_leaves
build_tree
register_pytree_node
Expand All @@ -28,6 +29,7 @@ List of Functions
tree_reduce
tree_structure
tree_transpose
tree_traverse
tree_unflatten
treedef_children
treedef_is_leaf
Expand Down
70 changes: 70 additions & 0 deletions jax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,3 +939,73 @@ def _prefix_error(
f"{prefix_tree_keys} and {full_tree_keys}")
for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children):
yield from _prefix_error((*key_path, k), t1, t2)


def tree_traverse(fn, structure, *, top_down=True):
"""Traverses the given nested structure, applying the given function.

The traversal is depth-first. If ``top_down`` is True (default), parents
are returned before their children (giving the option to avoid traversing
into a sub-tree).

>>> visited = []
>>> tree_traverse(visited.append, [(1, 2), [3], {"a": 4}], top_down=True)
[(1, 2), [3], {'a': 4}]
>>> visited
[[(1, 2), [3], {'a': 4}], (1, 2), 1, 2, [3], 3, {'a': 4}, 4]

>>> visited = []
>>> tree_traverse(visited.append, [(1, 2), [3], {"a": 4}], top_down=False)
[(1, 2), [3], {'a': 4}]
>>> visited
[1, 2, (1, 2), 3, [3], 4, {'a': 4}, [(1, 2), [3], {'a': 4}]]

Args:
fn: The function to be applied to each sub-nest of the structure.

When traversing top-down:
If ``fn(subtree) is None`` the traversal continues into the sub-tree.
If ``fn(subtree) is not None`` the traversal does not continue into
the sub-tree. The sub-tree will be replaced by ``fn(subtree)`` in the
returned structure (to replace the sub-tree with None, use the special
value :data:`MAP_TO_NONE`).

When traversing bottom-up:
If ``fn(subtree) is None`` the traversed sub-tree is returned unaltered.
If ``fn(subtree) is not None`` the sub-tree will be replaced by
``fn(subtree)`` in the returned structure (to replace the sub-tree
with None, use the special value :data:`MAP_TO_NONE`).

structure: The structure to traverse.
top_down: If True, parent structures will be visited before their children.

Returns:
The structured output from the traversal.
"""
def traverse_children():
children, treedef = tree_flatten(structure, is_leaf=lambda x: x is not structure)
Copy link
Contributor

@XuehaiPan XuehaiPan Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to use itertools.count to handle self-referential cases.

import itertools

counter = itertools.count()
one_level_children, one_level_treedef = tree_flatten(structure, is_leaf=lambda x: next(counter) > 0)
In [1]: from jax.tree_util import tree_flatten

In [2]: d = {'a': 1, 'b': 2, 'c': 3}; d['d'] = d

In [3]: tree_flatten(d, is_leaf=lambda x: x is not d)
RecursionError: maximum recursion depth exceeded while calling a Python object

In [4]: import itertools

In [5]: counter = itertools.count()

In [6]: tree_flatten(d, is_leaf=lambda x: next(counter) > 0)
Out[6]: ([1, 2, 3, {'a': 1, 'b': 2, 'c': 3, 'd': ...}], PyTreeDef({'a': *, 'b': *, 'c': *, 'd': *}))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should handle self-referential cases, since the underlying structure is not a tree.

if treedef_is_strict_leaf(treedef):
return structure
else:
return tree_unflatten(treedef, [tree_traverse(fn, c, top_down=top_down) for c in children])

if top_down:
ret = fn(structure)
if ret is None:
return traverse_children()
else:
traversed_structure = traverse_children()
ret = fn(traversed_structure)
if ret is None:
return traversed_structure
return None if ret is MAP_TO_NONE else ret


class _MapToNone:
"""
:data:`MAP_TO_NONE` is a special object used as a sentinel within
:func:`jax.tree_util.tree_traverse`.
"""
def __repr__(self): return "jax.tree_util.MAP_TO_NONE"
__str__ = __repr__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@superbobry I think there is no need to add __str__ = __repr__ here, because str() can use __repr__ when __str__ is not available.

From Python documentation:

If object does not have a str() method, then str() falls back to returning repr(object).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, thanks for pointing this out, Ayaka!

MAP_TO_NONE = _MapToNone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe need to avoid double instantiation and subclassing, so that there is only one instance of class _MapToNone.

I've implemented similar things before: https://github.com/ayaka14732/einshard/blob/00277824cbf75eb2b01d5cffb1935ab7c6fdc1c6/src/einshard/parsec/Void.py

2 changes: 2 additions & 0 deletions jax/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
# See PEP 484 & https://github.com/google/jax/issues/7570

from jax._src.tree_util import (
MAP_TO_NONE as MAP_TO_NONE,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe export class _MapToNone as well since it may be used in writting type signatures.

Partial as Partial,
PyTreeDef as PyTreeDef,
all_leaves as all_leaves,
Expand All @@ -52,6 +53,7 @@
tree_reduce as tree_reduce,
tree_structure as tree_structure,
tree_transpose as tree_transpose,
tree_traverse as tree_traverse,
tree_unflatten as tree_unflatten,
treedef_children as treedef_children,
treedef_is_leaf as treedef_is_leaf,
Expand Down
35 changes: 35 additions & 0 deletions tests/tree_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,41 @@ def testNamedTupleRegisteredWithoutKeysIsntTreatedAsLeaf(self):
leaves, _ = tree_util.tree_flatten_with_path(ATuple2(1, 'hi'))
self.assertLen(leaves, 1)

def testTraverseListsToTuples(self):
structure = [(1, 2), [3], {"a": [4]}]
self.assertEqual(
((1, 2), (3,), {"a": (4,)}),
tree_util.tree_traverse(
lambda x: tuple(x) if isinstance(x, list) else x,
structure,
top_down=False))

def testTraverseEarlyTermination(self):
structure = [(1, [2]), [3, (4, 5, 6)]]
visited = []
def visit(x):
visited.append(x)
return "X" if isinstance(x, tuple) and len(x) > 2 else None

output = tree_util.tree_traverse(visit, structure)
self.assertEqual([(1, [2]), [3, "X"]], output)
self.assertEqual(
[[(1, [2]), [3, (4, 5, 6)]],
(1, [2]), 1, [2], 2, [3, (4, 5, 6)], 3, (4, 5, 6)],
visited)

def testTraverseMapToNone(self):
structure = [(1, 2), ['c'], {"a": [4, 'd']}]
def string_to_none(x):
if isinstance(x, str):
return tree_util.MAP_TO_NONE
return None
output = tree_util.tree_traverse(string_to_none, structure, top_down=True)
self.assertEqual(output, [(1, 2), [None], {'a': [4, None]}])
output = tree_util.tree_traverse(string_to_none, structure, top_down=False)
self.assertEqual(output, [(1, 2), [None], {'a': [4, None]}])



class StaticTest(parameterized.TestCase):

Expand Down