Skip to content

Commit

Permalink
Merge branch 'main' into dependabot/pip/black-24.10.0
Browse files Browse the repository at this point in the history
  • Loading branch information
yamaguchi1024 authored Oct 13, 2024
2 parents 99acf47 + cb416d4 commit 231798d
Show file tree
Hide file tree
Showing 44 changed files with 1,402 additions and 282 deletions.
2 changes: 1 addition & 1 deletion apps/x86/conv/conv.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from exo import *
from exo.builtins import *
from exo.libs.externs import *
from exo.platforms.x86 import *
from exo.syntax import *
from exo.stdlib.scheduling import *
Expand Down
10 changes: 5 additions & 5 deletions src/exo/API.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def body(self):
block = self._root()._child_block("body")
return C.lift_cursor(block, self)

def find(self, pattern, many=False):
def find(self, pattern, many=False, call_depth=1):
"""
Find the most specific possible cursor for the given pattern.
For example, a pattern matching a single assignment statement
Expand All @@ -256,7 +256,7 @@ def find(self, pattern, many=False):
In any event, if no matches are found, a SchedulingError is raised
"""
return C.find(self._root(), self, pattern, many)
return C.find(self._root(), self, pattern, many, call_depth=call_depth + 1)

def find_loop(self, pattern, many=False):
"""
Expand All @@ -273,7 +273,7 @@ def find_loop(self, pattern, many=False):
name, count = results[1], (results[2] if results[2] else "")
pattern = f"for {name} in _: _{count}"

return self.find(pattern, many)
return self.find(pattern, many, call_depth=1)

def find_alloc_or_arg(self, pattern):
_name_count_re = r"^([a-zA-Z_]\w*)\s*(\#\s*[0-9]+)?$"
Expand All @@ -286,10 +286,10 @@ def find_alloc_or_arg(self, pattern):

pattern = f"{name}: _{count}"

return self.find(pattern)
return self.find(pattern, call_depth=1)

def find_all(self, pattern):
return self.find(pattern, many=True)
return self.find(pattern, many=True, call_depth=1)

# ---------------------------------------------- #
# execution / compilation operations
Expand Down
23 changes: 11 additions & 12 deletions src/exo/API_cursors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@ class Cursor(ABC):
| Literal( value : bool, int, or float )
| UnaryMinus( arg : Expr )
| BinaryOp( op : str, lhs : Expr, rhs : Expr )
| BuiltIn( name : str, args : ExprList )
| Extern( name : str, args : ExprList )
| WindowExpr( name : str, idx : *(see below) )
| BuiltIn( name : str, args : ExprList )
The `idx` argument of `WindowExpr` is a list containing either
`Expr` or `(Expr,Expr)` (a pair of expressions) at each position.
Expand Down Expand Up @@ -128,8 +127,8 @@ def parent(self):
return InvalidCursor()
return lift_cursor(impl_parent, self._proc)

def find(self, pattern, many=False):
return find(self._impl, self._proc, pattern, many)
def find(self, pattern, many=False, call_depth=1):
return find(self._impl, self._proc, pattern, many, call_depth=call_depth + 1)

def _child_node(self, *args, **kwargs):
return lift_cursor(self._impl._child_node(*args, **kwargs), self._proc)
Expand Down Expand Up @@ -783,21 +782,21 @@ def rhs(self) -> ExprCursor:
return self._child_node("rhs")


class BuiltInFunctionCursor(ExprCursor):
class ExternFunctionCursor(ExprCursor):
"""
Cursor pointing to the call to some built-in function
`name ( args )`
"""

def name(self) -> str:
assert isinstance(self._impl, C.Node)
assert isinstance(self._impl._node, LoopIR.BuiltIn)
assert isinstance(self._impl._node, LoopIR.Extern)

return self._impl._node.f.name()

def args(self) -> ExprListCursor:
assert isinstance(self._impl, C.Node)
assert isinstance(self._impl._node, LoopIR.BuiltIn)
assert isinstance(self._impl._node, LoopIR.Extern)

return ExprListCursor(self._impl._child_block("args"), self._proc)

Expand Down Expand Up @@ -923,8 +922,8 @@ def lift_cursor(impl, proc):
return UnaryMinusCursor(impl, proc)
elif isinstance(n, LoopIR.BinOp):
return BinaryOpCursor(impl, proc)
elif isinstance(n, LoopIR.BuiltIn):
return BuiltInFunctionCursor(impl, proc)
elif isinstance(n, LoopIR.Extern):
return ExternFunctionCursor(impl, proc)
elif isinstance(n, LoopIR.WindowExpr):
return WindowExprCursor(impl, proc)
elif isinstance(n, LoopIR.StrideExpr):
Expand All @@ -937,7 +936,7 @@ def lift_cursor(impl, proc):
assert False, f"bad case: {type(impl)}"


def find(scope: C, proc: API.Procedure, pattern: str, many: bool):
def find(scope: C, proc: API.Procedure, pattern: str, many: bool, call_depth=1):
"""
Find the most specific possible cursor for the given pattern in
the given scope of the proc. For example, a pattern matching a
Expand All @@ -953,7 +952,7 @@ def find(scope: C, proc: API.Procedure, pattern: str, many: bool):
raise TypeError("expected a pattern string")
default_match_no = None if many else 0
raw_cursors = match_pattern(
scope, pattern, call_depth=1, default_match_no=default_match_no
scope, pattern, call_depth=call_depth + 1, default_match_no=default_match_no
)
assert isinstance(raw_cursors, list)
cursors = []
Expand Down Expand Up @@ -1000,7 +999,7 @@ def find(scope: C, proc: API.Procedure, pattern: str, many: bool):
"LiteralCursor",
"UnaryMinusCursor",
"BinaryOpCursor",
"BuiltInFunctionCursor",
"ExternFunctionCursor",
"WindowExprCursor",
"StrideExprCursor",
#
Expand Down
11 changes: 4 additions & 7 deletions src/exo/API_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,8 +381,7 @@ def _cursor_call(self, expr_pattern, all_args):
self.err("expected an ExprCursor or pattern string")

