Skip to content

Commit

Permalink
Fix cursor printing errors (#720)
Browse files Browse the repository at this point in the history
  • Loading branch information
yamaguchi1024 authored Oct 13, 2024
1 parent 7bf0641 commit b22d0b5
Show file tree
Hide file tree
Showing 11 changed files with 250 additions and 13 deletions.
33 changes: 24 additions & 9 deletions src/exo/LoopIR_pprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,9 @@ def _print_w_access(node, env: PrintEnv) -> str:


def _print_cursor(cur):
if cur == None:
raise InvalidCursorError("Trying to print the Invalid Cursor!")

if isinstance(cur, Node) and not isinstance(cur._node, (LoopIR.proc, LoopIR.stmt)):
raise NotImplementedError(
"Cursor printing is only implemented for procs and statements"
Expand Down Expand Up @@ -625,31 +628,43 @@ def _print_cursor_proc(
def _print_cursor_block(
cur: Block, target: Cursor, env: PrintEnv, indent: str
) -> list[str]:
def while_cursor(c, move, k):
def while_next(c):
s = []
while True:
try:
c = move(c)
s.expand(k(c))
c = c.next()
s.extend(local_stmt(c))
except:
return s

def while_prev(c):
s = []
while True:
try:
c = c.prev()
s.append(local_stmt(c))
except:
s.reverse()
return [x for xs in s for x in xs]

def local_stmt(c):
return _print_cursor_stmt(c, target, env, indent)

if isinstance(target, Gap) and target in cur:
if target._type == GapType.Before:
return [
*while_cursor(target.anchor(), lambda g: g.prev(), local_stmt),
*while_prev(target.anchor()),
f"{indent}[GAP - Before]",
*while_cursor(target.anchor(), lambda g: g.next(), local_stmt),
*local_stmt(target.anchor()),
*while_next(target.anchor()),
]
else:
assert target._type == GapType.After
return [
*while_cursor(target.anchor(), lambda g: g.prev(), local_stmt),
*while_prev(target.anchor()),
*local_stmt(target.anchor()),
f"{indent}[GAP - After]",
*while_cursor(target.anchor(), lambda g: g.next(), local_stmt),
*while_next(target.anchor()),
]

elif isinstance(target, Block) and target in cur:
Expand All @@ -658,9 +673,9 @@ def local_stmt(c):
block.extend(local_stmt(stmt))
block.append(f"{indent}# BLOCK END")
return [
*while_cursor(target[0], lambda g: g.prev(), local_stmt),
*while_prev(target[0]),
*block,
*while_cursor(target[-1], lambda g: g.next(), local_stmt),
*while_next(target[-1]),
]

else:
Expand Down
81 changes: 81 additions & 0 deletions tests/golden/test_cursors/test_cursor_print.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
def foo(n: size, x: i8[n] @ DRAM):
for j in seq(0, n - 1):
x[j] = 2.0
for i in seq(0, n):
pass
if n > 1:
for i in seq(0, n): # <-- NODE
x[i] = 0.0
for j in seq(0, n - 1):
x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM):
for j in seq(0, n - 1):
x[j] = 2.0
for i in seq(0, n):
pass
if n > 1:
[GAP - Before]
for i in seq(0, n):
x[i] = 0.0
for j in seq(0, n - 1):
x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM):
for j in seq(0, n - 1):
x[j] = 2.0
for i in seq(0, n):
pass
if n > 1:
for i in seq(0, n):
x[i] = 0.0
[GAP - After]
for j in seq(0, n - 1):
x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM):
for j in seq(0, n - 1):
x[j] = 2.0
for i in seq(0, n): # <-- NODE
pass
if n > 1:
for i in seq(0, n):
x[i] = 0.0
for j in seq(0, n - 1):
x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM):
for j in seq(0, n - 1):
x[j] = 2.0
[GAP - Before]
for i in seq(0, n):
pass
if n > 1:
for i in seq(0, n):
x[i] = 0.0
for j in seq(0, n - 1):
x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM):
for j in seq(0, n - 1):
x[j] = 2.0
for i in seq(0, n):
pass
[GAP - After]
if n > 1:
for i in seq(0, n):
x[i] = 0.0
for j in seq(0, n - 1):
x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM):
# BLOCK START
for j in seq(0, n - 1):
x[j] = 2.0
for i in seq(0, n):
pass
# BLOCK END
if n > 1:
for i in seq(0, n):
x[i] = 0.0
for j in seq(0, n - 1):
x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM):
for j in seq(0, n - 1):
x[j] = 2.0
# BLOCK START
for i in seq(0, n):
pass
if n > 1:
for i in seq(0, n):
x[i] = 0.0
# BLOCK END
for j in seq(0, n - 1):
x[j] = 3.0
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,18 @@ def baz(n: size, m: size):
# BLOCK START
x: f32 @ DRAM
# BLOCK END
pass
pass
for k in seq(0, n):
pass
pass

