diff --git a/src/exo/LoopIR_pprint.py b/src/exo/LoopIR_pprint.py index b47bbfc2a..4464976e3 100644 --- a/src/exo/LoopIR_pprint.py +++ b/src/exo/LoopIR_pprint.py @@ -581,6 +581,9 @@ def _print_w_access(node, env: PrintEnv) -> str: def _print_cursor(cur): + if cur == None: + raise InvalidCursorError("Trying to print the Invalid Cursor!") + if isinstance(cur, Node) and not isinstance(cur._node, (LoopIR.proc, LoopIR.stmt)): raise NotImplementedError( "Cursor printing is only implemented for procs and statements" @@ -625,31 +628,43 @@ def _print_cursor_proc( def _print_cursor_block( cur: Block, target: Cursor, env: PrintEnv, indent: str ) -> list[str]: - def while_cursor(c, move, k): + def while_next(c): s = [] while True: try: - c = move(c) - s.expand(k(c)) + c = c.next() + s.extend(local_stmt(c)) except: return s + def while_prev(c): + s = [] + while True: + try: + c = c.prev() + s.append(local_stmt(c)) + except: + s.reverse() + return [x for xs in s for x in xs] + def local_stmt(c): return _print_cursor_stmt(c, target, env, indent) if isinstance(target, Gap) and target in cur: if target._type == GapType.Before: return [ - *while_cursor(target.anchor(), lambda g: g.prev(), local_stmt), + *while_prev(target.anchor()), f"{indent}[GAP - Before]", - *while_cursor(target.anchor(), lambda g: g.next(), local_stmt), + *local_stmt(target.anchor()), + *while_next(target.anchor()), ] else: assert target._type == GapType.After return [ - *while_cursor(target.anchor(), lambda g: g.prev(), local_stmt), + *while_prev(target.anchor()), + *local_stmt(target.anchor()), f"{indent}[GAP - After]", - *while_cursor(target.anchor(), lambda g: g.next(), local_stmt), + *while_next(target.anchor()), ] elif isinstance(target, Block) and target in cur: @@ -658,9 +673,9 @@ def local_stmt(c): block.extend(local_stmt(stmt)) block.append(f"{indent}# BLOCK END") return [ - *while_cursor(target[0], lambda g: g.prev(), local_stmt), + *while_prev(target[0]), *block, - *while_cursor(target[-1], lambda g: g.next(), local_stmt), + *while_next(target[-1]), ] else: diff --git a/tests/golden/test_cursors/test_cursor_print.txt b/tests/golden/test_cursors/test_cursor_print.txt new file mode 100644 index 000000000..dbdcd3046 --- /dev/null +++ b/tests/golden/test_cursors/test_cursor_print.txt @@ -0,0 +1,81 @@ +def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): # <-- NODE + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + [GAP - Before] + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + [GAP - After] + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): # <-- NODE + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + [GAP - Before] + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + [GAP - After] + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + # BLOCK START + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + # BLOCK END + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + # BLOCK START + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + # BLOCK END + for j in seq(0, n - 1): + x[j] = 3.0 \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt index e31e3e652..7667657dd 100644 --- a/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt @@ -26,10 +26,18 @@ def baz(n: size, m: size): # BLOCK START x: f32 @ DRAM # BLOCK END + pass + pass + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM + pass + pass # BLOCK START for k in seq(0, n): pass @@ -51,7 +59,11 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START pass pass - # BLOCK END \ No newline at end of file + # BLOCK END + for k in seq(0, n): + pass + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt b/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt index de57d7e17..cacb83a11 100644 --- a/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt +++ b/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt @@ -15,11 +15,14 @@ def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 0.0 # BLOCK START x = 1.0 x = 2.0 x = 3.0 # BLOCK END + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM @@ -29,11 +32,19 @@ def bar(n: size, m: size): x = 0.0 x = 1.0 # BLOCK END + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 # BLOCK START x = 4.0 x = 5.0 diff --git a/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt b/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt index f8cfced04..bca928ff4 100644 --- a/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt +++ b/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt @@ -1,28 +1,71 @@ def bar(n: size, m: size): [GAP - Before] + x: f32 @ DRAM + for i in seq(0, n): + for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): + x: f32 @ DRAM [GAP - Before] + for i in seq(0, n): + for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): [GAP - Before] + for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): [GAP - Before] + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 0.0 + x = 1.0 [GAP - Before] + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 [GAP - After] \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt index 8fbffdb76..1b45337ab 100644 --- a/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt @@ -24,10 +24,15 @@ def baz(n: size, m: size): # BLOCK START x: f32 @ DRAM # BLOCK END + y = 1.1 + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START y = 1.1 for k in seq(0, n): @@ -49,6 +54,7 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START y = 1.1 for k in seq(0, n): @@ -61,4 +67,8 @@ def baz(n: size, m: size): for j in seq(0, m): # BLOCK START x: f32 @ DRAM - # BLOCK END \ No newline at end of file + # BLOCK END + y = 1.1 + for k in seq(0, n): + pass + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt index 16fad7def..df3f44091 100644 --- a/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt @@ -31,10 +31,19 @@ def baz(n: size, m: size): x: f32 @ DRAM x = 0.0 # BLOCK END + pass + y: f32 @ DRAM + y = 1.1 + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM + x = 0.0 + pass # BLOCK START y: f32 @ DRAM y = 1.1 diff --git a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt index 2ab01ea00..9db487954 100644 --- a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt @@ -31,10 +31,17 @@ def baz(n: size, m: size): y = 1.1 x = 0.0 # BLOCK END + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM + y: f32 @ DRAM + y = 1.1 + x = 0.0 # BLOCK START for k in seq(0, n): pass @@ -57,7 +64,12 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START y: f32 @ DRAM y = 1.1 - # BLOCK END \ No newline at end of file + # BLOCK END + x = 0.0 + for k in seq(0, n): + pass + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt index da231454d..cd1464eea 100644 --- a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt +++ b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt @@ -41,7 +41,9 @@ def baz(n: size, m: size): x: f32 @ DRAM x = 0.0 for k in seq(0, n): + pass # BLOCK START y: f32 @ DRAM y = 1.1 - # BLOCK END \ No newline at end of file + # BLOCK END + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt index e161e1281..161b23e34 100644 --- a/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt @@ -31,10 +31,21 @@ def baz(n: size, m: size): x: f32 @ DRAM x = 0.0 # BLOCK END + for k in seq(0, 8): + y: f32 @ DRAM + y = 1.1 + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM + x = 0.0 + for k in seq(0, 8): + y: f32 @ DRAM + y = 1.1 # BLOCK START for k in seq(0, n): pass diff --git a/tests/test_cursors.py b/tests/test_cursors.py index 97dee83b1..0e26f62ee 100644 --- a/tests/test_cursors.py +++ b/tests/test_cursors.py @@ -730,3 +730,34 @@ def foo(n: size, x: i8[n]): if_stmt = foo.find("if _: _ ") i_loop_alternative = if_stmt.find("for i in _: _") assert i_loop2 == i_loop_alternative + + +def test_cursor_print(golden): + @proc + def foo(n: size, x: i8[n]): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0 + + i_loop2 = foo.find("for i in _:_ #1") + i_loop1 = foo.find("for i in _:_ #0") + + res = str(i_loop2) + str(i_loop2.before()) + str(i_loop2.after()) + res += ( + str(i_loop1) + + str(i_loop1.before()) + + str(i_loop1.after()) + + str(i_loop1.expand(1, 0)) + + str(i_loop1.expand(0, 1)) + ) + + assert res == golden + + with pytest.raises(InvalidCursorError, match="Trying to print the Invalid Cursor!"): + print(i_loop1.parent())