Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ nav:
- grid:
- statements: reference/bloqade/geometry/dialects/grid/_interface.md
- types: reference/bloqade/geometry/dialects/grid/types.md

- filled:
- statements: reference/bloqade/geometry/dialects/filled/_interface.md
- types: reference/bloqade/geometry/dialects/filled/types.md


theme:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ authors = [
{ name = "Phillip Weinberg", email = "[email protected]" }
]
dependencies = [
"kirin-toolchain~=0.18.0",
"kirin-toolchain~=0.22.0",
]
requires-python = ">= 3.10"

Expand Down
8 changes: 7 additions & 1 deletion src/bloqade/geometry/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,7 @@
from .dialects import grid as grid
from .dialects.filled import _interface as filled
from .dialects.grid import _interface as grid

__all__ = [
"grid",
"filled",
]
9 changes: 9 additions & 0 deletions src/bloqade/geometry/dialects/filled/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from ._dialect import dialect as dialect
from ._interface import (
fill as fill,
get_parent as get_parent,
vacate as vacate,
)
from .concrete import FilledGridMethods as FilledGridMethods
from .stmts import Fill as Fill, GetParent as GetParent, Vacate as Vacate
from .types import FilledGrid as FilledGrid, FilledGridType as FilledGridType
3 changes: 3 additions & 0 deletions src/bloqade/geometry/dialects/filled/_dialect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from kirin import ir

dialect = ir.Dialect("geometry.filled")
62 changes: 62 additions & 0 deletions src/bloqade/geometry/dialects/filled/_interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from typing import Any, TypeVar

from kirin.dialects import ilist
from kirin.lowering import wraps as _wraps

from bloqade.geometry.dialects import grid

from .stmts import Fill, GetParent, Vacate
from .types import FilledGrid

Nx = TypeVar("Nx")
Ny = TypeVar("Ny")


@_wraps(Vacate)
def vacate(
zone: grid.Grid[Nx, Ny],
vacancies: ilist.IList[tuple[int, int], Any],
) -> FilledGrid[Nx, Ny]:
"""Create a FilledGrid by vacating specified positions from a grid.

Args:
zone: The original grid from which positions will be vacated.
vacancies: An IList of (x_index, y_index) tuples indicating positions to vacate

Returns:
A FilledGrid with the specified vacancies.

"""
...


@_wraps(Fill)
def fill(
zone: grid.Grid[Nx, Ny],
filled: ilist.IList[tuple[int, int], Any],
) -> FilledGrid[Nx, Ny]:
"""Create a FilledGrid by filling specified positions in a grid.

Args:
zone: The original grid in which positions will be filled.
filled: An IList of (x_index, y_index) tuples indicating positions to fill

Returns:
A FilledGrid with the specified positions filled.

"""
...


@_wraps(GetParent)
def get_parent(filled_grid: FilledGrid[Nx, Ny]) -> grid.Grid[Nx, Ny]:
"""Retrieve the parent grid of a FilledGrid.

Args:
filled_grid: The FilledGrid whose parent grid is to be retrieved.

Returns:
The parent grid of the provided FilledGrid.

"""
...
36 changes: 36 additions & 0 deletions src/bloqade/geometry/dialects/filled/concrete.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any

from kirin.dialects import ilist
from kirin.interp import (
Frame,
Interpreter,
MethodTable,
impl,
)

from bloqade.geometry.dialects.grid.types import Grid

from . import stmts
from ._dialect import dialect
from .types import FilledGrid


@dialect.register
class FilledGridMethods(MethodTable):

@impl(stmts.Vacate)
def vacate(self, interp: Interpreter, frame: Frame, stmt: stmts.Vacate):
zone = frame.get_casted(stmt.zone, Grid)
vacancies = frame.get_casted(stmt.vacancies, ilist.IList[tuple[int, int], Any])
return (FilledGrid.vacate(zone, vacancies),)

