diff --git a/CHANGELOG.md b/CHANGELOG.md index 83be5f478ec9..9c569d066baa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.25 +* New Features + * Added {func}`jax.tree_util.tree_traverse` + ## jaxlib 0.4.25 ## jax 0.4.24 (Feb 6, 2024) diff --git a/docs/jax.tree_util.rst b/docs/jax.tree_util.rst index da52d87868e0..cb5ef7215b37 100644 --- a/docs/jax.tree_util.rst +++ b/docs/jax.tree_util.rst @@ -12,6 +12,7 @@ List of Functions :toctree: _autosummary Partial + MAP_TO_NONE all_leaves build_tree register_pytree_node @@ -28,6 +29,7 @@ List of Functions tree_reduce tree_structure tree_transpose + tree_traverse tree_unflatten treedef_children treedef_is_leaf diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index c8ce4fe7a598..9adb2b54410c 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -939,3 +939,73 @@ def _prefix_error( f"{prefix_tree_keys} and {full_tree_keys}") for k, t1, t2 in zip(prefix_tree_keys, prefix_tree_children, full_tree_children): yield from _prefix_error((*key_path, k), t1, t2) + + +def tree_traverse(fn, structure, *, top_down=True): + """Traverses the given nested structure, applying the given function. + + The traversal is depth-first. If ``top_down`` is True (default), parents + are returned before their children (giving the option to avoid traversing + into a sub-tree). + + >>> visited = [] + >>> tree_traverse(visited.append, [(1, 2), [3], {"a": 4}], top_down=True) + [(1, 2), [3], {'a': 4}] + >>> visited + [[(1, 2), [3], {'a': 4}], (1, 2), 1, 2, [3], 3, {'a': 4}, 4] + + >>> visited = [] + >>> tree_traverse(visited.append, [(1, 2), [3], {"a": 4}], top_down=False) + [(1, 2), [3], {'a': 4}] + >>> visited + [1, 2, (1, 2), 3, [3], 4, {'a': 4}, [(1, 2), [3], {'a': 4}]] + + Args: + fn: The function to be applied to each sub-nest of the structure. + + When traversing top-down: + If ``fn(subtree) is None`` the traversal continues into the sub-tree. + If ``fn(subtree) is not None`` the traversal does not continue into + the sub-tree. The sub-tree will be replaced by ``fn(subtree)`` in the + returned structure (to replace the sub-tree with None, use the special + value :data:`MAP_TO_NONE`). + + When traversing bottom-up: + If ``fn(subtree) is None`` the traversed sub-tree is returned unaltered. + If ``fn(subtree) is not None`` the sub-tree will be replaced by + ``fn(subtree)`` in the returned structure (to replace the sub-tree + with None, use the special value :data:`MAP_TO_NONE`). + + structure: The structure to traverse. + top_down: If True, parent structures will be visited before their children. + + Returns: + The structured output from the traversal. + """ + def traverse_children(): + children, treedef = tree_flatten(structure, is_leaf=lambda x: x is not structure) + if treedef_is_strict_leaf(treedef): + return structure + else: + return tree_unflatten(treedef, [tree_traverse(fn, c, top_down=top_down) for c in children]) + + if top_down: + ret = fn(structure) + if ret is None: + return traverse_children() + else: + traversed_structure = traverse_children() + ret = fn(traversed_structure) + if ret is None: + return traversed_structure + return None if ret is MAP_TO_NONE else ret + + +class _MapToNone: + """ + :data:`MAP_TO_NONE` is a special object used as a sentinel within + :func:`jax.tree_util.tree_traverse`. + """ + def __repr__(self): return "jax.tree_util.MAP_TO_NONE" + __str__ = __repr__ +MAP_TO_NONE = _MapToNone() diff --git a/jax/tree_util.py b/jax/tree_util.py index eaa8247735a4..356767469879 100644 --- a/jax/tree_util.py +++ b/jax/tree_util.py @@ -39,6 +39,7 @@ # See PEP 484 & https://github.com/google/jax/issues/7570 from jax._src.tree_util import ( + MAP_TO_NONE as MAP_TO_NONE, Partial as Partial, PyTreeDef as PyTreeDef, all_leaves as all_leaves, @@ -52,6 +53,7 @@ tree_reduce as tree_reduce, tree_structure as tree_structure, tree_transpose as tree_transpose, + tree_traverse as tree_traverse, tree_unflatten as tree_unflatten, treedef_children as treedef_children, treedef_is_leaf as treedef_is_leaf, diff --git a/tests/tree_util_test.py b/tests/tree_util_test.py index cc2c21fc4981..959704fb953e 100644 --- a/tests/tree_util_test.py +++ b/tests/tree_util_test.py @@ -743,6 +743,41 @@ def testNamedTupleRegisteredWithoutKeysIsntTreatedAsLeaf(self): leaves, _ = tree_util.tree_flatten_with_path(ATuple2(1, 'hi')) self.assertLen(leaves, 1) + def testTraverseListsToTuples(self): + structure = [(1, 2), [3], {"a": [4]}] + self.assertEqual( + ((1, 2), (3,), {"a": (4,)}), + tree_util.tree_traverse( + lambda x: tuple(x) if isinstance(x, list) else x, + structure, + top_down=False)) + + def testTraverseEarlyTermination(self): + structure = [(1, [2]), [3, (4, 5, 6)]] + visited = [] + def visit(x): + visited.append(x) + return "X" if isinstance(x, tuple) and len(x) > 2 else None + + output = tree_util.tree_traverse(visit, structure) + self.assertEqual([(1, [2]), [3, "X"]], output) + self.assertEqual( + [[(1, [2]), [3, (4, 5, 6)]], + (1, [2]), 1, [2], 2, [3, (4, 5, 6)], 3, (4, 5, 6)], + visited) + + def testTraverseMapToNone(self): + structure = [(1, 2), ['c'], {"a": [4, 'd']}] + def string_to_none(x): + if isinstance(x, str): + return tree_util.MAP_TO_NONE + return None + output = tree_util.tree_traverse(string_to_none, structure, top_down=True) + self.assertEqual(output, [(1, 2), [None], {'a': [4, None]}]) + output = tree_util.tree_traverse(string_to_none, structure, top_down=False) + self.assertEqual(output, [(1, 2), [None], {'a': [4, None]}]) + + class StaticTest(parameterized.TestCase):