Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Probabilistic compilation #14

Merged
merged 5 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 25 additions & 15 deletions inconspiquous/dialects/gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@
from typing import ClassVar
from dataclasses import dataclass

from xdsl.dialects.builtin import FloatAttr, Float64Type
from xdsl.dialects.builtin import (
FloatAttr,
Float64Type,
IndexType,
IntegerAttr,
AnyFloatConstr,
i1,
IntegerType,
)
from xdsl.ir import (
Dialect,
Operation,
Expand All @@ -30,7 +38,7 @@
result_def,
traits_def,
)
from xdsl.parser import AnyFloatConstr, AttrParser, IndexType, IntegerAttr, IntegerType
from xdsl.parser import AttrParser
from xdsl.printer import Printer
from xdsl.traits import ConstantLike, Pure

Expand Down Expand Up @@ -314,28 +322,30 @@ def __init__(self, lhs: SSAValue | Operation, rhs: SSAValue | Operation):


@irdl_op_definition
class XSGateOp(IRDLOperation):
class XZSOp(IRDLOperation):
"""
A gate for describing combinations of X and (pi/2) phase gates.
The final gate is given by:
X^(x >> 1) . S^phase

Passing in a value of 0 or 2 for x is undefined behaviour
A gadget for describing combinations of X, Z, and (pi/2) phase gates.
"""

name = "gate.xs"
name = "gate.xzs"

x = operand_def(IntegerType(2))
phase = operand_def(IntegerType(2))
x = operand_def(i1)
z = operand_def(i1)
phase = operand_def(i1)

out = result_def(GateType(1))

assembly_format = "$x `,` $phase attr-dict"
assembly_format = "$x `,` $z `,` $phase attr-dict"

traits = traits_def(Pure())

def __init__(self, x: Operation | SSAValue, phase: Operation | SSAValue):
super().__init__(operands=(x, phase), result_types=(GateType(1),))
def __init__(
self,
x: Operation | SSAValue,
z: Operation | SSAValue,
phase: Operation | SSAValue,
):
super().__init__(operands=(x, z, phase), result_types=(GateType(1),))


Gate = Dialect(
Expand All @@ -344,7 +354,7 @@ def __init__(self, x: Operation | SSAValue, phase: Operation | SSAValue):
ConstantGateOp,
QuaternionGateOp,
ComposeGateOp,
XSGateOp,
XZSOp,
],
[
AngleAttr,
Expand Down
32 changes: 16 additions & 16 deletions inconspiquous/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ def get_convert_scf_to_cf():

return convert_scf_to_cf.ConvertScfToCf

def get_convert_to_xs():
from inconspiquous.transforms.xs import convert_to_xs
def get_convert_to_xzs():
from inconspiquous.transforms.xzs import convert_to_xzs

return convert_to_xs.ConvertToXS
return convert_to_xzs.ConvertToXZS

def get_cse():
from xdsl.transforms import common_subexpression_elimination
Expand All @@ -46,15 +46,15 @@ def get_lower_to_fin_supp():

return lower_to_fin_supp.LowerToFinSupp

def get_lower_xs_to_select():
from inconspiquous.transforms.xs import lower
def get_lower_xzs_to_select():
from inconspiquous.transforms.xzs import lower

return lower.LowerXSToSelect
return lower.LowerXZSToSelect

def get_merge_xs():
from inconspiquous.transforms.xs import merge
def get_merge_xzs():
from inconspiquous.transforms.xzs import merge

return merge.MergeXSGates
return merge.MergeXZSGates

def get_mlir_opt():
from xdsl.transforms import mlir_opt
Expand All @@ -66,23 +66,23 @@ def get_randomized_comp():

return randomized_comp.RandomizedComp

def get_xs_select():
from inconspiquous.transforms.xs import select
def get_xzs_select():
from inconspiquous.transforms.xzs import select

return select.XSSelect
return select.XZSSelect

return {
"canonicalize": get_canonicalize,
"convert-qssa-to-qref": get_convert_qssa_to_qref,
"convert-scf-to-cf": get_convert_scf_to_cf,
"convert-to-xs": get_convert_to_xs,
"convert-to-xzs": get_convert_to_xzs,
"cse": get_cse,
"dce": get_dce,
"lower-dyn-gate-to-scf": get_lower_dyn_gate_to_scf,
"lower-to-fin-supp": get_lower_to_fin_supp,
"lower-xs-to-select": get_lower_xs_to_select,
"merge-xs": get_merge_xs,
"lower-xzs-to-select": get_lower_xzs_to_select,
"merge-xzs": get_merge_xzs,
"mlir-opt": get_mlir_opt,
"randomized-comp": get_randomized_comp,
"xs-select": get_xs_select,
"xzs-select": get_xzs_select,
}
71 changes: 0 additions & 71 deletions inconspiquous/transforms/xs/lower.py

This file was deleted.

57 changes: 0 additions & 57 deletions inconspiquous/transforms/xs/merge.py

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
IdentityGate,
PhaseGate,
XGate,
XSGateOp,
XZSOp,
YGate,
ZGate,
)
Expand All @@ -40,53 +40,52 @@ def match_and_rewrite(self, op: GateOp, rewriter: PatternRewriter):
rewriter.replace_matched_op(DynGateOp(constant, *op.ins))


