Skip to content

Commit

Permalink
runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
yamaguchi1024 committed Aug 17, 2024
1 parent bd6ae4a commit 71490bb
Show file tree
Hide file tree
Showing 2 changed files with 118 additions and 127 deletions.
148 changes: 118 additions & 30 deletions src/exo/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .dataflow import (
LoopIR_to_DataflowIR,
ScalarPropagation,
GetControlPredicates,
GetValues,
D,
)
Expand Down Expand Up @@ -376,6 +375,71 @@ def lift_es(es):
return [lift_e(e) for e in es]


# --------------------------------------------------------------------------- #
# Getting control flow on DataflowIR. Will be unnecessary when we
# integrate control flow into abstract values.
# --------------------------------------------------------------------------- #


class GetControlPredicates(LoopIR_Do):
def __init__(self, proc, stmts):
self.proc = proc
self.stmts = stmts
self.preds = None
self.done = False
self.cur_preds = []

for a in self.proc.args:
if isinstance(a.type, T.Size):
size_pred = A.BinOp(
"<",
A.Const(0, T.int, null_srcinfo()),
A.Var(a.name, T.size, a.srcinfo),
T.bool,
null_srcinfo(),
)
self.cur_preds.append(size_pred)
self.do_t(a.type)

for pred in self.proc.preds:
self.cur_preds.append(lift_e(pred))
self.do_e(pred)

self.do_stmts(self.proc.body)

def do_s(self, s):
if self.done:
return

if s == self.stmts[0]:
self.preds = AAnd(*self.cur_preds)
self.done = True

styp = type(s)
if styp is LoopIR.If:
self.cur_preds.append(lift_e(s.cond))
self.do_stmts(s.body)
self.cur_preds.pop()

self.cur_preds.append(A.Not(lift_e(s.cond), T.int, null_srcinfo()))
self.do_stmts(s.orelse)
self.cur_preds.pop()

elif styp is LoopIR.For:
a_iter = A.Var(s.iter, T.int, s.srcinfo)
b1 = A.BinOp("<=", lift_e(s.lo), a_iter, T.bool, null_srcinfo())
b2 = A.BinOp("<", a_iter, lift_e(s.hi), T.bool, null_srcinfo())
cond = A.BinOp("and", b1, b2, T.bool, null_srcinfo())
self.cur_preds.append(cond)
self.do_stmts(s.body)
self.cur_preds.pop()

super().do_s(s)

def result(self):
return self.preds.simplify()


# Produce a set of AExprs which occur as right-hand-sides
# of config writes.
def possible_config_writes(stmts):
Expand Down Expand Up @@ -1531,11 +1595,13 @@ def loop_globenv(i, lo_expr, hi_expr, body):


def Check_ReorderStmts(proc, s1, s2):
datair, stmts = LoopIR_to_DataflowIR(proc, [s1, s2]).result()
# datair, stmts = LoopIR_to_DataflowIR(proc, [s1, s2]).result()

# print("here in ReorderStmts")

assert len(stmts) == 2
assert isinstance(s1, LoopIR.stmt) and isinstance(s2, LoopIR.stmt)

p = GetControlPredicates(datair, stmts).result()
p = GetControlPredicates(proc, [s1, s2]).result()

slv = SMTSolver(verbose=False)
slv.push()
Expand All @@ -1554,11 +1620,13 @@ def Check_ReorderStmts(proc, s1, s2):


def Check_ReorderLoops(proc, s):
datair, stmts = LoopIR_to_DataflowIR(proc, [s]).result()
# datair, stmts = LoopIR_to_DataflowIR(proc, [s]).result()

assert len(stmts) == 1
# print("here in ReorderLoops")

p = GetControlPredicates(datair, stmts).result()
assert isinstance(s, LoopIR.For)

p = GetControlPredicates(proc, [s]).result()

slv = SMTSolver(verbose=False)
slv.push()
Expand Down Expand Up @@ -1632,11 +1700,13 @@ def bds(x, lo, hi):
# /\ ( forall i,i'. May(InBound(i,i',e) /\ i < i') => Commutes(a1', a1) )
#
def Check_ParallelizeLoop(proc, s):
datair, stmts = LoopIR_to_DataflowIR(proc, [s]).result()
# datair, stmts = LoopIR_to_DataflowIR(proc, [s]).result()

# print("Check_ParallelizeLoop")

assert len(stmts) == 1
assert isinstance(s, LoopIR.For)

p = GetControlPredicates(datair, stmts).result()
p = GetControlPredicates(proc, [s]).result()

slv = SMTSolver(verbose=False)
slv.push()
Expand Down Expand Up @@ -1688,9 +1758,11 @@ def bds(x, lo, hi):
#
def Check_FissionLoop(proc, loop, stmts1, stmts2, no_loop_var_1=False):

