Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cursor printing errors #720

Merged
merged 5 commits into from
Oct 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 24 additions & 9 deletions src/exo/LoopIR_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,9 @@ def _print_w_access(node, env: PrintEnv) -> str:


def _print_cursor(cur):
if cur == None:
raise InvalidCursorError("Trying to print the Invalid Cursor!")

if isinstance(cur, Node) and not isinstance(cur._node, (LoopIR.proc, LoopIR.stmt)):
raise NotImplementedError(
"Cursor printing is only implemented for procs and statements"
Expand Down Expand Up @@ -625,31 +628,43 @@ def _print_cursor_proc(
def _print_cursor_block(
cur: Block, target: Cursor, env: PrintEnv, indent: str
) -> list[str]:
def while_cursor(c, move, k):
def while_next(c):
s = []
while True:
try:
c = move(c)
s.expand(k(c))
c = c.next()
s.extend(local_stmt(c))
except:
return s

def while_prev(c):
s = []
while True:
try:
c = c.prev()
s.append(local_stmt(c))
except:
s.reverse()
return [x for xs in s for x in xs]

def local_stmt(c):
return _print_cursor_stmt(c, target, env, indent)

if isinstance(target, Gap) and target in cur:
if target._type == GapType.Before:
return [
*while_cursor(target.anchor(), lambda g: g.prev(), local_stmt),
*while_prev(target.anchor()),
f"{indent}[GAP - Before]",
*while_cursor(target.anchor(), lambda g: g.next(), local_stmt),
*local_stmt(target.anchor()),
*while_next(target.anchor()),
]
else:
assert target._type == GapType.After
return [
*while_cursor(target.anchor(), lambda g: g.prev(), local_stmt),
*while_prev(target.anchor()),
*local_stmt(target.anchor()),
f"{indent}[GAP - After]",
*while_cursor(target.anchor(), lambda g: g.next(), local_stmt),
*while_next(target.anchor()),
]

elif isinstance(target, Block) and target in cur:
Expand All @@ -658,9 +673,9 @@ def local_stmt(c):
block.extend(local_stmt(stmt))
block.append(f"{indent}# BLOCK END")
return [
*while_cursor(target[0], lambda g: g.prev(), local_stmt),
*while_prev(target[0]),
*block,
*while_cursor(target[-1], lambda g: g.next(), local_stmt),
*while_next(target[-1]),
]

else:
Expand Down
81 changes: 81 additions & 0 deletions tests/golden/test_cursors/test_cursor_print.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
def foo(n: size, x: i8[n] @ DRAM):
for j in seq(0, n - 1):
x[j] = 2.0
for i in seq(0, n):
pass
if n > 1:
for i in seq(0, n): # <-- NODE
x[i] = 0.0
for j in seq(0, n - 1):
x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM):
for j in seq(0, n - 1):
x[j] = 2.0
for i in seq(0, n):
pass
if n > 1:
[GAP - Before]
for i in seq(0, n):
x[i] = 0.0
for j in seq(0, n - 1):
x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM):
for j in seq(0, n - 1):
x[j] = 2.0
for i in seq(0, n):
pass
if n > 1:
for i in seq(0, n):
x[i] = 0.0
[GAP - After]
for j in seq(0, n - 1):
x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM):
for j in seq(0, n - 1):
x[j] = 2.0
for i in seq(0, n): # <-- NODE
pass
if n > 1:
for i in seq(0, n):
x[i] = 0.0
for j in seq(0, n - 1):
x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM):
for j in seq(0, n - 1):
x[j] = 2.0
[GAP - Before]
for i in seq(0, n):
pass
if n > 1:
for i in seq(0, n):
x[i] = 0.0
for j in seq(0, n - 1):
x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM):
for j in seq(0, n - 1):
x[j] = 2.0
for i in seq(0, n):
pass
[GAP - After]
if n > 1:
for i in seq(0, n):
x[i] = 0.0
for j in seq(0, n - 1):
x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM):
# BLOCK START
for j in seq(0, n - 1):
x[j] = 2.0
for i in seq(0, n):
pass
# BLOCK END
if n > 1:
for i in seq(0, n):
x[i] = 0.0
for j in seq(0, n - 1):
x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM):
for j in seq(0, n - 1):
x[j] = 2.0
# BLOCK START
for i in seq(0, n):
pass
if n > 1:
for i in seq(0, n):
x[i] = 0.0
# BLOCK END
for j in seq(0, n - 1):
x[j] = 3.0
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,18 @@ def baz(n: size, m: size):
# BLOCK START
x: f32 @ DRAM
# BLOCK END
pass
pass
for k in seq(0, n):
pass
pass

