Skip to content
Merged
Prev Previous commit
Next Next commit
fixup
  • Loading branch information
TomAugspurger committed Sep 18, 2024
commit 0ea04afa6a92c13450776819b2ad9ca2336e7465
28 changes: 21 additions & 7 deletions src/zarr/store/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
from pathlib import Path
from typing import TYPE_CHECKING, Any, Literal

import fsspec
import fsspec.implementations

from zarr.abc.store import AccessMode, Store
from zarr.core.buffer import Buffer, default_buffer_prototype
from zarr.core.common import ZARR_JSON, ZARRAY_JSON, ZGROUP_JSON, ZarrFormat
Expand Down Expand Up @@ -104,10 +101,11 @@ async def make_store_path(
result = StorePath(await LocalStore.open(root=store_like, mode=mode or "r"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can pass **storage_options to LocalStore as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be missing something, but I don't think that'll work. LocalStore.open will call LocalStore.__init__, which just takes root and mode, which are passed as regular args here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you mean. I was thinking auto_mkdir would be passed through but if that's not the case, let's not get distracted here.

elif isinstance(store_like, str):
storage_options = storage_options or {}
fs, _ = fsspec.url_to_fs(store_like, **storage_options)
if "file" not in fs.protocol:
storage_options = storage_options or {}
result = StorePath(RemoteStore(url=store_like, mode=mode or "r", **storage_options))

if _is_fsspec_uri(store_like):
result = StorePath(
RemoteStore.from_url(store_like, storage_options=storage_options, mode=mode or "r")
)
else:
result = StorePath(await LocalStore.open(root=Path(store_like), mode=mode or "r"))
elif isinstance(store_like, dict):
Expand All @@ -120,6 +118,22 @@ async def make_store_path(
return result


def _is_fsspec_uri(uri: str) -> bool:
"""
Check if a URI looks like a non-local fsspec URI.

Examples
--------
>>> _is_fsspec_uri("s3://bucket")
True
>>> _is_fsspec_uri("my-directory")
False
>>> _is_fsspec_uri("local://my-directory")
False
"""
return "://" in uri or "::" in uri and "local://" not in uri


async def ensure_no_existing_node(store_path: StorePath, zarr_format: ZarrFormat) -> None:
"""
Check if a store_path is safe for array / group creation.
Expand Down
8 changes: 1 addition & 7 deletions src/zarr/store/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ def __init__(
fs: AsyncFileSystem,
mode: AccessModeLiteral = "r",
path: str = "/",
# url: UPath | str,
allowed_exceptions: tuple[type[Exception], ...] = ALLOWED_EXCEPTIONS,
):
"""
Expand Down Expand Up @@ -82,7 +81,7 @@ def from_url(
storage_options: dict[str, Any] | None = None,
mode: AccessModeLiteral = "r",
allowed_exceptions: tuple[type[Exception], ...] = ALLOWED_EXCEPTIONS,
):
) -> RemoteStore:
fs, path = fsspec.url_to_fs(url, **storage_options)
return cls(fs=fs, path=path, mode=mode, allowed_exceptions=allowed_exceptions)

Expand All @@ -97,9 +96,6 @@ async def clear(self) -> None:
async def empty(self) -> bool:
return not await self.fs._find(self.path, withdirs=True)

# def __str__(self) -> str:
# return f"RemoteStore<fs={self.fs} mode={self._mode}>"

def __repr__(self) -> str:
return f"<RemoteStore({type(self.fs).__name__}, {self.path})>"

Expand All @@ -109,8 +105,6 @@ def __eq__(self, other: object) -> bool:
and self.path == other.path
and self.mode == other.mode
and self.fs == other.fs
# and self._url == other._url
# and self._storage_options == other._storage_options # FIXME: this isn't working for some reason
)

async def get(
Expand Down
26 changes: 10 additions & 16 deletions tests/v3/test_store/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import botocore.client
import fsspec
import pytest
from upath import UPath

import zarr.api.asynchronous
from zarr.core.buffer import Buffer, cpu, default_buffer_prototype
Expand Down Expand Up @@ -117,18 +118,6 @@ def store_kwargs(self, request) -> dict[str, str | bool]:
def store(self, store_kwargs: dict[str, str | bool]) -> RemoteStore:
return self.store_cls(**store_kwargs)

# url = store_kwargs["url"]
# mode = store_kwargs["mode"]
# if isinstance(url, UPath):
# out = self.store_cls.from_upath(url, mode=mode)
# else:
# storage_options = {
# "anon": store_kwargs["anon"],
# "endpoint_url": store_kwargs["endpoint_url"]
# }
# out = self.store_cls.from_url(url=url, storage_options=storage_options, mode=mode)
# return out

def get(self, store: RemoteStore, key: str) -> Buffer:
# make a new, synchronous instance of the filesystem because this test is run in sync code
new_fs = fsspec.filesystem(
Expand All @@ -144,7 +133,7 @@ def set(self, store: RemoteStore, key: str, value: Buffer) -> None:
new_fs.write_bytes(f"{store.path}/{key}", value.to_bytes())

def test_store_repr(self, store: RemoteStore) -> None:
assert str(store) == f"s3://{test_bucket_name}"
assert str(store) == "<RemoteStore(S3FileSystem, test)>"

def test_store_supports_writes(self, store: RemoteStore) -> None:
assert True
Expand All @@ -171,7 +160,7 @@ async def test_remote_store_from_uri(
self.buffer_cls.from_bytes(json.dumps(meta).encode()),
)
group = await zarr.api.asynchronous.open_group(
store=store._url, storage_options=storage_options
store=f"s3://{test_bucket_name}", storage_options=storage_options
)
assert dict(group.attrs) == {"key": "value"}

Expand All @@ -181,7 +170,7 @@ async def test_remote_store_from_uri(
self.buffer_cls.from_bytes(json.dumps(meta).encode()),
)
group = await zarr.api.asynchronous.open_group(
store="/".join([store._url.rstrip("/"), "directory-2"]), storage_options=storage_options
store=f"s3://{test_bucket_name}/directory-2", storage_options=storage_options
)
assert dict(group.attrs) == {"key": "value-2"}

Expand All @@ -191,6 +180,11 @@ async def test_remote_store_from_uri(
self.buffer_cls.from_bytes(json.dumps(meta).encode()),
)
group = await zarr.api.asynchronous.open_group(
store=store._url, path="directory-3", storage_options=storage_options
store=f"s3://{test_bucket_name}", path="directory-3", storage_options=storage_options
)
assert dict(group.attrs) == {"key": "value-3"}

def test_from_upath(self) -> None:
path = UPath(f"s3://{test_bucket_name}", endpoint_url=endpoint_url, anon=False)
result = RemoteStore.from_upath(path)
assert result.fs.endpoint_url == endpoint_url