datair, d_loop = LoopIR_to_DataflowIR(proc, [loop]).result()
# print("Check_FissionLoop")

p = GetControlPredicates(datair, d_loop).result()
# datair, d_loop = LoopIR_to_DataflowIR(proc, [loop]).result()

p = GetControlPredicates(proc, [loop]).result()

slv = SMTSolver(verbose=False)
slv.push()
Expand Down Expand Up @@ -1774,9 +1846,9 @@ def lift_dexpr(e, key=None):
def Check_DeleteConfigWrite(proc, stmts):
assert len(stmts) > 0

ir1, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
p = GetControlPredicates(ir1, d_stmts).result()
print("here in DeleteConfigWrite")

p = GetControlPredicates(proc, stmts).result()
slv = SMTSolver(verbose=False)
slv.push()
slv.assume(AMay(p))
Expand All @@ -1801,6 +1873,7 @@ def Check_DeleteConfigWrite(proc, stmts):
)

# Below are the actual checks
ir1, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()

ScalarPropagation(ir1)

Expand Down Expand Up @@ -1869,6 +1942,8 @@ def Check_ExtendEqv(proc1, proc2, stmts1, stmts2, cfg_mod):
assert len(stmts1) == 1
assert len(stmts2) == 1

print("here in Check_ExtendEqv")

slv = SMTSolver(verbose=False)
slv.push()

Expand Down Expand Up @@ -1928,16 +2003,18 @@ def make_point(key):


def Check_ExprEqvInContext(proc, expr0, stmts0, expr1, stmts1=None):

# print("Check_ExprEqvInContext")
assert len(stmts0) > 0
stmts1 = stmts1 or stmts0

len_0 = len(stmts0)
datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts0 + stmts1).result()
d_stmts0 = d_stmts[0:len_0]
d_stmts1 = d_stmts[len_0:]
# len_0 = len(stmts0)
# datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts0 + stmts1).result()
# d_stmts0 = d_stmts[0:len_0]
# d_stmts1 = d_stmts[len_0:]

p0 = GetControlPredicates(datair, d_stmts0).result()
p1 = GetControlPredicates(datair, d_stmts1).result()
p0 = GetControlPredicates(proc, stmts0).result()
p1 = GetControlPredicates(proc, stmts1).result()

slv = SMTSolver(verbose=False)
slv.push()
Expand All @@ -1954,11 +2031,13 @@ def Check_ExprEqvInContext(proc, expr0, stmts0, expr1, stmts1=None):


def Check_BufferReduceOnly(proc, stmts, buf, ndim):

print("Check_BufferReduceOnly")
assert len(stmts) > 0

datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
# datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()

p = GetControlPredicates(datair, d_stmts).result()
p = GetControlPredicates(proc, stmts).result()