@impl(stmts.Fill)
def fill(self, interp: Interpreter, frame: Frame, stmt: stmts.Fill):
zone = frame.get_casted(stmt.zone, Grid)
filled = frame.get_casted(stmt.filled, ilist.IList[tuple[int, int], Any])
return (FilledGrid.fill(zone, filled),)

@impl(stmts.GetParent)
def get_parent(self, interp: Interpreter, frame: Frame, stmt: stmts.GetParent):
filled_grid = frame.get_casted(stmt.filled_grid, FilledGrid)
return (filled_grid.parent,)
43 changes: 43 additions & 0 deletions src/bloqade/geometry/dialects/filled/stmts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from kirin import decl, ir, lowering, types
from kirin.decl import info
from kirin.dialects import ilist

from bloqade.geometry.dialects import grid

from ._dialect import dialect
from .types import FilledGridType

NumVacant = types.TypeVar("NumVacant")
Nx = types.TypeVar("Nx")
Ny = types.TypeVar("Ny")


@decl.statement(dialect=dialect)
class Vacate(ir.Statement):
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})

zone: ir.SSAValue = info.argument(grid.GridType[Nx, Ny])
vacancies: ir.SSAValue = info.argument(
ilist.IListType[types.Tuple[types.Int, types.Int], NumVacant]
)
result: ir.ResultValue = info.result(FilledGridType[Nx, Ny])


@decl.statement(dialect=dialect)
class Fill(ir.Statement):
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})

zone: ir.SSAValue = info.argument(grid.GridType[Nx, Ny])
filled: ir.SSAValue = info.argument(
ilist.IListType[types.Tuple[types.Int, types.Int], NumVacant]
)
result: ir.ResultValue = info.result(FilledGridType[Nx, Ny])


@decl.statement(dialect=dialect)
class GetParent(ir.Statement):
name = "get_parent"
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})

filled_grid: ir.SSAValue = info.argument(FilledGridType[Nx, Ny])
result: ir.ResultValue = info.result(grid.GridType[Nx, Ny])
143 changes: 143 additions & 0 deletions src/bloqade/geometry/dialects/filled/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from dataclasses import dataclass, field
from functools import cached_property
from itertools import product
from typing import Any, Iterable, Sequence, TypeVar

from kirin import types
from kirin.dialects import ilist

from bloqade.geometry.dialects import grid

NumX = TypeVar("NumX")
NumY = TypeVar("NumY")


@dataclass(eq=False)
class FilledGrid(grid.Grid[NumX, NumY]):
x_spacing: tuple[float, ...] = field(init=False)
y_spacing: tuple[float, ...] = field(init=False)
x_init: float | None = field(init=False)
y_init: float | None = field(init=False)

parent: grid.Grid[NumX, NumY]
vacancies: frozenset[tuple[int, int]]

def __post_init__(self):
self.x_spacing = self.parent.x_spacing
self.y_spacing = self.parent.y_spacing
self.x_init = self.parent.x_init
self.y_init = self.parent.y_init

self.type = types.Generic(
FilledGrid,
types.Literal(len(self.x_spacing) + 1),
types.Literal(len(self.y_spacing) + 1),
)

def __hash__(self):
return hash((self.parent, self.vacancies))

def __eq__(self, other: Any) -> bool:
return (
isinstance(other, FilledGrid)
and self.parent == other.parent
and self.vacancies == other.vacancies
)

def is_equal(self, other: Any) -> bool:
return self == other

@cached_property
def positions(self) -> ilist.IList[tuple[float, float], Any]:
positions = tuple(
(x, y)
for (ix, x), (iy, y) in product(
enumerate(self.x_positions), enumerate(self.y_positions)
)
if (ix, iy) not in self.vacancies
)

return ilist.IList(positions)

@classmethod
def fill(
cls, grid_obj: grid.Grid[NumX, NumY], filled: Sequence[tuple[int, int]]
) -> "FilledGrid[NumX, NumY]":
num_x, num_y = grid_obj.shape

