Skip to content

Commit

Permalink
Merge pull request #10 from xdslproject/alexarice/fin-supp
Browse files Browse the repository at this point in the history
dialects: (prob) add finite support distribution
  • Loading branch information
alexarice authored Nov 8, 2024
2 parents 533e3c7 + 8d2e74d commit ae7cd0a
Show file tree
Hide file tree
Showing 6 changed files with 230 additions and 15 deletions.
123 changes: 121 additions & 2 deletions inconspiquous/dialects/prob.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,35 @@
from typing import ClassVar, Sequence
from typing_extensions import Self

from xdsl.dialects.builtin import IntegerAttrTypeConstr, i1
from xdsl.ir import Dialect, VerifyException
from xdsl.ir import (
Attribute,
Dialect,
Operation,
SSAValue,
VerifyException,
)
from xdsl.irdl import (
AnyAttr,
IRDLOperation,
VarConstraint,
irdl_op_definition,
operand_def,
prop_def,
result_def,
traits_def,
var_operand_def,
)
from xdsl.parser import Float64Type, FloatAttr, IndexType, IntegerType
from xdsl.parser import (
DenseArrayBase,
Float64Type,
FloatAttr,
IndexType,
IntegerType,
UnresolvedOperand,
)
from xdsl.parser import Parser
from xdsl.printer import Printer
from xdsl.pattern_rewriter import RewritePattern
from xdsl.traits import HasCanonicalizationPatternsTrait

Expand Down Expand Up @@ -68,11 +90,108 @@ def __init__(self, out_type: IntegerType | IndexType):
)


class FinSuppOpHasCanonicalizationPatterns(HasCanonicalizationPatternsTrait):
@classmethod
def get_canonicalization_patterns(cls) -> tuple[RewritePattern, ...]:
from inconspiquous.transforms.canonicalization.prob import (
FinSuppTrivial,
FinSuppRemoveCase,
FinSuppDuplicate,
)

return (FinSuppTrivial(), FinSuppRemoveCase(), FinSuppDuplicate())


@irdl_op_definition
class FinSuppOp(IRDLOperation):
name = "prob.fin_supp"

_T: ClassVar = VarConstraint("T", AnyAttr())

ins = var_operand_def(_T)

default_value = operand_def(_T)

out = result_def(_T)

probabilities = prop_def(DenseArrayBase)

traits = traits_def(FinSuppOpHasCanonicalizationPatterns())

def __init__(
self,
probabilities: Sequence[float] | DenseArrayBase,
default_value: SSAValue,
*ins: SSAValue | Operation,
attr_dict: dict[str, Attribute] | None = None,
):
result_type = SSAValue.get(default_value).type
if not isinstance(probabilities, DenseArrayBase):
probabilities = DenseArrayBase.create_dense_float(
Float64Type(), probabilities
)
super().__init__(
operands=(ins, default_value),
result_types=(result_type,),
properties={"probabilities": probabilities},
attributes=attr_dict,
)

@staticmethod
def parse_case(parser: Parser) -> tuple[UnresolvedOperand, float]:
prob = parser.parse_number()
assert isinstance(prob, float)
parser.parse_keyword("or")
operand = parser.parse_unresolved_operand()
return (operand, prob)

@classmethod
def parse(cls, parser: Parser) -> Self:
parser.parse_punctuation("[")
probabilities: list[float] = []
cases: list[UnresolvedOperand] = []
while (n := parser.parse_optional_number()) is not None:
assert isinstance(n, float)
probabilities.append(n)
parser.parse_keyword("of")
cases.append(parser.parse_unresolved_operand())
parser.parse_punctuation(",")
if cases:
parser.parse_keyword("else")
default_unresolved = parser.parse_unresolved_operand()
parser.parse_punctuation("]")
parser.parse_punctuation(":")
result_type = parser.parse_type()
ins = tuple(parser.resolve_operand(x, result_type) for x in cases)
default_value = parser.resolve_operand(default_unresolved, result_type)
attr_dict = parser.parse_optional_attr_dict()
return cls(probabilities, default_value, *ins, attr_dict=attr_dict)

@staticmethod
def print_case(c: tuple[SSAValue, int | float], printer: Printer):
operand, prob = c
printer.print_string(repr(prob) + " of ")
printer.print_operand(operand)

def print(self, printer: Printer):
printer.print_string(" [ ")
printer.print_list(
zip(self.ins, self.probabilities.as_tuple()),
lambda c: self.print_case(c, printer),
)
if self.ins:
printer.print_string(", else ")
printer.print_operand(self.default_value)
printer.print_string(" ] : ")
printer.print_attribute(self.out.type)


Prob = Dialect(
"prob",
[
BernoulliOp,
UniformOp,
FinSuppOp,
],
[],
)
67 changes: 66 additions & 1 deletion inconspiquous/transforms/canonicalization/prob.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from xdsl.dialects.arith import Constant
from xdsl.dialects.builtin import BoolAttr
from xdsl.ir import SSAValue
from xdsl.pattern_rewriter import (
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)

