Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kehemo committed Oct 29, 2024
1 parent 4ca7e23 commit cf2593a
Show file tree
Hide file tree
Showing 24 changed files with 64 additions and 87 deletions.
2 changes: 1 addition & 1 deletion src/exo/frontend/pattern_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def match_pattern(
# get source location where this is getting called from
caller = inspect.getframeinfo(stack_frames[call_depth][0])
func_locals = ChainMap(stack_frames[call_depth].frame.f_locals)
func_globals = ChainMap(stack_frames[call_depth].frame.f_globals)
func_globals = stack_frames[call_depth].frame.f_globals

# parse the pattern we're going to use to match
p_ast = pyparser.pattern(
Expand Down
22 changes: 15 additions & 7 deletions src/exo/frontend/pyparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,14 @@ def pattern(s, filename=None, lineno=None, srclocals=None, srcglobals=None):
SourceInfo(
src_file=srcfilename, src_line_offset=srclineno, src_col_offset=n_dedent
),
parent_scope=DummyScope({}, {}), # add globals from enclosing scope
parent_scope=DummyScope(
srcglobals if srcglobals is not None else {},
(
{k: BoundLocal(v) for k, v in srclocals.items()}
if srclocals is not None
else {}
),
), # add globals from enclosing scope
is_fragment=True,
)
return parser.result()
Expand Down Expand Up @@ -495,10 +502,10 @@ def __init__(

self.push()
special_cases = ["stride"]
for key, val in self.globals.items():
for key, val in parent_scope.get_globals().items():
if isinstance(val, Extern):
special_cases.append(key)
for key, val in self.locals.items():
for key, val in parent_scope.read_locals().items():
if isinstance(val, Extern):
special_cases.append(key)

Expand Down Expand Up @@ -579,6 +586,7 @@ def try_eval_unquote(
isinstance(unquote_node, pyast.Name)
and isinstance(unquote_node.ctx, pyast.Load)
and unquote_node.id not in self.exo_locals
and not self.is_fragment
):
cur_globals = self.parent_scope.get_globals()
cur_locals = self.parent_scope.read_locals()
Expand Down Expand Up @@ -860,8 +868,8 @@ def parse_num_type(self, node, is_arg=False):

return typ

elif isinstance(node, pyast.Name) and node.id in Parser._prim_types:
return Parser._prim_types[node.id]
elif isinstance(node, pyast.Name) and node.id in _prim_types:
return _prim_types[node.id]
elif isinstance(node, pyast.Name) and (
_is_size(node) or _is_stride(node) or _is_index(node) or _is_bool(node)
):
Expand All @@ -872,8 +880,8 @@ def parse_num_type(self, node, is_arg=False):
unquote_eval_result = self.try_eval_unquote(node)
if len(unquote_eval_result) == 1:
unquoted = unquote_eval_result[0]
if isinstance(unquoted, str) and unquoted in Parser._prim_types:
return Parser._prim_types[unquoted]
if isinstance(unquoted, str) and unquoted in _prim_types:
return _prim_types[unquoted]
else:
self.err(node, "Unquote computation did not yield valid type")

Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// foo(
// a : i32 @DRAM
// )
Expand Down
4 changes: 0 additions & 4 deletions tests/golden/test_metaprogramming/test_captured_closure.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// bar(
// a : i32 @DRAM
// )
Expand Down
4 changes: 0 additions & 4 deletions tests/golden/test_metaprogramming/test_conditional.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// bar1(
// a : i8 @DRAM
// )
Expand Down
4 changes: 0 additions & 4 deletions tests/golden/test_metaprogramming/test_constant_lifting.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// foo(
// a : f64 @DRAM
// )
Expand Down
4 changes: 0 additions & 4 deletions tests/golden/test_metaprogramming/test_quote_complex_expr.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// foo(
// a : i32 @DRAM
// )
Expand Down
4 changes: 0 additions & 4 deletions tests/golden/test_metaprogramming/test_quote_elision.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// foo(
// a : i32 @DRAM,
// b : i32 @DRAM
Expand Down
4 changes: 0 additions & 4 deletions tests/golden/test_metaprogramming/test_scope_collision1.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// foo(
// a : i32 @DRAM
// )
Expand Down
4 changes: 0 additions & 4 deletions tests/golden/test_metaprogramming/test_scope_collision2.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// foo(
// a : i32 @DRAM,
// b : i32 @DRAM
Expand Down
4 changes: 0 additions & 4 deletions tests/golden/test_metaprogramming/test_scope_nesting.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// foo(
// a : i8 @DRAM,
// b : i8 @DRAM
Expand Down
4 changes: 0 additions & 4 deletions tests/golden/test_metaprogramming/test_scoping.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// foo(
// a : i8 @DRAM
// )
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// foo(
// a : i32 @DRAM
// )
Expand Down
21 changes: 21 additions & 0 deletions tests/golden/test_metaprogramming/test_statements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// foo(
// a : i32 @DRAM
// )
void foo( void *ctxt, int32_t* a ) {
*a += ((int32_t) 1);
*a += ((int32_t) 1);
for (int_fast32_t i = 0; i < 2; i++) {
*a += ((int32_t) 1);
*a += ((int32_t) 1);
}
}

4 changes: 0 additions & 4 deletions tests/golden/test_metaprogramming/test_type_params.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// bar1(
// a : i32 @DRAM,
// b : i8 @DRAM
Expand Down
4 changes: 0 additions & 4 deletions tests/golden/test_metaprogramming/test_type_quote_elision.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// foo(
// a : i8 @DRAM,
// x : i8[2] @DRAM
Expand Down
4 changes: 0 additions & 4 deletions tests/golden/test_metaprogramming/test_unquote_elision.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// foo(
// a : i32 @DRAM
// )
Expand Down
4 changes: 0 additions & 4 deletions tests/golden/test_metaprogramming/test_unquote_in_slice.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// bar(
// a : i8[10, 10] @DRAM
// )
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// bar(
// a : i8[10, 10, 10] @DRAM
// )
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// bar(
// a : i8[10, 10] @DRAM
// )
Expand Down
4 changes: 0 additions & 4 deletions tests/golden/test_metaprogramming/test_unrolling.txt
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#include "test.h"



#include <stdio.h>
#include <stdlib.h>



// foo(
// a : i8 @DRAM
// )
Expand Down
19 changes: 18 additions & 1 deletion tests/test_metaprogramming.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
from exo import proc, compile_procs_to_strings
from exo.API_scheduling import rename
from exo.pyparser import ParseError
from exo.frontend.pyparser import ParseError
import pytest


Expand Down Expand Up @@ -338,3 +338,20 @@ def foo(a: i32):

c_file, _ = compile_procs_to_strings([foo], "test.h")
assert c_file == golden


def test_statement_in_expr():
with pytest.raises(ParseError):

@proc
def foo(a: i32):
with meta:

def bar():
with ~meta:
a += 1
return 2

with ~meta:
a += {bar()}
a += {bar()}
4 changes: 2 additions & 2 deletions tests/test_typecheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,14 @@ def foo(n: size, A: R[n] @ GEMM_SCRATCH):

def test_sin1():
@proc
def sin(x: f32):
def sin_proc(x: f32):
y: f32
y = sin(x)


def test_sin2():
@proc
def sin(x: f32):
def sin_proc(x: f32):
y: f32
if False:
y = sin(x)
Expand Down
Loading

0 comments on commit cf2593a

Please sign in to comment.