From 58c1858b8a9efed8ac6361de7dc1b8afe32241cd Mon Sep 17 00:00:00 2001 From: Yuka Ikarashi Date: Fri, 11 Oct 2024 11:09:55 -0400 Subject: [PATCH 1/3] Fix cursor printing errors --- src/exo/LoopIR_pprint.py | 13 ++- .../golden/test_cursors/test_cursor_print.txt | 81 +++++++++++++++++++ tests/test_cursors.py | 31 +++++++ 3 files changed, 124 insertions(+), 1 deletion(-) create mode 100644 tests/golden/test_cursors/test_cursor_print.txt diff --git a/src/exo/LoopIR_pprint.py b/src/exo/LoopIR_pprint.py index 79eb13e5f..50aa18043 100644 --- a/src/exo/LoopIR_pprint.py +++ b/src/exo/LoopIR_pprint.py @@ -581,6 +581,9 @@ def _print_w_access(node, env: PrintEnv) -> str: def _print_cursor(cur): + if cur == None: + raise InvalidCursorError("Trying to print the Invalid Cursor!") + if isinstance(cur, Node) and not isinstance(cur._node, (LoopIR.proc, LoopIR.stmt)): raise NotImplementedError( "Cursor printing is only implemented for procs and statements" @@ -630,10 +633,16 @@ def while_cursor(c, move, k): while True: try: c = move(c) - s.expand(k(c)) + s.extend(k(c)) except: return s + def if_cursor(c, move, k): + try: + return k(move(c)) + except InvalidCursorError: + return [] + def local_stmt(c): return _print_cursor_stmt(c, target, env, indent) @@ -642,12 +651,14 @@ def local_stmt(c): return [ *while_cursor(target.anchor(), lambda g: g.prev(), local_stmt), f"{indent}[GAP - Before]", + *if_cursor(target, lambda g: g.anchor(), local_stmt), *while_cursor(target.anchor(), lambda g: g.next(), local_stmt), ] else: assert target._type == GapType.After return [ *while_cursor(target.anchor(), lambda g: g.prev(), local_stmt), + *if_cursor(target, lambda g: g.anchor(), local_stmt), f"{indent}[GAP - After]", *while_cursor(target.anchor(), lambda g: g.next(), local_stmt), ] diff --git a/tests/golden/test_cursors/test_cursor_print.txt b/tests/golden/test_cursors/test_cursor_print.txt new file mode 100644 index 000000000..dbdcd3046 --- /dev/null +++ b/tests/golden/test_cursors/test_cursor_print.txt @@ -0,0 +1,81 @@ +def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): # <-- NODE + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + [GAP - Before] + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + [GAP - After] + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): # <-- NODE + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + [GAP - Before] + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + [GAP - After] + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + # BLOCK START + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + # BLOCK END + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + # BLOCK START + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + # BLOCK END + for j in seq(0, n - 1): + x[j] = 3.0 \ No newline at end of file diff --git a/tests/test_cursors.py b/tests/test_cursors.py index 672ca0f52..8965918a6 100644 --- a/tests/test_cursors.py +++ b/tests/test_cursors.py @@ -729,3 +729,34 @@ def foo(n: size, x: i8[n]): if_stmt = foo.find("if _: _ ") i_loop_alternative = if_stmt.find("for i in _: _") assert i_loop2 == i_loop_alternative + + +def test_cursor_print(golden): + @proc + def foo(n: size, x: i8[n]): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0 + + i_loop2 = foo.find("for i in _:_ #1") + i_loop1 = foo.find("for i in _:_ #0") + + res = str(i_loop2) + str(i_loop2.before()) + str(i_loop2.after()) + res += ( + str(i_loop1) + + str(i_loop1.before()) + + str(i_loop1.after()) + + str(i_loop1.expand(1, 0)) + + str(i_loop1.expand(0, 1)) + ) + + assert res == golden + + with pytest.raises(InvalidCursorError, match="Trying to print the Invalid Cursor!"): + print(i_loop1.parent()) From 359e127f83904acc4b2c6d24944b233591dd0708 Mon Sep 17 00:00:00 2001 From: Yuka Ikarashi Date: Fri, 11 Oct 2024 11:44:52 -0400 Subject: [PATCH 2/3] update golden --- ...st_block_replace_forwarding_for_blocks.txt | 14 +++++- .../test_cursor_pretty_print_blocks.txt | 11 +++++ .../test_cursor_pretty_print_gaps.txt | 43 +++++++++++++++++++ .../test_delete_forwarding_for_blocks.txt | 12 +++++- .../test_insert_forwarding_for_blocks.txt | 9 ++++ .../test_move_forwarding_for_blocks.txt | 14 +++++- ...t_move_forwarding_for_blocks_gap_after.txt | 4 +- .../test_wrap_forwarding_for_blocks.txt | 11 +++++ 8 files changed, 114 insertions(+), 4 deletions(-) diff --git a/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt index e31e3e652..034193e1d 100644 --- a/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt @@ -26,10 +26,18 @@ def baz(n: size, m: size): # BLOCK START x: f32 @ DRAM # BLOCK END + pass + pass + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + pass + pass + x: f32 @ DRAM # BLOCK START for k in seq(0, n): pass @@ -51,7 +59,11 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START pass pass - # BLOCK END \ No newline at end of file + # BLOCK END + for k in seq(0, n): + pass + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt b/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt index de57d7e17..08f8cac13 100644 --- a/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt +++ b/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt @@ -15,11 +15,14 @@ def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 0.0 # BLOCK START x = 1.0 x = 2.0 x = 3.0 # BLOCK END + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM @@ -29,11 +32,19 @@ def bar(n: size, m: size): x = 0.0 x = 1.0 # BLOCK END + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 3.0 + x = 2.0 + x = 1.0 + x = 0.0 # BLOCK START x = 4.0 x = 5.0 diff --git a/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt b/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt index f8cfced04..31c69fc62 100644 --- a/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt +++ b/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt @@ -1,28 +1,71 @@ def bar(n: size, m: size): [GAP - Before] + x: f32 @ DRAM + for i in seq(0, n): + for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): + x: f32 @ DRAM [GAP - Before] + for i in seq(0, n): + for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): [GAP - Before] + for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): [GAP - Before] + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 1.0 + x = 0.0 [GAP - Before] + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 4.0 + x = 3.0 + x = 2.0 + x = 1.0 + x = 0.0 + x = 5.0 [GAP - After] \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt index 8fbffdb76..1b45337ab 100644 --- a/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt @@ -24,10 +24,15 @@ def baz(n: size, m: size): # BLOCK START x: f32 @ DRAM # BLOCK END + y = 1.1 + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START y = 1.1 for k in seq(0, n): @@ -49,6 +54,7 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START y = 1.1 for k in seq(0, n): @@ -61,4 +67,8 @@ def baz(n: size, m: size): for j in seq(0, m): # BLOCK START x: f32 @ DRAM - # BLOCK END \ No newline at end of file + # BLOCK END + y = 1.1 + for k in seq(0, n): + pass + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt index 16fad7def..c76411e8d 100644 --- a/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt @@ -31,10 +31,19 @@ def baz(n: size, m: size): x: f32 @ DRAM x = 0.0 # BLOCK END + pass + y: f32 @ DRAM + y = 1.1 + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + pass + x = 0.0 + x: f32 @ DRAM # BLOCK START y: f32 @ DRAM y = 1.1 diff --git a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt index 2ab01ea00..c7311b674 100644 --- a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt @@ -31,10 +31,17 @@ def baz(n: size, m: size): y = 1.1 x = 0.0 # BLOCK END + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x = 0.0 + y = 1.1 + y: f32 @ DRAM + x: f32 @ DRAM # BLOCK START for k in seq(0, n): pass @@ -57,7 +64,12 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START y: f32 @ DRAM y = 1.1 - # BLOCK END \ No newline at end of file + # BLOCK END + x = 0.0 + for k in seq(0, n): + pass + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt index da231454d..cd1464eea 100644 --- a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt +++ b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt @@ -41,7 +41,9 @@ def baz(n: size, m: size): x: f32 @ DRAM x = 0.0 for k in seq(0, n): + pass # BLOCK START y: f32 @ DRAM y = 1.1 - # BLOCK END \ No newline at end of file + # BLOCK END + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt index e161e1281..fea2ab41c 100644 --- a/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt @@ -31,10 +31,21 @@ def baz(n: size, m: size): x: f32 @ DRAM x = 0.0 # BLOCK END + for k in seq(0, 8): + y: f32 @ DRAM + y = 1.1 + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + for k in seq(0, 8): + y: f32 @ DRAM + y = 1.1 + x = 0.0 + x: f32 @ DRAM # BLOCK START for k in seq(0, n): pass From 6aa96bdd2d2e24a638ed13b06ca8bec6fdcb027e Mon Sep 17 00:00:00 2001 From: Yuka Ikarashi Date: Fri, 11 Oct 2024 15:39:13 -0400 Subject: [PATCH 3/3] fix --- src/exo/LoopIR_pprint.py | 36 ++++++++++--------- ...st_block_replace_forwarding_for_blocks.txt | 2 +- .../test_cursor_pretty_print_blocks.txt | 6 ++-- .../test_cursor_pretty_print_gaps.txt | 10 +++--- .../test_insert_forwarding_for_blocks.txt | 4 +-- .../test_move_forwarding_for_blocks.txt | 6 ++-- .../test_wrap_forwarding_for_blocks.txt | 4 +-- 7 files changed, 36 insertions(+), 32 deletions(-) diff --git a/src/exo/LoopIR_pprint.py b/src/exo/LoopIR_pprint.py index 50aa18043..b40762511 100644 --- a/src/exo/LoopIR_pprint.py +++ b/src/exo/LoopIR_pprint.py @@ -628,20 +628,24 @@ def _print_cursor_proc( def _print_cursor_block( cur: Block, target: Cursor, env: PrintEnv, indent: str ) -> list[str]: - def while_cursor(c, move, k): + def while_next(c): s = [] while True: try: - c = move(c) - s.extend(k(c)) + c = c.next() + s.extend(local_stmt(c)) except: return s - def if_cursor(c, move, k): - try: - return k(move(c)) - except InvalidCursorError: - return [] + def while_prev(c): + s = [] + while True: + try: + c = c.prev() + s.append(local_stmt(c)) + except: + s.reverse() + return [x for xs in s for x in xs] def local_stmt(c): return _print_cursor_stmt(c, target, env, indent) @@ -649,18 +653,18 @@ def local_stmt(c): if isinstance(target, Gap) and target in cur: if target._type == GapType.Before: return [ - *while_cursor(target.anchor(), lambda g: g.prev(), local_stmt), + *while_prev(target.anchor()), f"{indent}[GAP - Before]", - *if_cursor(target, lambda g: g.anchor(), local_stmt), - *while_cursor(target.anchor(), lambda g: g.next(), local_stmt), + *local_stmt(target.anchor()), + *while_next(target.anchor()), ] else: assert target._type == GapType.After return [ - *while_cursor(target.anchor(), lambda g: g.prev(), local_stmt), - *if_cursor(target, lambda g: g.anchor(), local_stmt), + *while_prev(target.anchor()), + *local_stmt(target.anchor()), f"{indent}[GAP - After]", - *while_cursor(target.anchor(), lambda g: g.next(), local_stmt), + *while_next(target.anchor()), ] elif isinstance(target, Block) and target in cur: @@ -669,9 +673,9 @@ def local_stmt(c): block.extend(local_stmt(stmt)) block.append(f"{indent}# BLOCK END") return [ - *while_cursor(target[0], lambda g: g.prev(), local_stmt), + *while_prev(target[0]), *block, - *while_cursor(target[-1], lambda g: g.next(), local_stmt), + *while_next(target[-1]), ] else: diff --git a/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt index 034193e1d..7667657dd 100644 --- a/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt @@ -35,9 +35,9 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM pass pass - x: f32 @ DRAM # BLOCK START for k in seq(0, n): pass diff --git a/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt b/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt index 08f8cac13..cacb83a11 100644 --- a/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt +++ b/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt @@ -41,10 +41,10 @@ def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): - x = 3.0 - x = 2.0 - x = 1.0 x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 # BLOCK START x = 4.0 x = 5.0 diff --git a/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt b/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt index 31c69fc62..bca928ff4 100644 --- a/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt +++ b/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt @@ -50,8 +50,8 @@ def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): - x = 1.0 x = 0.0 + x = 1.0 [GAP - Before] x = 2.0 x = 3.0 @@ -62,10 +62,10 @@ def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): - x = 4.0 - x = 3.0 - x = 2.0 - x = 1.0 x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 x = 5.0 [GAP - After] \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt index c76411e8d..df3f44091 100644 --- a/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt @@ -41,9 +41,9 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): - pass - x = 0.0 x: f32 @ DRAM + x = 0.0 + pass # BLOCK START y: f32 @ DRAM y = 1.1 diff --git a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt index c7311b674..9db487954 100644 --- a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt @@ -38,10 +38,10 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): - x = 0.0 - y = 1.1 - y: f32 @ DRAM x: f32 @ DRAM + y: f32 @ DRAM + y = 1.1 + x = 0.0 # BLOCK START for k in seq(0, n): pass diff --git a/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt index fea2ab41c..161b23e34 100644 --- a/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt @@ -41,11 +41,11 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM + x = 0.0 for k in seq(0, 8): y: f32 @ DRAM y = 1.1 - x = 0.0 - x: f32 @ DRAM # BLOCK START for k in seq(0, n): pass