-
Notifications
You must be signed in to change notification settings - Fork 1
QASM2 to squin #644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
QASM2 to squin #644
Changes from all commits
2f6e8a1
90d8578
509b9cc
006be62
1fa9865
6552782
d4f46bb
6b9e52f
cd53e5d
59dfcb0
99e66af
045fa00
5137de0
6390a23
0645d11
df35694
3468702
5f137de
aea7567
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| from .qasm2_to_squin import QASM2ToSquin as QASM2ToSquin |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| from kirin import ir, passes | ||
| from kirin.rewrite import Walk, Chain | ||
| from kirin.dialects import func | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade.rewrite.passes import CallGraphPass | ||
| from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule | ||
|
|
||
| from ..rewrite import qasm2 as qasm2_rule | ||
|
|
||
|
|
||
| class QASM2GateFuncToKirinFunc(RewriteRule): | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
| from bloqade.qasm2.dialects.expr.stmts import GateFunction | ||
|
|
||
| if not isinstance(node, GateFunction): | ||
| return RewriteResult() | ||
|
|
||
| kirin_func = func.Function( | ||
| sym_name=node.sym_name, | ||
| signature=node.signature, | ||
| body=node.body, | ||
| slots=node.slots, | ||
| ) | ||
| node.replace_by(kirin_func) | ||
|
|
||
| return RewriteResult(has_done_something=True) | ||
|
|
||
|
|
||
| class QASM2GateFuncToSquinPass(passes.Pass): | ||
|
|
||
| def unsafe_run(self, mt: ir.Method) -> RewriteResult: | ||
| convert_to_kirin_func = CallGraphPass( | ||
| dialects=mt.dialects, rule=Walk(QASM2GateFuncToKirinFunc()) | ||
| ) | ||
| rewrite_result = convert_to_kirin_func(mt) | ||
|
|
||
| combined_qasm2_rules = Walk( | ||
| Chain( | ||
| QASM2ToPyRule(), | ||
| qasm2_rule.QASM2CoreToSquin(), | ||
| qasm2_rule.QASM2GlobParallelToSquin(), | ||
| qasm2_rule.QASM2NoiseToSquin(), | ||
| qasm2_rule.QASM2IdToSquin(), | ||
| qasm2_rule.QASM2UOp1QToSquin(), | ||
| qasm2_rule.QASM2ParametrizedUOp1QToSquin(), | ||
| qasm2_rule.QASM2UOp2QToSquin(), | ||
| ) | ||
| ) | ||
|
|
||
| body_conversion_pass = CallGraphPass( | ||
| dialects=mt.dialects, rule=combined_qasm2_rules | ||
| ) | ||
| rewrite_result = body_conversion_pass(mt).join(rewrite_result) | ||
|
|
||
| return rewrite_result |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,66 @@ | ||
| from dataclasses import dataclass | ||
|
|
||
| from kirin import ir | ||
| from kirin.passes import Fold, Pass, TypeInfer | ||
| from kirin.rewrite import Walk, Chain | ||
| from kirin.rewrite.abc import RewriteResult | ||
| from kirin.dialects.ilist.passes import IListDesugar | ||
|
|
||
| from bloqade import squin | ||
| from bloqade.squin.rewrite.qasm2 import ( | ||
| QASM2IdToSquin, | ||
| QASM2CoreToSquin, | ||
| QASM2NoiseToSquin, | ||
| QASM2UOp1QToSquin, | ||
| QASM2UOp2QToSquin, | ||
| QASM2GlobParallelToSquin, | ||
| QASM2ParametrizedUOp1QToSquin, | ||
| ) | ||
|
|
||
| # There's a QASM2Py pass that only applies an _QASM2Py rewrite rule, | ||
| # I just want the rule here. | ||
| from bloqade.qasm2.passes.qasm2py import _QASM2Py as QASM2ToPyRule | ||
|
|
||
| from .qasm2_gate_func_to_squin import QASM2GateFuncToSquinPass | ||
|
|
||
|
|
||
| @dataclass | ||
| class QASM2ToSquin(Pass): | ||
|
|
||
| def unsafe_run(self, mt: ir.Method) -> RewriteResult: | ||
|
|
||
| # rewrite all QASM2 to squin first | ||
| rewrite_result = Walk( | ||
| Chain( | ||
| QASM2ToPyRule(), | ||
| QASM2CoreToSquin(), | ||
| QASM2GlobParallelToSquin(), | ||
| QASM2NoiseToSquin(), | ||
| QASM2IdToSquin(), | ||
| QASM2UOp1QToSquin(), | ||
| QASM2ParametrizedUOp1QToSquin(), | ||
| QASM2UOp2QToSquin(), | ||
| ) | ||
| ).rewrite(mt.code) | ||
|
|
||
| # go into subkernels | ||
| rewrite_result = ( | ||
| QASM2GateFuncToSquinPass(dialects=mt.dialects) | ||
| .unsafe_run(mt) | ||
| .join(rewrite_result) | ||
| ) | ||
|
|
||
| # kernel should be entirely in squin dialect now | ||
| mt.dialects = squin.kernel | ||
|
|
||
| # the rest is taken from the squin kernel | ||
| rewrite_result = Fold(dialects=mt.dialects).fixpoint(mt) | ||
| rewrite_result = ( | ||
| TypeInfer(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result) | ||
| ) | ||
| rewrite_result = ( | ||
| IListDesugar(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result) | ||
| ) | ||
| TypeInfer(dialects=mt.dialects).unsafe_run(mt).join(rewrite_result) | ||
|
|
||
| return rewrite_result | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,9 @@ | ||
| from .id_to_squin import QASM2IdToSquin as QASM2IdToSquin | ||
| from .core_to_squin import QASM2CoreToSquin as QASM2CoreToSquin | ||
| from .noise_to_squin import QASM2NoiseToSquin as QASM2NoiseToSquin | ||
| from .uop_1q_to_squin import QASM2UOp1QToSquin as QASM2UOp1QToSquin | ||
| from .uop_2q_to_squin import QASM2UOp2QToSquin as QASM2UOp2QToSquin | ||
| from .glob_parallel_to_squin import QASM2GlobParallelToSquin as QASM2GlobParallelToSquin | ||
| from .parametrized_uop_1q_to_squin import ( | ||
| QASM2ParametrizedUOp1QToSquin as QASM2ParametrizedUOp1QToSquin, | ||
| ) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,38 @@ | ||
| from kirin import ir | ||
| from kirin.dialects import py, func | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade import squin | ||
| from bloqade.qasm2.dialects.core import stmts as core_stmts | ||
|
|
||
| CORE_TO_SQUIN_MAP = { | ||
| core_stmts.QRegNew: squin.qubit.qalloc, | ||
| core_stmts.Reset: squin.qubit.reset, | ||
| } | ||
|
|
||
|
|
||
| class QASM2CoreToSquin(RewriteRule): | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
|
|
||
| if isinstance(node, core_stmts.QRegGet): | ||
| py_get_item = py.GetItem( | ||
| obj=node.reg, | ||
| index=node.idx, | ||
| ) | ||
| node.replace_by(py_get_item) | ||
| return RewriteResult(has_done_something=True) | ||
|
|
||
| if isinstance(node, core_stmts.QRegNew): | ||
| args = (node.n_qubits,) | ||
| elif isinstance(node, core_stmts.Reset): | ||
| args = (node.qarg,) | ||
| else: | ||
| return RewriteResult() | ||
|
|
||
| new_stmt = func.Invoke( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like you could just do this within the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. something like if isinstance(node, (core_stmts.QRegGet, core_stmts.QRegNew, core_stmts.Reset):
node.replace_by(func.Invoke(node.args, callee = callees[type(note))
else:
return RewriteResult() |
||
| callee=CORE_TO_SQUIN_MAP[type(node)], | ||
johnzl-777 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| inputs=args, | ||
| ) | ||
| node.replace_by(new_stmt) | ||
| return RewriteResult(has_done_something=True) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| from kirin import ir | ||
| from kirin.dialects import func | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade import squin | ||
| from bloqade.qasm2.dialects import glob, parallel | ||
|
|
||
| GLOBAL_PARALLEL_TO_SQUIN_MAP = { | ||
| glob.UGate: squin.broadcast.u3, | ||
| parallel.UGate: squin.broadcast.u3, | ||
| parallel.RZ: squin.broadcast.rz, | ||
| } | ||
|
|
||
|
|
||
| class QASM2GlobParallelToSquin(RewriteRule): | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
|
|
||
| if isinstance(node, glob.UGate): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, this seems a bit redundant: you might as well just assign
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The dictionary was something that @weinbe58 recommended, I could be doing the pattern wrong here but in cases where I do see a dictionary used it only turns out nice if the attribute you're accessing exits across all the statements. Like for arithmetic operation conversion, you'll always have an I actually realize if I wanted to be clever I could add some strings to the values in the dictionary and then Wish I could still do pattern matching but I'm told the performance would take a hit
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you do not need to unpack the arguments, unless they do not match the signature of the standard library function in squin but I think we have been consistent on this front so you should not need to do this @johnzl-777 see my comment above. |
||
| args = (node.theta, node.phi, node.lam, node.registers) | ||
| elif isinstance(node, parallel.UGate): | ||
| args = (node.theta, node.phi, node.lam, node.qargs) | ||
| elif isinstance(node, parallel.RZ): | ||
| args = (node.theta, node.qargs) | ||
| else: | ||
| return RewriteResult() | ||
|
|
||
| squin_equivalent_stmt = GLOBAL_PARALLEL_TO_SQUIN_MAP[type(node)] | ||
| invoke_stmt = func.Invoke( | ||
| callee=squin_equivalent_stmt, | ||
| inputs=args, | ||
| ) | ||
| node.replace_by(invoke_stmt) | ||
| return RewriteResult(has_done_something=True) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
| from kirin import ir | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| import bloqade.qasm2.dialects.uop.stmts as uop_stmts | ||
|
|
||
|
|
||
| class QASM2IdToSquin(RewriteRule): | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
|
|
||
| if not isinstance(node, uop_stmts.Id): | ||
| return RewriteResult() | ||
|
|
||
| node.delete() | ||
| return RewriteResult(has_done_something=True) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,91 @@ | ||
| from kirin import ir | ||
| from kirin.dialects import py, func | ||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||
|
|
||
| from bloqade import squin | ||
| from bloqade.qasm2.dialects.noise import stmts as noise_stmts | ||
|
|
||
| NOISE_TO_SQUIN_MAP = { | ||
| noise_stmts.AtomLossChannel: squin.broadcast.qubit_loss, | ||
| noise_stmts.PauliChannel: squin.broadcast.single_qubit_pauli_channel, | ||
| } | ||
|
|
||
|
|
||
| def num_to_py_constant( | ||
| values: list[int | float], stmt_to_insert_before: ir.Statement | ||
| ) -> list[ir.SSAValue]: | ||
|
|
||
| py_const_ssa_vals = [] | ||
| for value in values: | ||
| const_form = py.Constant(value=value) | ||
| const_form.insert_before(stmt_to_insert_before) | ||
| py_const_ssa_vals.append(const_form.result) | ||
|
|
||
| return py_const_ssa_vals | ||
|
|
||
|
|
||
| class QASM2NoiseToSquin(RewriteRule): | ||
|
|
||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||
|
|
||
| if isinstance(node, noise_stmts.AtomLossChannel): | ||
| qargs = node.qargs | ||
| prob = node.prob | ||
| prob_ssas = num_to_py_constant([prob], stmt_to_insert_before=node) | ||
| elif isinstance(node, noise_stmts.PauliChannel): | ||
| qargs = node.qargs | ||
| p_x = node.px | ||
| p_y = node.py | ||
| p_z = node.pz | ||
| prob_ssas = num_to_py_constant([p_x, p_y, p_z], stmt_to_insert_before=node) | ||
| elif isinstance(node, noise_stmts.CZPauliChannel): | ||
| return self.rewrite_CZPauliChannel(node) | ||
| else: | ||
| return RewriteResult() | ||
|
|
||
| squin_noise_stmt = NOISE_TO_SQUIN_MAP[type(node)] | ||
| invoke_stmt = func.Invoke( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you know that you are going for the broadcast version, why not just rewrite to the statement directly instead of adding an invoke to the stdlib?
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see the comment above |
||
| callee=squin_noise_stmt, | ||
| inputs=(*prob_ssas, qargs), | ||
| ) | ||
| node.replace_by(invoke_stmt) | ||
| return RewriteResult(has_done_something=True) | ||
|
|
||
| def rewrite_CZPauliChannel(self, stmt: noise_stmts.CZPauliChannel) -> RewriteResult: | ||
|
|
||
| ctrls = stmt.ctrls | ||
| qargs = stmt.qargs | ||
|
|
||
| px_ctrl = stmt.px_ctrl | ||
| py_ctrl = stmt.py_ctrl | ||
| pz_ctrl = stmt.pz_ctrl | ||
| px_qarg = stmt.px_qarg | ||
| py_qarg = stmt.py_qarg | ||
| pz_qarg = stmt.pz_qarg | ||
|
|
||
| error_probs = [px_ctrl, py_ctrl, pz_ctrl, px_qarg, py_qarg, pz_qarg] | ||
| # first half of entries for control qubits, other half for targets | ||
|
|
||
| error_prob_ssas = num_to_py_constant(error_probs, stmt_to_insert_before=stmt) | ||
|
|
||
| ctrl_pauli_channel_invoke = func.Invoke( | ||
| callee=squin.broadcast.single_qubit_pauli_channel, | ||
| inputs=( | ||
| *error_prob_ssas[:3], | ||
| ctrls, | ||
| ), | ||
| ) | ||
|
|
||
| qarg_pauli_channel_invoke = func.Invoke( | ||
| callee=squin.broadcast.single_qubit_pauli_channel, | ||
| inputs=( | ||
| *error_prob_ssas[3:], | ||
| qargs, | ||
| ), | ||
| ) | ||
|
|
||
| ctrl_pauli_channel_invoke.insert_before(stmt) | ||
| qarg_pauli_channel_invoke.insert_before(stmt) | ||
| stmt.delete() | ||
|
|
||
| return RewriteResult(has_done_something=True) | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,46 @@ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from math import pi | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from kirin import ir | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from kirin.dialects import py, func | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from kirin.rewrite.abc import RewriteRule, RewriteResult | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from bloqade import squin | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| from bloqade.qasm2.dialects.uop import stmts as uop_stmts | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| PARAMETRIZED_1Q_GATES_TO_SQUIN_MAP = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uop_stmts.UGate: squin.u3, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uop_stmts.U1: squin.u3, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uop_stmts.U2: squin.u3, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uop_stmts.RZ: squin.rz, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uop_stmts.RX: squin.rx, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| uop_stmts.RY: squin.ry, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| class QASM2ParametrizedUOp1QToSquin(RewriteRule): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def rewrite_Statement(self, node: ir.Statement) -> RewriteResult: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if isinstance(node, (uop_stmts.RX, uop_stmts.RY, uop_stmts.RZ)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args = (node.theta, node.qarg) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(node, (uop_stmts.UGate)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args = (node.theta, node.phi, node.lam, node.qarg) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(node, (uop_stmts.U1)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| zero_stmt = py.Constant(value=0.0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| zero_stmt.insert_before(node) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args = (zero_stmt.result, zero_stmt.result, node.lam, node.qarg) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif isinstance(node, (uop_stmts.U2)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| half_pi_stmt = py.Constant(value=pi / 2) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| half_pi_stmt.insert_before(node) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| args = (half_pi_stmt.result, node.phi, node.lam, node.qarg) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return RewriteResult() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+24
to
+37
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| squin_equivalent_stmt = PARAMETRIZED_1Q_GATES_TO_SQUIN_MAP[type(node)] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| invoke_stmt = func.Invoke( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| callee=squin_equivalent_stmt, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| inputs=args, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| node.replace_by(invoke_stmt) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return RewriteResult(has_done_something=True) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Uh oh!
There was an error while loading. Please reload this page.