Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixup
  • Loading branch information
TomAugspurger committed Sep 26, 2024
commit 926c71acc9e8886581540088d6f457997181d8d6
39 changes: 29 additions & 10 deletions src/zarr/store/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@
from zarr.core.common import AccessModeLiteral


# T = TypeVar("T", bound=Buffer | gpu.Buffer)


# class _MemoryStore


# TODO: this store could easily be extended to wrap any MutableMapping store from v2
# When that is done, the `MemoryStore` will just be a store that wraps a dict.
class MemoryStore(Store):
Expand Down Expand Up @@ -163,9 +157,13 @@ async def list_dir(self, prefix: str) -> AsyncGenerator[str, None]:

class GpuMemoryStore(MemoryStore):
"""A GPU only memory store that stores every chunk in GPU memory irrespective
of the original location. This guarantees that chunks will always be in GPU
memory for downstream processing. For location agnostic use cases, it would
be better to use `MemoryStore` instead.
of the original location.

The dictionary of buffers to initialize this memory store with *must* be
GPU Buffers.

Writing data to this store through ``.set`` will move the buffer to the GPU
if necessary.

Parameters
----------
Expand All @@ -174,7 +172,7 @@ class GpuMemoryStore(MemoryStore):
values.
"""

_store_dict: MutableMapping[str, Buffer]
_store_dict: MutableMapping[str, gpu.Buffer] # type: ignore[assignment]

def __init__(
self,
Expand All @@ -190,6 +188,27 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return f"GpuMemoryStore({str(self)!r})"

@classmethod
def from_dict(cls, store_dict: MutableMapping[str, Buffer]) -> Self:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this API!

"""
Create a GpuMemoryStore from a dictionary of buffers at any location.

The dictionary backing the newly created ``GpuMemoryStore`` will not be
the same as ``store_dict``.

Parameters
----------
store_dict: mapping
A mapping of strings keys to arbitrary Buffers. The buffer data
will be moved into a :class:`gpu.Buffer`.

Returns
-------
GpuMemoryStore
"""
gpu_store_dict = {k: gpu.Buffer.from_buffer(v) for k, v in store_dict.items()}
return cls(gpu_store_dict)

async def set(self, key: str, value: Buffer, byte_range: tuple[int, int] | None = None) -> None:
self._check_writable()
assert isinstance(key, str)
Expand Down
22 changes: 16 additions & 6 deletions src/zarr/testing/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,13 +267,23 @@ async def test_with_mode(self, store: S) -> None:
assert isinstance(clone, type(store))

# earlier writes are visible
assert self.get(clone, "key").to_bytes() == data
result = await clone.get("key", default_buffer_prototype())
assert result is not None
assert result.to_bytes() == data

# writes to original after with_mode is visible
# # writes to original after with_mode is visible
self.set(store, "key-2", self.buffer_cls.from_bytes(data))
assert self.get(clone, "key-2").to_bytes() == data
result = await clone.get("key-2", default_buffer_prototype())
assert result is not None
assert result.to_bytes() == data

if mode == "w":
if mode == "a":
# writes to clone is visible in the original
self.set(store, "key-3", self.buffer_cls.from_bytes(data))
assert self.get(clone, "key-3").to_bytes() == data
await clone.set("key-3", self.buffer_cls.from_bytes(data))
result = await clone.get("key-3", default_buffer_prototype())
assert result is not None
assert result.to_bytes() == data

else:
with pytest.raises(ValueError):
await clone.set("key-3", self.buffer_cls.from_bytes(data))
9 changes: 9 additions & 0 deletions tests/v3/test_store/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,3 +90,12 @@ def test_dict_reference(self, store: GpuMemoryStore) -> None:
store_dict = {}
result = GpuMemoryStore(store_dict=store_dict)
assert result._store_dict is store_dict

def test_from_dict(self):
d = {
"a": gpu.Buffer.from_bytes(b"aaaa"),
"b": cpu.Buffer.from_bytes(b"bbbb"),
}
result = GpuMemoryStore.from_dict(d)
for v in result._store_dict.values():
assert type(v) is gpu.Buffer