diff --git a/src/mdio/api/io.py b/src/mdio/api/io.py index e72cafeb3..2e34e67ea 100644 --- a/src/mdio/api/io.py +++ b/src/mdio/api/io.py @@ -6,11 +6,14 @@ from typing import Any from typing import Literal +import zarr from upath import UPath from xarray import Dataset as xr_Dataset from xarray import open_zarr as xr_open_zarr from xarray.backends.api import to_zarr as xr_to_zarr +from mdio.constants import ZarrFormat + if TYPE_CHECKING: from collections.abc import Mapping from pathlib import Path @@ -47,7 +50,13 @@ def open_mdio(input_path: UPath | Path | str, chunks: T_Chunks = None) -> xr_Dat """ input_path = _normalize_path(input_path) storage_options = _normalize_storage_options(input_path) - return xr_open_zarr(input_path.as_posix(), chunks=chunks, storage_options=storage_options) + zarr_format = zarr.config.get("default_zarr_format") + return xr_open_zarr( + input_path.as_posix(), + chunks=chunks, + storage_options=storage_options, + mask_and_scale=zarr_format == ZarrFormat.V3, # off for v2, on for v3 + ) def to_mdio( # noqa: PLR0913 @@ -57,7 +66,6 @@ def to_mdio( # noqa: PLR0913 *, compute: bool = True, region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None, - zarr_format: int = 3, ) -> None: """Write dataset contents to an MDIO output_path. @@ -74,18 +82,17 @@ def to_mdio( # noqa: PLR0913 can be computed to write array data later. Metadata is always updated eagerly. region: Optional mapping from dimension names to either a) ``"auto"``, or b) integer slices, indicating the region of existing MDIO array(s) in which to write this dataset's data. - zarr_format: The desired zarr format to target. The default is 3. """ output_path = _normalize_path(output_path) storage_options = _normalize_storage_options(output_path) + zarr_format = zarr.config.get("default_zarr_format") xr_to_zarr( dataset, store=output_path.as_posix(), # xarray doesn't like URI when file:// is protocol mode=mode, compute=compute, - consolidated=False, + consolidated=zarr_format == ZarrFormat.V2, # off for v3, on for v2 region=region, storage_options=storage_options, - zarr_format=zarr_format, write_empty_chunks=False, ) diff --git a/src/mdio/constants.py b/src/mdio/constants.py index 0036f5a90..783623281 100644 --- a/src/mdio/constants.py +++ b/src/mdio/constants.py @@ -1,9 +1,19 @@ """Constant values used across MDIO.""" +from enum import IntEnum + import numpy as np from mdio.schemas.dtype import ScalarType + +class ZarrFormat(IntEnum): + """Zarr version enum.""" + + V2 = 2 + V3 = 3 + + FLOAT16_MAX = np.finfo("float16").max FLOAT16_MIN = np.finfo("float16").min diff --git a/src/mdio/core/grid.py b/src/mdio/core/grid.py index 97e3bfac2..2a515334d 100644 --- a/src/mdio/core/grid.py +++ b/src/mdio/core/grid.py @@ -7,9 +7,11 @@ import numpy as np import zarr +from numcodecs.zarr3 import Blosc from zarr.codecs import BloscCodec from mdio.constants import UINT32_MAX +from mdio.constants import ZarrFormat from mdio.core.utils_write import get_constrained_chunksize if TYPE_CHECKING: @@ -108,8 +110,15 @@ def build_map(self, index_headers: HeaderArray) -> None: dtype=map_dtype, max_bytes=self._INTERNAL_CHUNK_SIZE_TARGET, ) - grid_compressor = BloscCodec(cname="zstd") - common_kwargs = {"shape": live_shape, "chunks": chunks, "compressors": grid_compressor, "store": None} + + zarr_format = zarr.config.get("default_zarr_format") + + common_kwargs = {"shape": live_shape, "chunks": chunks, "store": None} + if zarr_format == ZarrFormat.V2: + common_kwargs["compressors"] = Blosc(cname="zstd") + else: + common_kwargs["compressors"] = BloscCodec(cname="zstd") + self.map = zarr.create_array(fill_value=fill_value, dtype=map_dtype, **common_kwargs) self.live_mask = zarr.create_array(fill_value=0, dtype=bool, **common_kwargs) diff --git a/src/mdio/core/utils_write.py b/src/mdio/core/utils_write.py index 0414ecf6d..ab3c14f33 100644 --- a/src/mdio/core/utils_write.py +++ b/src/mdio/core/utils_write.py @@ -7,7 +7,6 @@ if TYPE_CHECKING: from numpy.typing import DTypeLike - from zarr import Group MAX_SIZE_LIVE_MASK = 512 * 1024**2 @@ -15,17 +14,6 @@ JsonSerializable = str | int | float | bool | None | dict[str, "JsonSerializable"] | list["JsonSerializable"] -def write_attribute(name: str, attribute: JsonSerializable, zarr_group: "Group") -> None: - """Write a mappable to Zarr array or group attribute. - - Args: - name: Name of the attribute. - attribute: Mapping to write. Must be JSON serializable. - zarr_group: Output group or array. - """ - zarr_group.attrs[name] = attribute - - def get_constrained_chunksize( shape: tuple[int, ...], dtype: "DTypeLike", @@ -45,7 +33,7 @@ def get_constrained_chunksize( return tuple(_balance_chunksizes(chunk)[0] for chunk in chunks) -def get_live_mask_chunksize(shape: tuple[int, ...]) -> tuple[int]: +def get_live_mask_chunksize(shape: tuple[int, ...]) -> tuple[int, ...]: """Given a live_mask shape, calculate the optimal write chunk size. Args: diff --git a/src/mdio/schemas/v1/dataset_serializer.py b/src/mdio/schemas/v1/dataset_serializer.py index 1e0a92d77..0df24a019 100644 --- a/src/mdio/schemas/v1/dataset_serializer.py +++ b/src/mdio/schemas/v1/dataset_serializer.py @@ -1,8 +1,10 @@ """Convert MDIO v1 schema Dataset to Xarray DataSet and write it in Zarr.""" import numpy as np +import zarr from dask import array as dask_array from dask.array.core import normalize_chunks +from numcodecs import Blosc from xarray import DataArray as xr_DataArray from xarray import Dataset as xr_Dataset from zarr.codecs import BloscCodec @@ -16,6 +18,7 @@ except ImportError: zfpy_ZFPY = None # noqa: N816 +from mdio.constants import ZarrFormat from mdio.constants import fill_value_map from mdio.schemas.compressors import ZFP as mdio_ZFP # noqa: N811 from mdio.schemas.compressors import Blosc as mdio_Blosc @@ -121,13 +124,17 @@ def _get_zarr_chunks(var: Variable, all_named_dims: dict[str, NamedDimension]) - def _convert_compressor( compressor: mdio_Blosc | mdio_ZFP | None, -) -> BloscCodec | zfpy_ZFPY | None: +) -> BloscCodec | Blosc | zfpy_ZFPY | None: """Convert a compressor to a numcodecs compatible format.""" if compressor is None: return None if isinstance(compressor, mdio_Blosc): - return BloscCodec(**compressor.model_dump(exclude={"name"})) + blosc_kwargs = compressor.model_dump(exclude={"name"}, mode="json") + if zarr.config.get("default_zarr_format") == ZarrFormat.V2: + blosc_kwargs["shuffle"] = -1 if blosc_kwargs["shuffle"] is None else blosc_kwargs["shuffle"] + return Blosc(**blosc_kwargs) + return BloscCodec(**blosc_kwargs) if isinstance(compressor, mdio_ZFP): if zfpy_ZFPY is None: @@ -215,12 +222,20 @@ def to_xarray_dataset(mdio_ds: Dataset) -> xr_Dataset: # noqa: PLR0912 if v.long_name: data_array.attrs["long_name"] = v.long_name + zarr_format = zarr.config.get("default_zarr_format") + fill_value_key = "_FillValue" if zarr_format == ZarrFormat.V2 else "fill_value" + fill_value = _get_fill_value(v.data_type) if v.name != "headers" else None + encoding = { "chunks": original_chunks, "compressor": _convert_compressor(v.compressor), - "fill_value": _get_fill_value(v.data_type), + fill_value_key: fill_value, } + if zarr_format == ZarrFormat.V2: + encoding["chunk_key_encoding"] = {"name": "v2", "configuration": {"separator": "/"}} + print(encoding) + data_array.encoding = encoding # Let's store the data array for the second pass diff --git a/src/mdio/segy/blocked_io.py b/src/mdio/segy/blocked_io.py index db066b260..fbd7531ae 100644 --- a/src/mdio/segy/blocked_io.py +++ b/src/mdio/segy/blocked_io.py @@ -10,6 +10,7 @@ from typing import TYPE_CHECKING import numpy as np +import zarr from dask.array import Array from dask.array import map_blocks from psutil import cpu_count @@ -17,6 +18,7 @@ from zarr import open_group as zarr_open_group from mdio.api.io import _normalize_storage_options +from mdio.constants import ZarrFormat from mdio.core.indexing import ChunkIterator from mdio.schemas.v1.stats import CenteredBinHistogram from mdio.schemas.v1.stats import SummaryStatistics @@ -118,6 +120,9 @@ def to_zarr( # noqa: PLR0913, PLR0915 attr_json = final_stats.model_dump_json() zarr_group[data_variable_name].attrs.update({"statsV1": attr_json}) + if zarr.config.get("default_zarr_format") == ZarrFormat.V2: + zarr.consolidate_metadata(zarr_group.store) + return final_stats