class ToXSGate(RewritePattern):
class ToXZSGate(RewritePattern):
"""
Rewrite a constant Identity/X/Y/Z/Phase gate to an xs gate
Rewrite a constant Identity/X/Y/Z/Phase gate to an xzs gadget
"""

@staticmethod
def get_const(i: int, rewriter: PatternRewriter) -> ConstantOp:
n = ConstantOp.from_int_and_width(i, 2)
n.result.name_hint = f"c{i}"
def get_const(b: bool, rewriter: PatternRewriter) -> ConstantOp:
n = ConstantOp(builtin.BoolAttr.from_bool(b))
n.result.name_hint = f"c{b}"
rewriter.insert_op(n, InsertPoint.before(rewriter.current_operation))
return n

@op_type_rewrite_pattern
def match_and_rewrite(self, op: ConstantGateOp, rewriter: PatternRewriter):
match op.gate:
case IdentityGate():
rewriter.replace_matched_op(
XSGateOp(self.get_const(1, rewriter), self.get_const(0, rewriter))
)
false = self.get_const(False, rewriter)
rewriter.replace_matched_op(XZSOp(false, false, false))
case XGate():
rewriter.replace_matched_op(
XSGateOp(self.get_const(3, rewriter), self.get_const(0, rewriter))
)
false = self.get_const(False, rewriter)
true = self.get_const(True, rewriter)
rewriter.replace_matched_op(XZSOp(true, false, false))
case YGate():
rewriter.replace_matched_op(
XSGateOp(self.get_const(3, rewriter), self.get_const(2, rewriter))
)
false = self.get_const(False, rewriter)
true = self.get_const(True, rewriter)
rewriter.replace_matched_op(XZSOp(true, true, false))
case ZGate():
rewriter.replace_matched_op(
XSGateOp(self.get_const(1, rewriter), self.get_const(2, rewriter))
)
false = self.get_const(False, rewriter)
true = self.get_const(True, rewriter)
rewriter.replace_matched_op(XZSOp(false, true, false))
case PhaseGate():
rewriter.replace_matched_op(
XSGateOp(self.get_const(1, rewriter), self.get_const(1, rewriter))
)
false = self.get_const(False, rewriter)
true = self.get_const(True, rewriter)
rewriter.replace_matched_op(XZSOp(false, false, true))
case _:
return


class ConvertToXS(ModulePass):
class ConvertToXZS(ModulePass):
"""
Convert all Identity/X/Y/Z/Phase gates to xs gates
Convert all Identity/X/Y/Z/Phase gates to xzs gadgets
"""

name = "convert-to-xs"
name = "convert-to-xzs"

def apply(self, ctx: MLContext, op: builtin.ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier([ToDynGate(), ToXSGate()])
GreedyRewritePatternApplier([ToDynGate(), ToXZSGate()])
).rewrite_module(op)
Loading
Loading