diff --git a/src/bloqade/geometry/dialects/grid/types.py b/src/bloqade/geometry/dialects/grid/types.py index 0d2ef3f..34b5328 100644 --- a/src/bloqade/geometry/dialects/grid/types.py +++ b/src/bloqade/geometry/dialects/grid/types.py @@ -1,7 +1,7 @@ import dataclasses from functools import cached_property from itertools import chain, product -from typing import Any, Generic, Sequence, TypeVar +from typing import Any, Generic, Literal, Sequence, TypeVar, overload from kirin import ir, types from kirin.dialects import ilist @@ -11,6 +11,26 @@ NumY = TypeVar("NumY") +def get_indices(size: int, index: Any) -> ilist.IList[int, Any]: + if isinstance(index, slice): + return ilist.IList(range(size)[index]) + elif isinstance(index, slice): + slice_value = index + return ilist.IList(range(size)[slice_value]) + elif isinstance(index, int): + if index < 0: + index += size + + if index < 0 or index >= size: + raise IndexError("Index out of range") + + return ilist.IList([index]) + elif isinstance(index, ilist.IList): + return index + else: + raise TypeError("Index must be an int, slice, or IList") + + @dataclasses.dataclass class Grid(ir.Data["Grid"], Generic[NumX, NumY]): x_spacing: tuple[float, ...] @@ -198,6 +218,57 @@ def get_view( """ return SubGrid(parent=self, x_indices=x_indices, y_indices=y_indices) + @overload + def __getitem__( + self, indices: tuple[int, int] + ) -> "Grid[Literal[1], Literal[1]]": ... + @overload + def __getitem__( + self, indices: tuple[int, slice | list[int]] + ) -> "Grid[Literal[1], Any]": ... + + @overload + def __getitem__( + self, indices: tuple[int, ilist.IList[int, Ny]] + ) -> "Grid[Literal[1], Ny]": ... + @overload + def __getitem__( + self, indices: tuple[slice | list[int], int] + ) -> "Grid[Any, Literal[1]]": ... + @overload + def __getitem__( + self, indices: tuple[slice | list[int], slice] + ) -> "Grid[Any, Any]": ... + + @overload + def __getitem__( + self, indices: tuple[slice | list[int], ilist.IList[int, Ny]] + ) -> "Grid[Any, Ny]": ... + @overload + def __getitem__( + self, indices: tuple[ilist.IList[int, Nx], int] + ) -> "Grid[Nx, Literal[1]]": ... + + @overload + def __getitem__( + self, indices: tuple[ilist.IList[int, Nx], slice | list[int]] + ) -> "Grid[Nx, Any]": ... + + @overload + def __getitem__( + self, indices: tuple[ilist.IList[int, Nx], ilist.IList[int, Ny]] + ) -> "Grid[Nx, Ny]": ... + + def __getitem__(self, indices): + if len(indices) != 2: + raise IndexError("Grid indexing requires two indices (x, y)") + + x_index, y_index = indices + x_indices = get_indices(len(self.x_spacing) + 1, x_index) + y_indices = get_indices(len(self.y_spacing) + 1, y_index) + + return self.get_view(x_indices=x_indices, y_indices=y_indices) + def __hash__(self) -> int: return id(self)