Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/mdio/api/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
from typing import Any
from typing import Literal

import zarr
from upath import UPath
from xarray import Dataset as xr_Dataset
from xarray import open_zarr as xr_open_zarr
from xarray.backends.api import to_zarr as xr_to_zarr

from mdio.constants import ZarrFormat

if TYPE_CHECKING:
from collections.abc import Mapping
from pathlib import Path
Expand Down Expand Up @@ -47,7 +50,13 @@ def open_mdio(input_path: UPath | Path | str, chunks: T_Chunks = None) -> xr_Dat
"""
input_path = _normalize_path(input_path)
storage_options = _normalize_storage_options(input_path)
return xr_open_zarr(input_path.as_posix(), chunks=chunks, storage_options=storage_options)
zarr_format = zarr.config.get("default_zarr_format")
return xr_open_zarr(
input_path.as_posix(),
chunks=chunks,
storage_options=storage_options,
mask_and_scale=zarr_format == ZarrFormat.V3, # off for v2, on for v3
)


def to_mdio( # noqa: PLR0913
Expand All @@ -57,7 +66,6 @@ def to_mdio( # noqa: PLR0913
*,
compute: bool = True,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None = None,
zarr_format: int = 3,
) -> None:
"""Write dataset contents to an MDIO output_path.

Expand All @@ -74,18 +82,17 @@ def to_mdio( # noqa: PLR0913
can be computed to write array data later. Metadata is always updated eagerly.
region: Optional mapping from dimension names to either a) ``"auto"``, or b) integer slices, indicating
the region of existing MDIO array(s) in which to write this dataset's data.
zarr_format: The desired zarr format to target. The default is 3.
"""
output_path = _normalize_path(output_path)
storage_options = _normalize_storage_options(output_path)
zarr_format = zarr.config.get("default_zarr_format")
xr_to_zarr(
dataset,
store=output_path.as_posix(), # xarray doesn't like URI when file:// is protocol
mode=mode,
compute=compute,
consolidated=False,
consolidated=zarr_format == ZarrFormat.V2, # off for v3, on for v2
region=region,
storage_options=storage_options,
zarr_format=zarr_format,
write_empty_chunks=False,
)
10 changes: 10 additions & 0 deletions src/mdio/constants.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
"""Constant values used across MDIO."""

from enum import IntEnum

import numpy as np

from mdio.schemas.dtype import ScalarType


class ZarrFormat(IntEnum):
"""Zarr version enum."""

V2 = 2
V3 = 3


FLOAT16_MAX = np.finfo("float16").max
FLOAT16_MIN = np.finfo("float16").min

Expand Down
13 changes: 11 additions & 2 deletions src/mdio/core/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@

import numpy as np
import zarr
from numcodecs.zarr3 import Blosc
from zarr.codecs import BloscCodec

from mdio.constants import UINT32_MAX
from mdio.constants import ZarrFormat
from mdio.core.utils_write import get_constrained_chunksize

if TYPE_CHECKING:
Expand Down Expand Up @@ -108,8 +110,15 @@ def build_map(self, index_headers: HeaderArray) -> None:
dtype=map_dtype,
max_bytes=self._INTERNAL_CHUNK_SIZE_TARGET,
)
grid_compressor = BloscCodec(cname="zstd")
common_kwargs = {"shape": live_shape, "chunks": chunks, "compressors": grid_compressor, "store": None}

zarr_format = zarr.config.get("default_zarr_format")

common_kwargs = {"shape": live_shape, "chunks": chunks, "store": None}
if zarr_format == ZarrFormat.V2:
common_kwargs["compressors"] = Blosc(cname="zstd")
else:
common_kwargs["compressors"] = BloscCodec(cname="zstd")

self.map = zarr.create_array(fill_value=fill_value, dtype=map_dtype, **common_kwargs)
self.live_mask = zarr.create_array(fill_value=0, dtype=bool, **common_kwargs)

Expand Down
14 changes: 1 addition & 13 deletions src/mdio/core/utils_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,13 @@

if TYPE_CHECKING:
from numpy.typing import DTypeLike
from zarr import Group


MAX_SIZE_LIVE_MASK = 512 * 1024**2

JsonSerializable = str | int | float | bool | None | dict[str, "JsonSerializable"] | list["JsonSerializable"]


def write_attribute(name: str, attribute: JsonSerializable, zarr_group: "Group") -> None:
"""Write a mappable to Zarr array or group attribute.

