Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
fix: validate v3 dtypes when loading/creating v3 metadata
  • Loading branch information
jhamman committed Sep 18, 2024
commit f1b01ac7cb1d673314320c9427049ae62046dcf3
4 changes: 2 additions & 2 deletions src/zarr/core/array_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal

from zarr.core.common import parse_dtype, parse_fill_value, parse_order, parse_shapelike
from zarr.core.common import parse_fill_value, parse_order, parse_shapelike

if TYPE_CHECKING:
import numpy as np
Expand All @@ -29,7 +29,7 @@ def __init__(
prototype: BufferPrototype,
) -> None:
shape_parsed = parse_shapelike(shape)
dtype_parsed = parse_dtype(dtype)
dtype_parsed = dtype # parsing is likely not needed here
fill_value_parsed = parse_fill_value(fill_value)
order_parsed = parse_order(order)

Expand Down
7 changes: 0 additions & 7 deletions src/zarr/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable, Iterator

import numpy as np
import numpy.typing as npt

ZARR_JSON = "zarr.json"
ZARRAY_JSON = ".zarray"
Expand Down Expand Up @@ -154,11 +152,6 @@ def parse_shapelike(data: int | Iterable[int]) -> tuple[int, ...]:
return data_tuple


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
# todo: real validation
return np.dtype(data)


def parse_fill_value(data: Any) -> Any:
# todo: real validation
return data
Expand Down
7 changes: 6 additions & 1 deletion src/zarr/core/metadata/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from zarr.core.array_spec import ArraySpec
from zarr.core.chunk_grids import RegularChunkGrid
from zarr.core.chunk_key_encodings import parse_separator
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_dtype, parse_shapelike
from zarr.core.common import ZARRAY_JSON, ZATTRS_JSON, parse_shapelike
from zarr.core.config import config, parse_indexing_order
from zarr.core.metadata.common import ArrayMetadata, parse_attributes

Expand Down Expand Up @@ -157,6 +157,11 @@ def update_attributes(self, attributes: dict[str, JSON]) -> Self:
return replace(self, attributes=attributes)


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
# todo: real validation
return np.dtype(data)


def parse_zarr_format(data: object) -> Literal[2]:
if data == 2:
return 2
Expand Down
32 changes: 31 additions & 1 deletion src/zarr/core/metadata/v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from zarr.core.buffer import default_buffer_prototype
from zarr.core.chunk_grids import ChunkGrid, RegularChunkGrid
from zarr.core.chunk_key_encodings import ChunkKeyEncoding
from zarr.core.common import ZARR_JSON, parse_dtype, parse_named_configuration, parse_shapelike
from zarr.core.common import ZARR_JSON, parse_named_configuration, parse_shapelike
from zarr.core.config import config
from zarr.core.metadata.common import ArrayMetadata, parse_attributes
from zarr.registry import get_codec_class
Expand Down Expand Up @@ -215,6 +215,10 @@ def from_dict(cls, data: dict[str, JSON]) -> Self:
# check that the node_type attribute is correct
_ = parse_node_type_array(_data.pop("node_type"))

# check that the data_type attribute is valid
if _data["data_type"] not in DataType:
raise ValueError(f"Invalid V3 data_type: {_data['data_type']}")

# dimension_names key is optional, normalize missing to `None`
_data["dimension_names"] = _data.pop("dimension_names", None)
# attributes key is optional, normalize missing to `None`
Expand Down Expand Up @@ -345,8 +349,11 @@ class DataType(Enum):
uint16 = "uint16"
uint32 = "uint32"
uint64 = "uint64"
float16 = "float16"
float32 = "float32"
float64 = "float64"
complex64 = "complex64"
complex128 = "complex128"

@property
def byte_count(self) -> int:
Expand All @@ -360,8 +367,11 @@ def byte_count(self) -> int:
DataType.uint16: 2,
DataType.uint32: 4,
DataType.uint64: 8,
DataType.float16: 2,
DataType.float32: 4,
DataType.float64: 8,
DataType.complex64: 8,
DataType.complex128: 16,
}
return data_type_byte_counts[self]

Expand All @@ -381,8 +391,11 @@ def to_numpy_shortname(self) -> str:
DataType.uint16: "u2",
DataType.uint32: "u4",
DataType.uint64: "u8",
DataType.float16: "f2",
DataType.float32: "f4",
DataType.float64: "f8",
DataType.complex64: "c8",
DataType.complex128: "c16",
}
return data_type_to_numpy[self]