def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
pass
pass
# BLOCK START
for k in seq(0, n):
pass
Expand All @@ -51,7 +59,11 @@ def baz(n: size, m: size):
def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
# BLOCK START
pass
pass
# BLOCK END
# BLOCK END
for k in seq(0, n):
pass
pass
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
x = 0.0
# BLOCK START
x = 1.0
x = 2.0
x = 3.0
# BLOCK END
x = 4.0
x = 5.0

def bar(n: size, m: size):
x: f32 @ DRAM
Expand All @@ -29,11 +32,19 @@ def bar(n: size, m: size):
x = 0.0
x = 1.0
# BLOCK END
x = 2.0
x = 3.0
x = 4.0
x = 5.0

def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
x = 0.0
x = 1.0
x = 2.0
x = 3.0
# BLOCK START
x = 4.0
x = 5.0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,71 @@
def bar(n: size, m: size):
[GAP - Before]
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
x = 0.0
x = 1.0
x = 2.0
x = 3.0
x = 4.0
x = 5.0

def bar(n: size, m: size):
x: f32 @ DRAM
[GAP - Before]
for i in seq(0, n):
for j in seq(0, m):
x = 0.0
x = 1.0
x = 2.0
x = 3.0
x = 4.0
x = 5.0

def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
[GAP - Before]
for j in seq(0, m):
x = 0.0
x = 1.0
x = 2.0
x = 3.0
x = 4.0
x = 5.0

def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
[GAP - Before]
x = 0.0
x = 1.0
x = 2.0
x = 3.0
x = 4.0
x = 5.0

def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
x = 0.0
x = 1.0
[GAP - Before]
x = 2.0
x = 3.0
x = 4.0
x = 5.0

def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
x = 0.0
x = 1.0
x = 2.0
x = 3.0
x = 4.0
x = 5.0
[GAP - After]
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@ def baz(n: size, m: size):
# BLOCK START
x: f32 @ DRAM
# BLOCK END
y = 1.1
for k in seq(0, n):
pass
pass

def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
# BLOCK START
y = 1.1
for k in seq(0, n):
Expand All @@ -49,6 +54,7 @@ def baz(n: size, m: size):
def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
# BLOCK START
y = 1.1
for k in seq(0, n):
Expand All @@ -61,4 +67,8 @@ def baz(n: size, m: size):
for j in seq(0, m):
# BLOCK START
x: f32 @ DRAM
# BLOCK END
# BLOCK END
y = 1.1
for k in seq(0, n):
pass
pass
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,19 @@ def baz(n: size, m: size):
x: f32 @ DRAM
x = 0.0
# BLOCK END
pass
y: f32 @ DRAM
y = 1.1
for k in seq(0, n):
pass
pass

def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
x = 0.0
pass
# BLOCK START
y: f32 @ DRAM
y = 1.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,17 @@ def baz(n: size, m: size):
y = 1.1
x = 0.0
# BLOCK END
for k in seq(0, n):
pass
pass

def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
y: f32 @ DRAM
y = 1.1
x = 0.0
# BLOCK START
for k in seq(0, n):
pass
Expand All @@ -57,7 +64,12 @@ def baz(n: size, m: size):
def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
# BLOCK START
y: f32 @ DRAM
y = 1.1
# BLOCK END
# BLOCK END
x = 0.0
for k in seq(0, n):
pass
pass
Loading
Loading