proc = all_args["proc"]
# TODO: Remove all need for `call_depth`
matches = proc.find(expr_pattern, many=self.match_many)
matches = proc.find(expr_pattern, many=self.match_many, call_depth=1)

if self.match_many:
for m in matches:
Expand Down Expand Up @@ -411,8 +410,7 @@ def _cursor_call(self, stmt_pattern, all_args):
self.err("expected a StmtCursor or pattern string")

proc = all_args["proc"]
# TODO: Remove all need for `call_depth`
matches = proc.find(stmt_pattern, many=self.match_many)
matches = proc.find(stmt_pattern, many=self.match_many, call_depth=1)

match = matches[0] if self.match_many else matches
if not isinstance(match, PC.StmtCursor):
Expand Down Expand Up @@ -441,8 +439,7 @@ def _cursor_call(self, block_pattern, all_args):
self.err("expected a Cursor or pattern string")

proc = all_args["proc"]
# TODO: Remove all need for `call_depth`
matches = proc.find(block_pattern, many=self.match_many)
matches = proc.find(block_pattern, many=self.match_many, call_depth=1)

match = matches[0] if self.match_many else matches
if isinstance(match, PC.StmtCursor):
Expand Down Expand Up @@ -540,7 +537,7 @@ def _cursor_call(self, alloc_pattern, all_args):
if not isinstance(cursor, (PC.AllocCursor, PC.ArgCursor)):
proc = all_args["proc"]
try:
cursor = proc.find(alloc_pattern)
cursor = proc.find(alloc_pattern, call_depth=1)
except:
for arg in proc.args():
if arg.name() == name:
Expand Down
19 changes: 9 additions & 10 deletions src/exo/LoopIR.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from asdl_adt import ADT, validators