Args:
name: Name of the attribute.
attribute: Mapping to write. Must be JSON serializable.
zarr_group: Output group or array.
"""
zarr_group.attrs[name] = attribute


def get_constrained_chunksize(
shape: tuple[int, ...],
dtype: "DTypeLike",
Expand All @@ -45,7 +33,7 @@ def get_constrained_chunksize(
return tuple(_balance_chunksizes(chunk)[0] for chunk in chunks)


def get_live_mask_chunksize(shape: tuple[int, ...]) -> tuple[int]:
def get_live_mask_chunksize(shape: tuple[int, ...]) -> tuple[int, ...]:
"""Given a live_mask shape, calculate the optimal write chunk size.

Args:
Expand Down
21 changes: 18 additions & 3 deletions src/mdio/schemas/v1/dataset_serializer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Convert MDIO v1 schema Dataset to Xarray DataSet and write it in Zarr."""

import numpy as np
import zarr
from dask import array as dask_array
from dask.array.core import normalize_chunks
from numcodecs import Blosc
from xarray import DataArray as xr_DataArray
from xarray import Dataset as xr_Dataset
from zarr.codecs import BloscCodec
Expand All @@ -16,6 +18,7 @@
except ImportError:
zfpy_ZFPY = None # noqa: N816

from mdio.constants import ZarrFormat
from mdio.constants import fill_value_map
from mdio.schemas.compressors import ZFP as mdio_ZFP # noqa: N811
from mdio.schemas.compressors import Blosc as mdio_Blosc
Expand Down Expand Up @@ -121,13 +124,17 @@ def _get_zarr_chunks(var: Variable, all_named_dims: dict[str, NamedDimension]) -

def _convert_compressor(
compressor: mdio_Blosc | mdio_ZFP | None,
) -> BloscCodec | zfpy_ZFPY | None:
) -> BloscCodec | Blosc | zfpy_ZFPY | None:
"""Convert a compressor to a numcodecs compatible format."""
if compressor is None:
return None

if isinstance(compressor, mdio_Blosc):
return BloscCodec(**compressor.model_dump(exclude={"name"}))
blosc_kwargs = compressor.model_dump(exclude={"name"}, mode="json")
if zarr.config.get("default_zarr_format") == ZarrFormat.V2:
blosc_kwargs["shuffle"] = -1 if blosc_kwargs["shuffle"] is None else blosc_kwargs["shuffle"]
return Blosc(**blosc_kwargs)
return BloscCodec(**blosc_kwargs)

if isinstance(compressor, mdio_ZFP):
if zfpy_ZFPY is None:
Expand Down Expand Up @@ -215,12 +222,20 @@ def to_xarray_dataset(mdio_ds: Dataset) -> xr_Dataset: # noqa: PLR0912
if v.long_name:
data_array.attrs["long_name"] = v.long_name

zarr_format = zarr.config.get("default_zarr_format")
fill_value_key = "_FillValue" if zarr_format == ZarrFormat.V2 else "fill_value"
fill_value = _get_fill_value(v.data_type) if v.name != "headers" else None

encoding = {
"chunks": original_chunks,
"compressor": _convert_compressor(v.compressor),
"fill_value": _get_fill_value(v.data_type),
fill_value_key: fill_value,
}

if zarr_format == ZarrFormat.V2:
encoding["chunk_key_encoding"] = {"name": "v2", "configuration": {"separator": "/"}}
print(encoding)

data_array.encoding = encoding

# Let's store the data array for the second pass
Expand Down
5 changes: 5 additions & 0 deletions src/mdio/segy/blocked_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
from typing import TYPE_CHECKING

import numpy as np
import zarr
from dask.array import Array
from dask.array import map_blocks
from psutil import cpu_count
from tqdm.auto import tqdm
from zarr import open_group as zarr_open_group

from mdio.api.io import _normalize_storage_options
from mdio.constants import ZarrFormat
from mdio.core.indexing import ChunkIterator
from mdio.schemas.v1.stats import CenteredBinHistogram
from mdio.schemas.v1.stats import SummaryStatistics
Expand Down Expand Up @@ -118,6 +120,9 @@ def to_zarr( # noqa: PLR0913, PLR0915
attr_json = final_stats.model_dump_json()
zarr_group[data_variable_name].attrs.update({"statsV1": attr_json})

if zarr.config.get("default_zarr_format") == ZarrFormat.V2:
zarr.consolidate_metadata(zarr_group.store)

return final_stats


Expand Down
Loading