def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
pass
pass
# BLOCK START
for k in seq(0, n):
pass
Expand All @@ -51,7 +59,11 @@ def baz(n: size, m: size):
def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
# BLOCK START
pass
pass
# BLOCK END
# BLOCK END
for k in seq(0, n):
pass
pass
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,14 @@ def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
x = 0.0
# BLOCK START
x = 1.0
x = 2.0
x = 3.0
# BLOCK END
x = 4.0
x = 5.0

def bar(n: size, m: size):
x: f32 @ DRAM
Expand All @@ -29,11 +32,19 @@ def bar(n: size, m: size):
x = 0.0
x = 1.0
# BLOCK END
x = 2.0
x = 3.0
x = 4.0
x = 5.0

def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
x = 0.0
x = 1.0
x = 2.0
x = 3.0
# BLOCK START
x = 4.0
x = 5.0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,71 @@
def bar(n: size, m: size):
[GAP - Before]
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
x = 0.0
x = 1.0
x = 2.0
x = 3.0
x = 4.0
x = 5.0

def bar(n: size, m: size):
x: f32 @ DRAM
[GAP - Before]
for i in seq(0, n):
for j in seq(0, m):
x = 0.0
x = 1.0
x = 2.0
x = 3.0
x = 4.0
x = 5.0

def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
[GAP - Before]
for j in seq(0, m):
x = 0.0
x = 1.0
x = 2.0
x = 3.0
x = 4.0
x = 5.0

def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
[GAP - Before]
x = 0.0
x = 1.0
x = 2.0
x = 3.0
x = 4.0
x = 5.0

def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
x = 0.0
x = 1.0
[GAP - Before]
x = 2.0
x = 3.0
x = 4.0
x = 5.0

def bar(n: size, m: size):
x: f32 @ DRAM
for i in seq(0, n):
for j in seq(0, m):
x = 0.0
x = 1.0
x = 2.0
x = 3.0
x = 4.0
x = 5.0
[GAP - After]
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,15 @@ def baz(n: size, m: size):
# BLOCK START
x: f32 @ DRAM
# BLOCK END
y = 1.1
for k in seq(0, n):
pass
pass

def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
# BLOCK START
y = 1.1
for k in seq(0, n):
Expand All @@ -49,6 +54,7 @@ def baz(n: size, m: size):
def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
# BLOCK START
y = 1.1
for k in seq(0, n):
Expand All @@ -61,4 +67,8 @@ def baz(n: size, m: size):
for j in seq(0, m):
# BLOCK START
x: f32 @ DRAM
# BLOCK END
# BLOCK END
y = 1.1
for k in seq(0, n):
pass
pass
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,19 @@ def baz(n: size, m: size):
x: f32 @ DRAM
x = 0.0
# BLOCK END
pass
y: f32 @ DRAM
y = 1.1
for k in seq(0, n):
pass
pass

def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
x = 0.0
pass
# BLOCK START
y: f32 @ DRAM
y = 1.1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,17 @@ def baz(n: size, m: size):
y = 1.1
x = 0.0
# BLOCK END
for k in seq(0, n):
pass
pass

def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
y: f32 @ DRAM
y = 1.1
x = 0.0
# BLOCK START
for k in seq(0, n):
pass
Expand All @@ -57,7 +64,12 @@ def baz(n: size, m: size):
def baz(n: size, m: size):
for i in seq(0, n):
for j in seq(0, m):
x: f32 @ DRAM
# BLOCK START
y: f32 @ DRAM
y = 1.1
# BLOCK END
# BLOCK END
x = 0.0
for k in seq(0, n):
pass
pass
Loading

0 comments on commit b22d0b5

Please sign in to comment.