if isinstance(grid_obj, FilledGrid):
vacancies = grid_obj.vacancies
parent = grid_obj.parent
else:
vacancies = frozenset(product(range(num_x), range(num_y)))
parent = grid_obj

vacancies = vacancies - frozenset(filled)

return cls(parent=parent, vacancies=vacancies)

@classmethod
def vacate(
cls, grid_obj: grid.Grid[NumX, NumY], vacancies: Iterable[tuple[int, int]]
) -> "FilledGrid[NumX, NumY]":

if isinstance(grid_obj, FilledGrid):
input_vacancies = grid_obj.vacancies
parent = grid_obj.parent
else:
input_vacancies = frozenset()
parent = grid_obj

input_vacancies = input_vacancies.union(vacancies)

return cls(parent=parent, vacancies=input_vacancies)

def get_view( # type: ignore
self, x_indices: ilist.IList[int, Any], y_indices: ilist.IList[int, Any]
):
remapping_x = {ix: i for i, ix in enumerate(x_indices)}
remapping_y = {iy: i for i, iy in enumerate(y_indices)}
return FilledGrid(
parent=self.parent.get_view(x_indices, y_indices),
vacancies=frozenset(
(remapping_x[x], remapping_y[y])
for x, y in self.vacancies
if x in remapping_x and y in remapping_y
),
)

def shift(self, x_shift: float, y_shift: float):
return FilledGrid(
parent=self.parent.shift(x_shift, y_shift),
vacancies=self.vacancies,
)

def scale(self, x_scale: float, y_scale: float):
return FilledGrid(
parent=self.parent.scale(x_scale, y_scale),
vacancies=self.vacancies,
)

def repeat(self, x_times: int, y_times: int, x_gap: float, y_gap: float):
new_parent = self.parent.repeat(x_times, y_times, x_gap, y_gap)
x_dim, y_dim = self.shape
vacancies = frozenset(
(x + x_dim * i, y + y_dim * j)
for i, j, (x, y) in product(range(x_times), range(y_times), self.vacancies)
)
return FilledGrid.vacate(new_parent, vacancies)

def row_x_pos(self, row_index: int):
x_vacancies = {x for x, y in self.vacancies if y == row_index}
return ilist.IList(
[x for i, x in enumerate(self.parent.x_positions) if i not in x_vacancies]
)

def col_y_pos(self, column_index: int):
y_vacancies = {y for x, y in self.vacancies if x == column_index}
return ilist.IList(
[y for i, y in enumerate(self.y_positions) if i not in y_vacancies]
)


FilledGridType = types.Generic(FilledGrid, types.TypeVar("NumX"), types.TypeVar("NumY"))
32 changes: 32 additions & 0 deletions src/bloqade/geometry/dialects/grid/_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from kirin.lowering import wraps as _wraps

from .stmts import (
ColYPos,
FromPositions,
Get,
GetSubGrid,
Expand All @@ -14,6 +15,7 @@
New,
Positions,
Repeat,
RowXPos,
Scale,
Shape,
Shift,
Expand Down Expand Up @@ -275,3 +277,33 @@ def shape(grid: Grid) -> tuple[int, int]:
tuple[int, int]: a tuple of (num_x, num_y)
"""
...


@_wraps(RowXPos)
def row_xpos(
grid: Grid[typing.Any, typing.Any], row_index: int
) -> ilist.IList[float, typing.Any]:
"""Get the x positions of a specific row in the grid.

Args:
grid (Grid): a grid object
row_index (int): the index of the row to get x positions for
Returns:
ilist.IList[float, typing.Any]: a list of x positions for the specified row
"""
...


@_wraps(ColYPos)
def col_ypos(
grid: Grid[typing.Any, typing.Any], column_index: int
) -> ilist.IList[float, typing.Any]:
"""Get the y positions of a specific column in the grid.

Args:
grid (Grid): a grid object
column_index (int | None): the index of the column to get y positions for, or None for all columns
Returns:
ilist.IList[float, typing.Any]: a list of y positions for the specified column
"""
...
Loading
Loading