Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion src/bloqade/geometry/dialects/grid/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
from functools import cached_property
from itertools import chain, product
from typing import Any, Generic, Sequence, TypeVar
from typing import Any, Generic, Literal, Sequence, TypeVar, overload

from kirin import ir, types
from kirin.dialects import ilist
Expand All @@ -11,6 +11,26 @@
NumY = TypeVar("NumY")


def get_indices(size: int, index: Any) -> ilist.IList[int, Any]:
if isinstance(index, slice):
return ilist.IList(range(size)[index])
elif isinstance(index, slice):
slice_value = index
return ilist.IList(range(size)[slice_value])
elif isinstance(index, int):
if index < 0:
index += size

if index < 0 or index >= size:
raise IndexError("Index out of range")

return ilist.IList([index])
elif isinstance(index, ilist.IList):
return index
else:
raise TypeError("Index must be an int, slice, or IList")


@dataclasses.dataclass
class Grid(ir.Data["Grid"], Generic[NumX, NumY]):
x_spacing: tuple[float, ...]
Expand Down Expand Up @@ -198,6 +218,57 @@ def get_view(
"""
return SubGrid(parent=self, x_indices=x_indices, y_indices=y_indices)

@overload
def __getitem__(
self, indices: tuple[int, int]
) -> "Grid[Literal[1], Literal[1]]": ...
@overload
def __getitem__(
self, indices: tuple[int, slice | list[int]]
) -> "Grid[Literal[1], Any]": ...

@overload
def __getitem__(
self, indices: tuple[int, ilist.IList[int, Ny]]
) -> "Grid[Literal[1], Ny]": ...
@overload
def __getitem__(
self, indices: tuple[slice | list[int], int]
) -> "Grid[Any, Literal[1]]": ...
@overload
def __getitem__(
self, indices: tuple[slice | list[int], slice]
) -> "Grid[Any, Any]": ...

@overload
def __getitem__(
self, indices: tuple[slice | list[int], ilist.IList[int, Ny]]
) -> "Grid[Any, Ny]": ...
@overload
def __getitem__(
self, indices: tuple[ilist.IList[int, Nx], int]
) -> "Grid[Nx, Literal[1]]": ...

@overload
def __getitem__(
self, indices: tuple[ilist.IList[int, Nx], slice | list[int]]
) -> "Grid[Nx, Any]": ...

@overload
def __getitem__(
self, indices: tuple[ilist.IList[int, Nx], ilist.IList[int, Ny]]
) -> "Grid[Nx, Ny]": ...

def __getitem__(self, indices):
if len(indices) != 2:
raise IndexError("Grid indexing requires two indices (x, y)")

x_index, y_index = indices
x_indices = get_indices(len(self.x_spacing) + 1, x_index)
y_indices = get_indices(len(self.y_spacing) + 1, y_index)

return self.get_view(x_indices=x_indices, y_indices=y_indices)

def __hash__(self) -> int:
return id(self)

Expand Down
Loading