Skip to content

Commit

Permalink
Distinguish type error from name error in pyparser (#733)
Browse files Browse the repository at this point in the history
Distinguish undefined variable error from type error in UAST parser
  • Loading branch information
akeley98 authored Oct 23, 2024
1 parent 2c31b61 commit cd5f597
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 6 deletions.
11 changes: 7 additions & 4 deletions src/exo/frontend/pyparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,8 +883,8 @@ def parse_expr(self, e):
nm = self.locals[nm_node.id]
elif nm_node.id in self.globals:
nm = self.globals[nm_node.id]
else:
nm = None
else: # could not resolve name to anything
self.err(nm_node, f"variable '{nm_node.id}' undefined")

if isinstance(nm, SizeStub):
nm = nm.nm
Expand All @@ -899,8 +899,11 @@ def parse_expr(self, e):
)
else:
return UAST.Const(nm, self.getsrcinfo(e))
else: # could not resolve name to anything
self.err(nm_node, f"variable '{nm_node.id}' undefined")
else:
self.err(
nm_node,
f"variable '{nm_node.id}' has unsupported type {type(nm)}",
)

if is_window:
return UAST.WindowExpr(nm, idxs, self.getsrcinfo(e))
Expand Down
66 changes: 64 additions & 2 deletions tests/test_uast.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from __future__ import annotations

import pytest

from exo import DRAM
from exo.frontend.pyparser import Parser, get_src_locals, get_ast_from_python
from exo.frontend.pyparser import (
Parser,
get_src_locals,
get_ast_from_python,
ParseError,
)


def to_uast(f):
Expand All @@ -10,7 +17,7 @@ def to_uast(f):
body,
getsrcinfo,
func_globals=f.__globals__,
srclocals=get_src_locals(depth=3),
srclocals=get_src_locals(depth=2),
instr=("TEST", ""),
as_func=True,
)
Expand Down Expand Up @@ -57,3 +64,58 @@ def alloc_nest(
res[i, j] = rloc[j]

assert str(to_uast(alloc_nest)) == golden


global_str = "What is 6 times 9?"
global_num = 42


def test_variable_lookup_positive():
def func(f: f32):
for i in seq(0, 42):
f += 1

reference = to_uast(func)

def func(f: f32):
for i in seq(0, global_num):
f += 1

test_global = to_uast(func)
assert str(test_global) == str(reference)

local_num = 42

def func(f: f32):
for i in seq(0, local_num):
f += 1

test_local = to_uast(func)
assert str(test_local) == str(reference)


def test_variable_lookup_type_error():
def func(f: f32):
for i in seq(0, global_str):
f += 1

with pytest.raises(ParseError, match="type <class 'str'>"):
to_uast(func)

local_str = "xyzzy"

def func(f: f32):
for i in seq(0, local_str):
f += 1

with pytest.raises(ParseError, match="type <class 'str'>"):
to_uast(func)


def test_variable_lookup_name_error():
def func(f: f32):
for i in seq(0, xyzzy):
f += 1

with pytest.raises(ParseError, match="'xyzzy' undefined"):
to_uast(func)

0 comments on commit cd5f597

Please sign in to comment.