from .builtins import BuiltIn
from .extern import Extern
from .configs import Config
from .memory import Memory
from .prelude import Sym, SrcInfo, extclass
Expand Down Expand Up @@ -92,7 +92,7 @@ def __new__(cls, op):
| Const( object val )
| USub( expr arg ) -- i.e. -(...)
| BinOp( binop op, expr lhs, expr rhs )
| BuiltIn( builtin f, expr* args )
| Extern( extern f, expr* args )
| WindowExpr( sym name, w_access* idx )
| StrideExpr( sym name, int dim )
| ReadConfig( config config, string field )
Expand Down Expand Up @@ -130,7 +130,7 @@ def __new__(cls, op):
"name": validators.instance_of(Identifier, convert=True),
"sym": Sym,
"mem": Type[Memory],
"builtin": BuiltIn,
"extern": Extern,
"config": Config,
"binop": validators.instance_of(Operator, convert=True),
"srcinfo": SrcInfo,
Expand Down Expand Up @@ -190,7 +190,7 @@ def __new__(cls, op):
| Const ( object val )
| USub ( expr arg ) -- i.e. -(...)
| BinOp ( op op, expr lhs, expr rhs )
| BuiltIn( builtin f, expr* args )
| Extern( extern f, expr* args )
| WindowExpr( sym name, w_access* idx )
| StrideExpr( sym name, int dim )
| ParRange( expr lo, expr hi ) -- only use for loop cond
Expand Down Expand Up @@ -221,7 +221,7 @@ def __new__(cls, op):
"name": validators.instance_of(Identifier, convert=True),
"sym": Sym,
"mem": Type[Memory],
"builtin": BuiltIn,
"extern": Extern,
"config": Config,
"loopir_proc": LoopIR.proc,
"op": validators.instance_of(Operator, convert=True),
Expand Down Expand Up @@ -270,14 +270,13 @@ def __new__(cls, op):
| Const ( object val )
| USub ( expr arg ) -- i.e. -(...)
| BinOp ( op op, expr lhs, expr rhs )
| BuiltIn ( builtin f, expr* args )
| Extern ( name f, expr* args )
| ReadConfig( string config, string field )
attributes( srcinfo srcinfo )
} """,
ext_types={
"name": validators.instance_of(IdentifierOrHole, convert=True),
"builtin": BuiltIn,
"op": validators.instance_of(Operator, convert=True),
"srcinfo": SrcInfo,
},
Expand Down Expand Up @@ -673,7 +672,7 @@ def map_e(self, e):
rhs=new_rhs or e.rhs,
type=new_type or e.type,
)
elif isinstance(e, LoopIR.BuiltIn):
elif isinstance(e, LoopIR.Extern):
new_type = self.map_t(e.type)
new_args = self.map_exprs(e.args)
if any((new_type, new_args is not None)):
Expand Down Expand Up @@ -810,7 +809,7 @@ def do_e(self, e):
elif etyp is LoopIR.BinOp:
self.do_e(e.lhs)
self.do_e(e.rhs)
elif etyp is LoopIR.BuiltIn:
elif etyp is LoopIR.Extern:
for a in e.args:
self.do_e(a)
elif etyp is LoopIR.USub:
Expand Down Expand Up @@ -914,7 +913,7 @@ def match_e(self, e1, e2):
and self.match_e(e1.lhs, e2.lhs)
and self.match_e(e1.rhs, e2.rhs)
)
elif isinstance(e1, LoopIR.BuiltIn):
elif isinstance(e1, LoopIR.Extern):
# TODO: check f equality
return e1.f is e2.f and all(
self.match_e(a1, a2) for a1, a2 in zip(e1.args, e2.args)
Expand Down
46 changes: 25 additions & 21 deletions src/exo/LoopIR_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,18 +196,18 @@ def do_t(self, t):
pass


class LoopIR_FindBuiltIns(LoopIR_Do):
class LoopIR_FindExterns(LoopIR_Do):
def __init__(self, proc):
self._builtins = set()
self._externs = set()
super().__init__(proc)

def result(self):
return self._builtins
return self._externs

# to improve efficiency
def do_e(self, e):
if isinstance(e, LoopIR.BuiltIn):
self._builtins.add(e.f)
if isinstance(e, LoopIR.Extern):
self._externs.add((e.f, e.type.basetype().ctype()))
else:
super().do_e(e)

Expand Down Expand Up @@ -247,12 +247,12 @@ def find_all_mems(proc_list):
return [m for m in mems]


def find_all_builtins(proc_list):
builtins = set()
def find_all_externs(proc_list):
externs = set()
for p in proc_list:
builtins.update(LoopIR_FindBuiltIns(p).result())
externs.update(LoopIR_FindExterns(p).result())

return [b for b in builtins]
return externs


def find_all_configs(proc_list):
Expand Down Expand Up @@ -376,10 +376,10 @@ def from_lines(x):

# Body contents
memory_code = _compile_memories(find_all_mems(proc_list))
builtin_code = _compile_builtins(find_all_builtins(proc_list))
private_fwd_decls = []
proc_bodies = []
instrs_global = []
analyzed_proc_list = []

needed_helpers = set()

Expand Down Expand Up @@ -424,6 +424,8 @@ def from_lines(x):

proc_bodies.append(b)

analyzed_proc_list.append(p)

# Structs are just blobs of code... still sort them for output stability
struct_defns = [x.definition for x in sorted(struct_defns, key=lambda x: x.name)]

Expand Down Expand Up @@ -454,12 +456,14 @@ def from_lines(x):
{from_lines(public_fwd_decls)}
"""

