Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
handle overwrite kwarg correctly in to_zarr
  • Loading branch information
d-v-b committed Feb 28, 2024
commit c013c17e4ba5780979e83c37f7c852f314d65aba
81 changes: 58 additions & 23 deletions src/pydantic_zarr/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@
from typing_extensions import Annotated
from pydantic import AfterValidator, model_validator
from pydantic.functional_validators import BeforeValidator
from zarr.storage import init_group, BaseStore, contains_group
from zarr.storage import init_group, BaseStore, contains_group, contains_array
import numcodecs
import zarr
import os
import numpy as np
import numpy.typing as npt
from numcodecs.abc import Codec
from zarr.errors import ContainsGroupError
from zarr.errors import ContainsGroupError, ContainsArrayError
from pydantic_zarr.core import (
IncEx,
StrictBase,
Expand Down Expand Up @@ -266,25 +266,45 @@ def to_zarr(
The storage backend that will manifest the array.
path : str
The location of the array inside the store.
overwrite : bool
Whether to overwrite an existing array or group at the path. If overwrite is
`False` and an array or group already exists at `path`, an exception will be
raised. The default is `False`.

**kwargs : Any
Additional keyword arguments are passed to `zarr.create`.
Returns
-------
zarr.Array
A Zarr array that is structurally identical to `self`.
"""
spec_dict = self.model_dump()
attrs = spec_dict.pop("attributes")
overwrite = kwargs.pop("overwrite", False)
if kwargs.get("mode", "r").startswith("w"):
raise ValueError("Mode='w' is not supported yet")
if self.compressor is not None:
spec_dict["compressor"] = numcodecs.get_codec(spec_dict["compressor"])
if self.filters is not None:
spec_dict["filters"] = [
numcodecs.get_codec(f) for f in spec_dict["filters"]
]
result = zarr.create(store=store, path=path, **spec_dict, **kwargs)
if contains_array(store, path):
extant_array = zarr.open_array(store, path=path, mode="r")

if not self.like(extant_array):
if not overwrite:
msg = (
f"An array already exists at path {path}. "
"That array is structurally dissimilar to the array you are trying to "
"store. Call to_zarr with overwrite=True to overwrite that array."
)
raise ContainsArrayError(msg)
else:
if not overwrite:
# extant_array is read-only, so we make a new array handle that
# takes **kwargs
return zarr.open_array(
store=extant_array.store, path=extant_array.path, **kwargs
)
result = zarr.create(
store=store, path=path, overwrite=overwrite, **spec_dict, **kwargs
)
result.attrs.put(attrs)
return result

Expand Down Expand Up @@ -441,36 +461,51 @@ def to_zarr(self, store: BaseStore, path: str, **kwargs):
The storage backend that will manifest the group and its contents.
path : str
The location of the group inside the store.
overwrite : bool
Whether to overwrite an existing array or group at the path. If overwrite is
False and an array or group already exists at the path, an exception will be
raised. Defaults to False.
**kwargs : Any
Additional keyword arguments that will be passed to `zarr.create` for creating
sub-arrays.

Returns
-------
zarr.Group
A zarr group that is structurally identical to the GroupSpec.
A zarr group that is structurally identical to `self`.

"""
spec_dict = self.model_dump(exclude={"members": True})
attrs = spec_dict.pop("attributes")

overwrite = kwargs.pop("overwrite", False)
if contains_group(store, path):
if not kwargs.get("overwrite", False):
msg = (
f"A group already exists at path {path}. "
"Call to_zarr with overwrite=True to delete the existing group."
)
raise ContainsGroupError(msg)
extant_group = zarr.group(store, path=path)
if not self.like(extant_group):
if not overwrite:
msg = (
f"A group already exists at path {path}. "
"That group is structurally dissimilar to the group you are trying to store."
"Call to_zarr with overwrite=True to overwrite that group."
)
raise ContainsGroupError(msg)
else:
if not overwrite:
# if the extant group is structurally identical to self, and overwrite is false,
# then just return the extant group
return extant_group

elif contains_array(store, path) and not overwrite:
msg = (
f"An array already exists at path {path}. "
"Call to_zarr with overwrite=True to overwrite the array."
)
raise ContainsArrayError(msg)
else:
init_group(store=store, overwrite=kwargs.get("overwrite", False), path=path)
init_group(store=store, overwrite=overwrite, path=path)

result = zarr.group(store=store, path=path, **kwargs)
result = zarr.group(store=store, path=path, overwrite=overwrite)
result.attrs.put(attrs)
# consider raising an exception if a partial GroupSpec is provided
if self.members is not None:
for name, member in self.members.items():
subpath = os.path.join(path, name)
member.to_zarr(store, subpath, **kwargs)
member.to_zarr(store, subpath, overwrite=overwrite, **kwargs)

return result

Expand Down
47 changes: 43 additions & 4 deletions tests/test_v2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from pydantic import ValidationError
import pytest
import zarr
from zarr.errors import ContainsGroupError
from zarr.errors import ContainsGroupError, ContainsArrayError
from typing import Any, Literal, Union, Optional
import numcodecs
from numcodecs.abc import Codec
Expand Down Expand Up @@ -62,7 +62,8 @@ def test_array_spec(
compressor=compressor,
filters=_filters,
)
array.attrs.put({"foo": [100, 200, 300], "bar": "hello"})
attributes = {"foo": [100, 200, 300], "bar": "hello"}
array.attrs.put(attributes)
spec = ArraySpec.from_zarr(array)

assert spec.zarr_version == array._version
Expand Down Expand Up @@ -108,6 +109,35 @@ def test_array_spec(
assert spec.shape == array2.shape
assert spec.fill_value == array2.fill_value

# test serialization
store = zarr.MemoryStore()
stored = spec.to_zarr(store, path="foo")
assert ArraySpec.from_zarr(stored) == spec

# test that to_zarr is idempotent
assert spec.to_zarr(store, path="foo") == stored

# test that to_zarr raises if the extant array is different
spec_2 = spec.model_copy(update={"attributes": {"baz": 10}})
with pytest.raises(ContainsArrayError):
spec_2.to_zarr(store, path="foo")

# test that we can overwrite the dissimilar array
stored_2 = spec_2.to_zarr(store, path="foo", overwrite=True)
assert ArraySpec.from_zarr(stored_2) == spec_2

# test that mode and write_empty_chunks get passed through
assert spec_2.to_zarr(store, path="foo", mode="a").read_only is False
assert spec_2.to_zarr(store, path="foo", mode="r").read_only is True
assert (
spec_2.to_zarr(store, path="foo", write_empty_chunks=False)._write_empty_chunks
is False
)
assert (
spec_2.to_zarr(store, path="foo", write_empty_chunks=True)._write_empty_chunks
is True
)


@pytest.mark.parametrize("array", (np.arange(10), np.zeros((10, 10), dtype="uint8")))
def test_array_spec_from_array(array: npt.NDArray[Any]):
Expand Down Expand Up @@ -208,9 +238,18 @@ class ArrayAttrs(TypedDict):
observed = from_zarr(group)
assert observed == spec

# check that we can't overwrite the original group
# assert that we get the same group twice
assert to_zarr(spec, store, "/group_a") == group

# check that we can't call to_zarr targeting the original group with a different spec
spec_2 = spec.model_copy(update={"attributes": RootAttrs(foo=99, bar=[0, 1, 2])})
with pytest.raises(ContainsGroupError):
_ = to_zarr(spec_2, store, "/group_a")

# check that we can't call to_zarr with the original spec if the group has changed
group.attrs.put({"foo": 100})
with pytest.raises(ContainsGroupError):
group = to_zarr(spec, store, "/group_a")
_ = to_zarr(spec, store, "/group_a")

# materialize again with overwrite
group2 = to_zarr(spec, store, "/group_a", overwrite=True)
Expand Down