Skip to content

Commit f056c41

Browse files
committed
Adding statement that generates all positions of the grid (#16)
* Adding statement that generates all positions of the grid * updating pre-commit * run linter
1 parent 9d37b57 commit f056c41

File tree

6 files changed

+45
-2
lines changed

6 files changed

+45
-2
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ repos:
1414
- id: isort
1515
name: isort (python)
1616
- repo: https://github.com/psf/black
17-
rev: 24.3.0
17+
rev: 25.1.0
1818
hooks:
1919
- id: black
2020
- repo: https://github.com/charliermarsh/ruff-pre-commit

src/bloqade/geometry/dialects/grid/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
get_xpos as get_xpos,
66
get_ypos as get_ypos,
77
new as new,
8+
positions as positions,
89
repeat as repeat,
910
scale as scale,
1011
shape as shape,
@@ -25,6 +26,7 @@
2526
GetYBounds as GetYBounds,
2627
GetYPos as GetYPos,
2728
New as New,
29+
Positions as Positions,
2830
Repeat as Repeat,
2931
Scale as Scale,
3032
Shape as Shape,

src/bloqade/geometry/dialects/grid/_interface.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
GetYBounds,
1313
GetYPos,
1414
New,
15+
Positions,
1516
Repeat,
1617
Scale,
1718
Shape,
@@ -164,6 +165,22 @@ def y_bounds(grid: Grid[typing.Any, typing.Any]) -> tuple[float, float]:
164165
...
165166

166167

168+
@_wraps(Positions)
169+
def positions(
170+
grid: Grid[typing.Any, typing.Any],
171+
) -> ilist.IList[tuple[float, float], typing.Any]:
172+
"""Get the positions of a grid as a list of (x, y) tuples.
173+
174+
Args:
175+
grid (Grid): a grid object
176+
177+
Returns:
178+
ilist.IList[tuple[float, float], typing.Any]: a list of (x, y) tuples representing the positions of the grid points
179+
180+
"""
181+
...
182+
183+
167184
@_wraps(Repeat)
168185
def repeat(
169186
grid: Grid, x_times: int, y_times: int, x_spacing: float, y_spacing: float

src/bloqade/geometry/dialects/grid/concrete.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,13 @@ def repeat(
181181
y_gap = frame.get_casted(stmt.y_gap, float)
182182

183183
return (grid.repeat(x_times, y_times, x_gap, y_gap),)
184+
185+
@impl(stmts.Positions)
186+
def positions(
187+
self,
188+
interp: Interpreter,
189+
frame: Frame,
190+
stmt: stmts.Positions,
191+
):
192+
grid = frame.get_casted(stmt.zone, Grid)
193+
return (grid.positions,)

src/bloqade/geometry/dialects/grid/stmts.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,16 @@ class New(ir.Statement):
5252
)
5353

5454

55+
@statement(dialect=dialect)
56+
class Positions(ir.Statement):
57+
name = "positions"
58+
traits = frozenset({ir.Pure(), lowering.FromPythonCall()})
59+
zone: ir.SSAValue = info.argument(type=GridType[types.Any, types.Any])
60+
result: ir.ResultValue = info.result(
61+
ilist.IListType[types.Tuple[types.Float, types.Float], types.Any]
62+
)
63+
64+
5565
# Maybe do this with hints?
5666
@statement(dialect=dialect)
5767
class Shape(ir.Statement):

src/bloqade/geometry/dialects/grid/types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import dataclasses
22
from functools import cached_property
3-
from itertools import chain
3+
from itertools import chain, product
44
from typing import Any, Generic, Sequence, TypeVar
55

66
from kirin import ir, types
@@ -118,6 +118,10 @@ def y_positions(self) -> tuple[float, ...]:
118118
)
119119
)
120120

121+
@cached_property
122+
def positions(self) -> ilist.IList[tuple[float, float], Any]:
123+
return ilist.IList(tuple(product(self.x_positions, self.y_positions)))
124+
121125
def get(self, idx: tuple[int, int]) -> tuple[float, float]:
122126
return (self.x_positions[idx[0]], self.y_positions[idx[1]])
123127

0 commit comments

Comments
 (0)