extern_code = _compile_externs(find_all_externs(analyzed_proc_list))

helper_code = [_static_helpers[v] for v in needed_helpers]
body_contents = [
helper_code,
instrs_global,
memory_code,
builtin_code,
extern_code,
private_fwd_decls,
proc_bodies,
]
Expand All @@ -470,12 +474,12 @@ def from_lines(x):
return header_contents, body_contents


def _compile_builtins(builtins):
builtin_code = []
for b in sorted(builtins, key=lambda x: x.name()):
if glb := b.globl():
builtin_code.append(glb)
return builtin_code
def _compile_externs(externs):
extern_code = []
for f, t in sorted(externs, key=lambda x: x[0].name() + x[1]):
if glb := f.globl(t):
extern_code.append(glb)
return extern_code


def _compile_memories(mems):
Expand Down Expand Up @@ -971,7 +975,7 @@ def comp_fnarg(self, e, fn, i, *, prec=0):
x for x, _ in get_writes_of_stmts(fn.body)
)
else:
raise NotImplementedError("Passing windows to built-ins")
raise NotImplementedError("Passing windows to externs")
win_struct = self.get_window_type(e.type, is_const)
data, strides = self.window_struct_fields(e)
return f"(struct {win_struct}){{ &{data}, {{ {strides} }} }}"
Expand Down Expand Up @@ -1044,9 +1048,9 @@ def comp_e(self, e, prec=0):
elif isinstance(e, LoopIR.USub):
return f'-{self.comp_e(e.arg, op_prec["~"])}'

elif isinstance(e, LoopIR.BuiltIn):
args = [self.comp_fnarg(a, e, i) for i, a in enumerate(e.args)]
return e.f.compile(args)
elif isinstance(e, LoopIR.Extern):
args = [self.comp_e(a) for a in e.args]
return e.f.compile(args, e.type.basetype().ctype())

elif isinstance(e, LoopIR.StrideExpr):
basetyp = self.envtyp[e.name]
Expand Down
4 changes: 2 additions & 2 deletions src/exo/LoopIR_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def pacc(w):
return f"{self.get_name(e.name)}[{', '.join([pacc(w) for w in e.idx])}]"
elif isinstance(e, UAST.StrideExpr):
return f"stride({self.get_name(e.name)}, {e.dim})"
elif isinstance(e, UAST.BuiltIn):
elif isinstance(e, UAST.Extern):
pname = e.f.name() or "_anon_"
args = [self.pexpr(a) for a in e.args]
return f"{pname}({','.join(args)})"
Expand Down Expand Up @@ -507,7 +507,7 @@ def _print_expr(e, env: PrintEnv, prec: int = 0) -> str:
elif isinstance(e, LoopIR.StrideExpr):
return f"stride({env.get_name(e.name)}, {e.dim})"

elif isinstance(e, LoopIR.BuiltIn):
elif isinstance(e, LoopIR.Extern):
pname = e.f.name() or "_anon_"
args = [_print_expr(a, env) for a in e.args]
return f"{pname}({', '.join(args)})"
Expand Down
Loading

0 comments on commit 231798d

Please sign in to comment.