from inconspiquous.dialects.prob import BernoulliOp
from inconspiquous.dialects.prob import BernoulliOp, FinSuppOp


class BernoulliConst(RewritePattern):
Expand All @@ -23,3 +24,67 @@ def match_and_rewrite(self, op: BernoulliOp, rewriter: PatternRewriter):

if prob == 0.0:
rewriter.replace_matched_op(Constant(BoolAttr.from_bool(False)))


class FinSuppTrivial(RewritePattern):
"""
prob.fin_supp [ %x ] == %x
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: FinSuppOp, rewriter: PatternRewriter):
if not op.probabilities.data.data:
rewriter.replace_matched_op((), (op.default_value,))


class FinSuppRemoveCase(RewritePattern):
"""
A case can be removed if its probability is 0 or it's equal to the default case.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: FinSuppOp, rewriter: PatternRewriter):
probs = op.probabilities.as_tuple()
if not any(
p == 0.0 or c == op.default_value
for p, c in zip(probs, op.ins, strict=True)
):
return
new_probabilities = tuple(
p
for p, c in zip(probs, op.ins, strict=True)
if p != 0.0 and c != op.default_value
)
new_ins = tuple(
c
for p, c in zip(probs, op.ins, strict=True)
if p != 0.0 and c != op.default_value
)
rewriter.replace_matched_op(
FinSuppOp(new_probabilities, op.default_value, *new_ins)
)


class FinSuppDuplicate(RewritePattern):
"""
If two cases are the same then we can merge them.
"""

@op_type_rewrite_pattern
def match_and_rewrite(self, op: FinSuppOp, rewriter: PatternRewriter):
print(op.ins)
if len(set(op.ins)) == len(op.ins):
return
seen: dict[SSAValue, int] = dict()
new_probs: list[float] = []
new_ins: list[SSAValue] = []

for p, c in zip(op.probabilities.as_tuple(), op.ins, strict=True):
if c not in seen:
seen[c] = len(new_probs)
new_probs.append(p)
new_ins.append(c)
else:
new_probs[seen[c]] += p

rewriter.replace_matched_op(FinSuppOp(new_probs, op.default_value, *new_ins))
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ dev-dependencies = [
"ruff>=0.6.5",
"pyright>=1.1.380",
"pytest>=8.3.3",
"lit<16.0.0",
"filecheck==0.0.23",
"pre-commit==3.3.1",
"lit<19.0.0",
"filecheck==1.0.1",
"pre-commit==4.0.1",
"psutil>=6.0.0",
]

Expand Down
19 changes: 19 additions & 0 deletions tests/filecheck/dialects/prob/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,22 @@

// Stop them being dead code eliminated
"test.op"(%0, %1) : (i1, i1) -> ()

// CHECK: %[[#x1:]] = "test.op"() {"fin_supp_test"} : () -> i64
%2 = "test.op"() {"fin_supp_test"} : () -> i64
%3 = prob.fin_supp [ %2 ] : i64
// CHECK-NEXT: "test.op"(%[[#x1]]) : (i64) -> ()
"test.op"(%3) : (i64) -> ()

// CHECK: %[[#first:]], %[[#second:]], %[[#third:]] = "test.op"() : () -> (i32, i32, i32)
%4, %5, %6 = "test.op"() : () -> (i32, i32, i32)
// CHECK-NEXT: %{{.*}} = prob.fin_supp [ 0.375 of %[[#first]], else %[[#second]] ] : i32
%7 = prob.fin_supp [ 0.125 of %4, 0.25 of %4, else %5 ] : i32

// CHECK-NEXT: %{{.*}} = prob.fin_supp [ 0.1 of %[[#first]], else %[[#third]] ] : i32
%8 = prob.fin_supp [ 0.1 of %4, 0.0 of %5, else %6 ] : i32

// CHECK-NEXT: %{{.*}} = prob.fin_supp [ 0.2 of %[[#second]], else %[[#first]] ] : i32
%9 = prob.fin_supp [ 0.1 of %4, 0.2 of %5, else %4 ] : i32

"test.op"(%7, %8, %9) : (i32, i32, i32) -> ()
12 changes: 12 additions & 0 deletions tests/filecheck/dialects/prob/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,15 @@
// CHECK: %{{.*}} = prob.uniform : i32
// CHECK-GENERIC: %{{.*}} = "prob.uniform"() : () -> i32
%1 = prob.uniform : i32

%2, %3, %4 = "test.op"() : () -> (i64, i64, i64)

// CHECK: %{{.*}} = prob.fin_supp [ 0.1 of %{{.*}}, 0.2 of %{{.*}}, else %{{.*}} ] : i64
%5 = prob.fin_supp [
0.1 of %2,
0.2 of %3,
else %4
] : i64

// CHECK: %{{.*}} = prob.fin_supp [ %{{.*}} ] : i64
%6 = prob.fin_supp [ %4 ] : i64
18 changes: 9 additions & 9 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit ae7cd0a

Please sign in to comment.