Skip to content

Conversation

@jakevdp
Copy link
Collaborator

@jakevdp jakevdp commented Feb 7, 2024

This provides the same API as tree.traverse from https://github.com/google-deepmind/tree

cc/ @mtthss

Fixes #19631

@jakevdp jakevdp requested a review from froystig February 7, 2024 18:54
@jakevdp jakevdp self-assigned this Feb 7, 2024
@jakevdp jakevdp requested a review from superbobry February 7, 2024 18:56
@superbobry superbobry self-requested a review February 7, 2024 22:26
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Feb 7, 2024
This provides the same API as tree.traverse from google-deepmind/tree
The structured output from the traversal.
"""
def traverse_children():
children, treedef = tree_flatten(structure, is_leaf=lambda x: x is not structure)
Copy link
Contributor

@XuehaiPan XuehaiPan Mar 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to use itertools.count to handle self-referential cases.

import itertools

counter = itertools.count()
one_level_children, one_level_treedef = tree_flatten(structure, is_leaf=lambda x: next(counter) > 0)
In [1]: from jax.tree_util import tree_flatten

In [2]: d = {'a': 1, 'b': 2, 'c': 3}; d['d'] = d

In [3]: tree_flatten(d, is_leaf=lambda x: x is not d)
RecursionError: maximum recursion depth exceeded while calling a Python object

In [4]: import itertools

In [5]: counter = itertools.count()

In [6]: tree_flatten(d, is_leaf=lambda x: next(counter) > 0)
Out[6]: ([1, 2, 3, {'a': 1, 'b': 2, 'c': 3, 'd': ...}], PyTreeDef({'a': *, 'b': *, 'c': *, 'd': *}))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should handle self-referential cases, since the underlying structure is not a tree.

:func:`jax.tree_util.tree_traverse`.
"""
def __repr__(self): return "jax.tree_util.MAP_TO_NONE"
__str__ = __repr__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@superbobry I think there is no need to add __str__ = __repr__ here, because str() can use __repr__ when __str__ is not available.

From Python documentation:

If object does not have a str() method, then str() falls back to returning repr(object).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, thanks for pointing this out, Ayaka!

"""
def __repr__(self): return "jax.tree_util.MAP_TO_NONE"
__str__ = __repr__
MAP_TO_NONE = _MapToNone()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe need to avoid double instantiation and subclassing, so that there is only one instance of class _MapToNone.

I've implemented similar things before: https://github.com/ayaka14732/einshard/blob/00277824cbf75eb2b01d5cffb1935ab7c6fdc1c6/src/einshard/parsec/Void.py

# See PEP 484 & https://github.com/google/jax/issues/7570

from jax._src.tree_util import (
MAP_TO_NONE as MAP_TO_NONE,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe export class _MapToNone as well since it may be used in writting type signatures.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

pull ready Ready for copybara import and testing

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add tree_traverse to jax.tree_util

5 participants