Expand All @@ -399,7 +412,24 @@ def from_dtype(cls, dtype: np.dtype[Any]) -> DataType:
"<u2": "uint16",
"<u4": "uint32",
"<u8": "uint64",
"<f2": "float16",
"<f4": "float32",
"<f8": "float64",
"<c8": "complex64",
"<c16": "complex128",
}
return DataType[dtype_to_data_type[dtype.str]]


def parse_dtype(data: npt.DTypeLike) -> np.dtype[Any]:
try:
dtype = np.dtype(data)
except TypeError as e:
raise ValueError(f"Invalid V3 data_type: {data}") from e
# check that this is a valid v3 data_type
try:
_ = DataType.from_dtype(dtype)
except KeyError as e:
raise ValueError(f"Invalid V3 data_type: {dtype}") from e

return dtype
4 changes: 3 additions & 1 deletion src/zarr/testing/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
paths = st.lists(node_names, min_size=1).map(lambda x: "/".join(x)) | st.just("/")
np_arrays = npst.arrays(
# TODO: re-enable timedeltas once they are supported
dtype=npst.scalar_dtypes().filter(lambda x: x.kind != "m"),
dtype=npst.scalar_dtypes().filter(
lambda x: (x.kind not in ["m", "M"]) and (x.byteorder not in [">"])
),
shape=npst.array_shapes(max_dims=4),
)
stores = st.builds(MemoryStore, st.just({}), mode=st.just("w"))
Expand Down
50 changes: 37 additions & 13 deletions tests/v3/test_metadata/test_v3.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
from __future__ import annotations

import json
import re
from typing import TYPE_CHECKING, Literal

from zarr.codecs.bytes import BytesCodec
from zarr.core.buffer import default_buffer_prototype
from zarr.core.chunk_key_encodings import DefaultChunkKeyEncoding, V2ChunkKeyEncoding
from zarr.core.metadata.v3 import ArrayV3Metadata

Expand All @@ -19,7 +17,12 @@
import numpy as np
import pytest

from zarr.core.metadata.v3 import parse_dimension_names, parse_fill_value, parse_zarr_format
from zarr.core.metadata.v3 import (
parse_dimension_names,
parse_dtype,
parse_fill_value,
parse_zarr_format,
)

bool_dtypes = ("bool",)

Expand Down Expand Up @@ -234,22 +237,43 @@ def test_metadata_to_dict(
assert observed == expected


@pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897])
@pytest.mark.parametrize("precision", ["ns", "D"])
async def test_datetime_metadata(fill_value: int, precision: str) -> None:
# @pytest.mark.parametrize("fill_value", [-1, 0, 1, 2932897])
# @pytest.mark.parametrize("precision", ["ns", "D"])
# async def test_datetime_metadata(fill_value: int, precision: str) -> None:
# metadata_dict = {
# "zarr_format": 3,
# "node_type": "array",
# "shape": (1,),
# "chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
# "data_type": f"<M8[{precision}]",
# "chunk_key_encoding": {"name": "default", "separator": "."},
# "codecs": (),
# "fill_value": np.datetime64(fill_value, precision),
# }
# metadata = ArrayV3Metadata.from_dict(metadata_dict)
# # ensure there isn't a TypeError here.
# d = metadata.to_buffer_dict(default_buffer_prototype())

# result = json.loads(d["zarr.json"].to_bytes())
# assert result["fill_value"] == fill_value


async def test_invalid_dtype_raises() -> None:
metadata_dict = {
"zarr_format": 3,
"node_type": "array",
"shape": (1,),
"chunk_grid": {"name": "regular", "configuration": {"chunk_shape": (1,)}},
"data_type": f"<M8[{precision}]",
"data_type": "<M8[ns]",
"chunk_key_encoding": {"name": "default", "separator": "."},
"codecs": (),
"fill_value": np.datetime64(fill_value, precision),
"fill_value": np.datetime64(0, "ns"),
}
metadata = ArrayV3Metadata.from_dict(metadata_dict)
# ensure there isn't a TypeError here.
d = metadata.to_buffer_dict(default_buffer_prototype())
with pytest.raises(ValueError, match=r"Invalid V3 data_type"):
ArrayV3Metadata.from_dict(metadata_dict)


result = json.loads(d["zarr.json"].to_bytes())
assert result["fill_value"] == fill_value
@pytest.mark.parametrize("data", ["datetime64[s]", "foo", object()])
def test_parse_invalid_dtype_raises(data):
with pytest.raises(ValueError, match=r"Invalid V3 data_type"):
parse_dtype(data)