Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Make from_dict more flexible, and add from_pytree
  • Loading branch information
ColCarroll committed Mar 13, 2024
commit c68c913a4c09111bb81250794fb56e51b230545c
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## v0.x.x Unreleased

### New features
- Support for `pytree`s and robust to nested dictionaries. (2291)

### Maintenance and fixes
- Fix deprecations introduced in latest pandas and xarray versions, and prepare for numpy 2.0 ones ([2315](https://github.com/arviz-devs/arviz/pull/2315)))
Expand Down
3 changes: 2 additions & 1 deletion arviz/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .io_cmdstan import from_cmdstan
from .io_cmdstanpy import from_cmdstanpy
from .io_datatree import from_datatree, to_datatree
from .io_dict import from_dict
from .io_dict import from_dict, from_pytree
from .io_emcee import from_emcee
from .io_json import from_json, to_json
from .io_netcdf import from_netcdf, to_netcdf
Expand Down Expand Up @@ -38,6 +38,7 @@
"from_cmdstanpy",
"from_datatree",
"from_dict",
"from_pytree",
"from_json",
"from_pyro",
"from_numpyro",
Expand Down
79 changes: 76 additions & 3 deletions arviz/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union

import numpy as np
import tree
import xarray as xr

try:
Expand Down Expand Up @@ -67,6 +68,48 @@ def wrapped(cls: RequiresArgTypeT) -> Optional[RequiresReturnTypeT]:
return wrapped


def _yield_flat_up_to(shallow_tree, input_tree, path=()):
"""Yields (path, value) pairs of input_tree flattened up to shallow_tree.

Adapted from dm-tree (https://github.com/google-deepmind/tree) to allow
lists as leaves.

Args:
shallow_tree: Nested structure. Traverse no further than its leaf nodes.
input_tree: Nested structure. Return the paths and values from this tree.
Must have the same upper structure as shallow_tree.
path: Tuple. Optional argument, only used when recursing. The path from the
root of the original shallow_tree, down to the root of the shallow_tree
arg of this recursive call.

Yields:
Pairs of (path, value), where path the tuple path of a leaf node in
shallow_tree, and value is the value of the corresponding node in
input_tree.
"""
# pylint: disable=protected-access
if isinstance(shallow_tree, tree._TEXT_OR_BYTES) or not (
isinstance(shallow_tree, tree.collections_abc.Mapping)
or tree._is_namedtuple(shallow_tree)
or tree._is_attrs(shallow_tree)
):
yield (path, input_tree)
else:
input_tree = dict(tree._yield_sorted_items(input_tree))
for shallow_key, shallow_subtree in tree._yield_sorted_items(shallow_tree):
subpath = path + (shallow_key,)
input_subtree = input_tree[shallow_key]
for leaf_path, leaf_value in _yield_flat_up_to(
shallow_subtree, input_subtree, path=subpath
):
yield (leaf_path, leaf_value)
# pylint: enable=protected-access


def _flatten_with_path(structure):
return list(_yield_flat_up_to(structure, structure))


def generate_dims_coords(
shape,
var_name,
Expand Down Expand Up @@ -255,7 +298,7 @@ def numpy_to_data_array(
return xr.DataArray(ary, coords=coords, dims=dims)


def dict_to_dataset(
def pytree_to_dataset(
data,
*,
attrs=None,
Expand All @@ -266,11 +309,34 @@ def dict_to_dataset(
index_origin=None,
skip_event_dims=None,
):
"""Convert a dictionary of numpy arrays to an xarray.Dataset.
"""Convert a pytree of numpy arrays to an xarray.Dataset.

See https://jax.readthedocs.io/en/latest/pytrees.html for what a pytree is, but
this inclues at least dictionaries and tuple types.

In case of nested pytrees, the variable name will be a tuple of individual names.

For example,

pytree_to_dataset({'top': {'second': 1.}, 'top2': 1.})

will have `var_names` `('top', 'second')` and `top2`.

Dimensions and co-ordinates can be defined as usual:

datadict = {
"top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
"d": np.random.randn(100),
}
dataset = convert_to_dataset(datadict,
coords={"c": np.arange(10)},
dims={("top", "b"): ["c"]})

Then `dataset.data_vars` will be `('top', 'a'), ('top', 'b'), 'd'`.

Parameters
----------
data : dict[str] -> ndarray
data : pytree
Data to convert. Keys are variable names.
attrs : dict
Json serializable metadata to attach to the dataset, in addition to defaults.
Expand Down Expand Up @@ -302,6 +368,10 @@ def dict_to_dataset(
"""
if dims is None:
dims = {}
try:
data = {k[0] if len(k) == 1 else k: v for k, v in _flatten_with_path(data)}
except TypeError: # probably unsortable keys -- the function will still work if
pass # it is an honest dictionary.

data_vars = {
key: numpy_to_data_array(
Expand All @@ -318,6 +388,9 @@ def dict_to_dataset(
return xr.Dataset(data_vars=data_vars, attrs=make_attrs(attrs=attrs, library=library))


dict_to_dataset = pytree_to_dataset


def make_attrs(attrs=None, library=None):
"""Make standard attributes to attach to xarray datasets.

Expand Down
4 changes: 4 additions & 0 deletions arviz/data/converters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""High level conversion functions."""
import numpy as np
import tree
import xarray as xr

from .base import dict_to_dataset
Expand Down Expand Up @@ -105,6 +106,8 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
dataset = obj.to_dataset()
elif isinstance(obj, dict):
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
elif tree.is_nested(obj) and not isinstance(obj, (list, tuple)):
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
elif isinstance(obj, np.ndarray):
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"):
Expand All @@ -118,6 +121,7 @@ def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None,
"xarray dataarray",
"xarray dataset",
"dict",
"pytree",
"netcdf filename",
"numpy array",
"pystan fit",
Expand Down
3 changes: 3 additions & 0 deletions arviz/data/io_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,3 +458,6 @@ def from_dict(
attrs=attrs,
**kwargs,
).to_inference_data()


from_pytree = from_dict
2 changes: 1 addition & 1 deletion arviz/plots/backends/matplotlib/pairplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def plot_pair(
if reference_values:
x_name = flat_var_names[i]
y_name = flat_var_names[j + not_marginals]
if x_name and y_name not in difference:
if (x_name not in difference) and (y_name not in difference):
ax[j, i].plot(
reference_values_copy[x_name],
reference_values_copy[y_name],
Expand Down
14 changes: 14 additions & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,20 @@ def test_dict_to_dataset():
assert set(dataset.b.coords) == {"chain", "draw", "c"}


def test_nested_dict_to_dataset():
datadict = {
"top": {"a": np.random.randn(100), "b": np.random.randn(1, 100, 10)},
"d": np.random.randn(100),
}
dataset = convert_to_dataset(datadict, coords={"c": np.arange(10)}, dims={("top", "b"): ["c"]})
assert set(dataset.data_vars) == {("top", "a"), ("top", "b"), "d"}
assert set(dataset.coords) == {"chain", "draw", "c"}

assert set(dataset[("top", "a")].coords) == {"chain", "draw"}
assert set(dataset[("top", "b")].coords) == {"chain", "draw", "c"}
assert set(dataset.d.coords) == {"chain", "draw"}


def test_dict_to_dataset_event_dims_error():
datadict = {"a": np.random.randn(1, 100, 10)}
coords = {"b": np.arange(10), "c": ["x", "y", "z"]}
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ numpy>=1.22.0,<2.0
scipy>=1.8.0
packaging
pandas>=1.4.0
dm-tree>=0.1.8
xarray>=0.21.0
h5netcdf>=1.0.2
typing_extensions>=4.1.0
Expand Down