slv = SMTSolver(verbose=False)
slv.push()
Expand Down Expand Up @@ -1988,13 +2067,15 @@ def Check_Access_In_Window(proc, access_cursor, w_exprs, block_cursor):
block_cursor is the context in which to interpret the access in.
"""

# print("Check_Access_In_Window")

access = access_cursor._node
block = [x._node for x in block_cursor]
idxs = access.idx
assert len(idxs) == len(w_exprs)

datair, d_stmts = LoopIR_to_DataflowIR(proc, block).result()
p = GetControlPredicates(datair, d_stmts).result()
# datair, d_stmts = LoopIR_to_DataflowIR(proc, block).result()
p = GetControlPredicates(proc, block).result()

slv = SMTSolver(verbose=False)
slv.push()
Expand Down Expand Up @@ -2067,9 +2148,10 @@ def Check_Bounds(proc, alloc_stmt, block):
if len(block) == 0:
return

datair, stmts = LoopIR_to_DataflowIR(proc, block).result()
# print("Check_Bounds")
# datair, stmts = LoopIR_to_DataflowIR(proc, block).result()

p = GetControlPredicates(datair, stmts).result()
p = GetControlPredicates(proc, block).result()

slv = SMTSolver(verbose=False)
slv.push()
Expand Down Expand Up @@ -2105,6 +2187,8 @@ def Check_Bounds(proc, alloc_stmt, block):


def Check_IsDeadAfter(proc, stmts, bufname, ndim):

print("Check_IsDeadAfter")
assert len(stmts) > 0

ap = PostEnv(proc, stmts).get_posteffs()
Expand All @@ -2126,11 +2210,13 @@ def Check_IsDeadAfter(proc, stmts, bufname, ndim):


def Check_IsIdempotent(proc, stmts):

print("Check_IsIdempotent")
assert len(stmts) > 0

datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
# datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()

p = GetControlPredicates(datair, d_stmts).result()
p = GetControlPredicates(proc, stmts).result()

slv = SMTSolver(verbose=False)
slv.push()
Expand All @@ -2144,10 +2230,11 @@ def Check_IsIdempotent(proc, stmts):


def Check_ExprBound(proc, stmts, expr, op, value, exception=True):
print("Check_ExprBound")
assert len(stmts) > 0

datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
p = GetControlPredicates(datair, d_stmts).result()
# datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result()
p = GetControlPredicates(proc, stmts).result()

# TODO: Check_ExprBound does not depend on configuration states so this can be skipped, but more fundamentally running abstract interpretation this many times is simply too slow.
# ScalarPropagation(datair)
Expand Down Expand Up @@ -2335,5 +2422,6 @@ def do_s(self, s):


def Check_Aliasing(proc):
print("Check_Aliasing")
helper = _Check_Aliasing_Helper(proc)
# that's it
97 changes: 0 additions & 97 deletions src/exo/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,100 +829,3 @@ def abs_builtin(self, builtin, args):

# TODO: write a short circuit for select builtin
return D.Const(builtin.interpret(vargs), args[0].typ)


# --------------------------------------------------------------------------- #
# Getting control flow on DataflowIR. Will be unnecessary when we
# integrate control flow into abstract values.
# --------------------------------------------------------------------------- #


def lift_dataflow(e):
if e.type.is_indexable() or e.type.is_stridable() or e.type == T.bool:
if isinstance(e, DataflowIR.Read):
assert len(e.idx) == 0
return A.Var(e.name, e.type, e.srcinfo)
elif isinstance(e, DataflowIR.Const):
return A.Const(e.val, e.type, e.srcinfo)
elif isinstance(e, DataflowIR.BinOp):
return A.BinOp(
e.op, lift_dataflow(e.lhs), lift_dataflow(e.rhs), e.type, e.srcinfo
)
elif isinstance(e, DataflowIR.USub):
return A.USub(lift_dataflow(e.arg), e.type, e.srcinfo)
elif isinstance(e, DataflowIR.StrideExpr):
return A.Stride(e.name, e.dim, e.type, e.srcinfo)
elif isinstance(e, DataflowIR.ReadConfig):
return A.Var(e.config_field, e.type, e.srcinfo)
else:
f"bad case: {type(e)}"
else:
assert e.type.is_numeric()
if e.type.is_real_scalar():
if isinstance(e, DataflowIR.Const):
return A.Const(e.val, e.type, e.srcinfo)
elif isinstance(e, DataflowIR.Read):
return A.ConstSym(e.name, e.type, e.srcinfo)
elif isinstance(e, DataflowIR.ReadConfig):
return A.Var(e.config_field, e.type, e.srcinfo)

return A.Unk(T.err, e.srcinfo)


class GetControlPredicates(DataflowIR_Do):
def __init__(self, datair, stmts):
self.datair = datair
self.stmts = stmts
self.preds = None
self.done = False
self.cur_preds = []

for a in self.datair.args:
if isinstance(a.type, T.Size):
size_pred = A.BinOp(
"<",
A.Const(0, T.int, null_srcinfo()),
A.Var(a.name, T.size, a.srcinfo),
T.bool,
null_srcinfo(),
)
self.cur_preds.append(size_pred)
self.do_t(a.type)

for pred in self.datair.preds:
self.cur_preds.append(lift_dataflow(pred))
self.do_e(pred)

self.do_stmts(self.datair.body.stmts)

def do_s(self, s):
if self.done:
return

if s == self.stmts[0]:
self.preds = AAnd(*self.cur_preds)
self.done = True

styp = type(s)
if styp is DataflowIR.If:
self.cur_preds.append(lift_dataflow(s.cond))
self.do_stmts(s.body.stmts)
self.cur_preds.pop()

self.cur_preds.append(A.Not(lift_dataflow(s.cond), T.int, null_srcinfo()))
self.do_stmts(s.orelse.stmts)
self.cur_preds.pop()

elif styp is DataflowIR.For:
a_iter = A.Var(s.iter, T.int, s.srcinfo)
b1 = A.BinOp("<=", lift_dataflow(s.lo), a_iter, T.bool, null_srcinfo())
b2 = A.BinOp("<", a_iter, lift_dataflow(s.hi), T.bool, null_srcinfo())
cond = A.BinOp("and", b1, b2, T.bool, null_srcinfo())
self.cur_preds.append(cond)
self.do_stmts(s.body.stmts)
self.cur_preds.pop()

super().do_s(s)

def result(self):
return self.preds.simplify()

0 comments on commit 71490bb

Please sign in to comment.