From 87d6e49513e19d30fb7f7cd81a59d61e63eed339 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 13 Oct 2024 19:28:48 +0000 Subject: [PATCH 1/7] Bump pre-commit from 3.8.0 to 4.0.1 (#717) --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index f9c2b4a4e..4b845fe28 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,6 +1,6 @@ black==24.8.0 coverage==7.6.1 -pre-commit==3.8.0 +pre-commit==4.0.1 pytest-cov==5.0.0 pytest-xdist==3.6.1 pytest==8.3.3 From ff52bcc2a6b0abd50d19267a05d71413c0b31a93 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 13 Oct 2024 16:25:29 -0400 Subject: [PATCH 2/7] Bump build from 1.2.2 to 1.2.2.post1 (#714) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index ed467db0d..87a68067b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ PySMT==0.9.6 asdl-adt==0.1.0 asdl==0.1.5 -build==1.2.2 +build==1.2.2.post1 z3-solver==4.13.2.0 yapf==0.40.2 From e4232bf57ed4d509d3e88cb3dbabda92f8209803 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 13 Oct 2024 16:25:50 -0400 Subject: [PATCH 3/7] Bump numpy from 2.1.1 to 2.1.2 (#715) --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 4b845fe28..f735badda 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -5,5 +5,5 @@ pytest-cov==5.0.0 pytest-xdist==3.6.1 pytest==8.3.3 tox==4.21.2 -numpy==2.1.1 +numpy==2.1.2 Pillow==10.4.0 From b5579b14a86018a8d06291a10d6bd6766e482ed3 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 13 Oct 2024 16:26:12 -0400 Subject: [PATCH 4/7] Bump black from 24.8.0 to 24.10.0 (#716) --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index f735badda..3dd8b772d 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,4 +1,4 @@ -black==24.8.0 +black==24.10.0 coverage==7.6.1 pre-commit==4.0.1 pytest-cov==5.0.0 From 7bf0641f26100b97b65b568773d59393ddd96827 Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sun, 13 Oct 2024 16:27:10 -0400 Subject: [PATCH 5/7] Bump coverage from 7.6.1 to 7.6.2 (#719) --- dev-requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 3dd8b772d..76470cef0 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,5 +1,5 @@ black==24.10.0 -coverage==7.6.1 +coverage==7.6.2 pre-commit==4.0.1 pytest-cov==5.0.0 pytest-xdist==3.6.1 From b22d0b5667207e68ebb297a4a3fd5465a7350650 Mon Sep 17 00:00:00 2001 From: Yuka Ikarashi Date: Sun, 13 Oct 2024 17:35:38 -0400 Subject: [PATCH 6/7] Fix cursor printing errors (#720) --- src/exo/LoopIR_pprint.py | 33 +++++--- .../golden/test_cursors/test_cursor_print.txt | 81 +++++++++++++++++++ ...st_block_replace_forwarding_for_blocks.txt | 14 +++- .../test_cursor_pretty_print_blocks.txt | 11 +++ .../test_cursor_pretty_print_gaps.txt | 43 ++++++++++ .../test_delete_forwarding_for_blocks.txt | 12 ++- .../test_insert_forwarding_for_blocks.txt | 9 +++ .../test_move_forwarding_for_blocks.txt | 14 +++- ...t_move_forwarding_for_blocks_gap_after.txt | 4 +- .../test_wrap_forwarding_for_blocks.txt | 11 +++ tests/test_cursors.py | 31 +++++++ 11 files changed, 250 insertions(+), 13 deletions(-) create mode 100644 tests/golden/test_cursors/test_cursor_print.txt diff --git a/src/exo/LoopIR_pprint.py b/src/exo/LoopIR_pprint.py index b47bbfc2a..4464976e3 100644 --- a/src/exo/LoopIR_pprint.py +++ b/src/exo/LoopIR_pprint.py @@ -581,6 +581,9 @@ def _print_w_access(node, env: PrintEnv) -> str: def _print_cursor(cur): + if cur == None: + raise InvalidCursorError("Trying to print the Invalid Cursor!") + if isinstance(cur, Node) and not isinstance(cur._node, (LoopIR.proc, LoopIR.stmt)): raise NotImplementedError( "Cursor printing is only implemented for procs and statements" @@ -625,31 +628,43 @@ def _print_cursor_proc( def _print_cursor_block( cur: Block, target: Cursor, env: PrintEnv, indent: str ) -> list[str]: - def while_cursor(c, move, k): + def while_next(c): s = [] while True: try: - c = move(c) - s.expand(k(c)) + c = c.next() + s.extend(local_stmt(c)) except: return s + def while_prev(c): + s = [] + while True: + try: + c = c.prev() + s.append(local_stmt(c)) + except: + s.reverse() + return [x for xs in s for x in xs] + def local_stmt(c): return _print_cursor_stmt(c, target, env, indent) if isinstance(target, Gap) and target in cur: if target._type == GapType.Before: return [ - *while_cursor(target.anchor(), lambda g: g.prev(), local_stmt), + *while_prev(target.anchor()), f"{indent}[GAP - Before]", - *while_cursor(target.anchor(), lambda g: g.next(), local_stmt), + *local_stmt(target.anchor()), + *while_next(target.anchor()), ] else: assert target._type == GapType.After return [ - *while_cursor(target.anchor(), lambda g: g.prev(), local_stmt), + *while_prev(target.anchor()), + *local_stmt(target.anchor()), f"{indent}[GAP - After]", - *while_cursor(target.anchor(), lambda g: g.next(), local_stmt), + *while_next(target.anchor()), ] elif isinstance(target, Block) and target in cur: @@ -658,9 +673,9 @@ def local_stmt(c): block.extend(local_stmt(stmt)) block.append(f"{indent}# BLOCK END") return [ - *while_cursor(target[0], lambda g: g.prev(), local_stmt), + *while_prev(target[0]), *block, - *while_cursor(target[-1], lambda g: g.next(), local_stmt), + *while_next(target[-1]), ] else: diff --git a/tests/golden/test_cursors/test_cursor_print.txt b/tests/golden/test_cursors/test_cursor_print.txt new file mode 100644 index 000000000..dbdcd3046 --- /dev/null +++ b/tests/golden/test_cursors/test_cursor_print.txt @@ -0,0 +1,81 @@ +def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): # <-- NODE + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + [GAP - Before] + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + [GAP - After] + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): # <-- NODE + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + [GAP - Before] + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + [GAP - After] + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + # BLOCK START + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + # BLOCK END + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + # BLOCK START + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + # BLOCK END + for j in seq(0, n - 1): + x[j] = 3.0 \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt index e31e3e652..7667657dd 100644 --- a/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt @@ -26,10 +26,18 @@ def baz(n: size, m: size): # BLOCK START x: f32 @ DRAM # BLOCK END + pass + pass + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM + pass + pass # BLOCK START for k in seq(0, n): pass @@ -51,7 +59,11 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START pass pass - # BLOCK END \ No newline at end of file + # BLOCK END + for k in seq(0, n): + pass + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt b/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt index de57d7e17..cacb83a11 100644 --- a/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt +++ b/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt @@ -15,11 +15,14 @@ def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 0.0 # BLOCK START x = 1.0 x = 2.0 x = 3.0 # BLOCK END + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM @@ -29,11 +32,19 @@ def bar(n: size, m: size): x = 0.0 x = 1.0 # BLOCK END + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 # BLOCK START x = 4.0 x = 5.0 diff --git a/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt b/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt index f8cfced04..bca928ff4 100644 --- a/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt +++ b/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt @@ -1,28 +1,71 @@ def bar(n: size, m: size): [GAP - Before] + x: f32 @ DRAM + for i in seq(0, n): + for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): + x: f32 @ DRAM [GAP - Before] + for i in seq(0, n): + for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): [GAP - Before] + for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): [GAP - Before] + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 0.0 + x = 1.0 [GAP - Before] + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 [GAP - After] \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt index 8fbffdb76..1b45337ab 100644 --- a/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt @@ -24,10 +24,15 @@ def baz(n: size, m: size): # BLOCK START x: f32 @ DRAM # BLOCK END + y = 1.1 + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START y = 1.1 for k in seq(0, n): @@ -49,6 +54,7 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START y = 1.1 for k in seq(0, n): @@ -61,4 +67,8 @@ def baz(n: size, m: size): for j in seq(0, m): # BLOCK START x: f32 @ DRAM - # BLOCK END \ No newline at end of file + # BLOCK END + y = 1.1 + for k in seq(0, n): + pass + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt index 16fad7def..df3f44091 100644 --- a/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt @@ -31,10 +31,19 @@ def baz(n: size, m: size): x: f32 @ DRAM x = 0.0 # BLOCK END + pass + y: f32 @ DRAM + y = 1.1 + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM + x = 0.0 + pass # BLOCK START y: f32 @ DRAM y = 1.1 diff --git a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt index 2ab01ea00..9db487954 100644 --- a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt @@ -31,10 +31,17 @@ def baz(n: size, m: size): y = 1.1 x = 0.0 # BLOCK END + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM + y: f32 @ DRAM + y = 1.1 + x = 0.0 # BLOCK START for k in seq(0, n): pass @@ -57,7 +64,12 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START y: f32 @ DRAM y = 1.1 - # BLOCK END \ No newline at end of file + # BLOCK END + x = 0.0 + for k in seq(0, n): + pass + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt index da231454d..cd1464eea 100644 --- a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt +++ b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt @@ -41,7 +41,9 @@ def baz(n: size, m: size): x: f32 @ DRAM x = 0.0 for k in seq(0, n): + pass # BLOCK START y: f32 @ DRAM y = 1.1 - # BLOCK END \ No newline at end of file + # BLOCK END + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt index e161e1281..161b23e34 100644 --- a/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt @@ -31,10 +31,21 @@ def baz(n: size, m: size): x: f32 @ DRAM x = 0.0 # BLOCK END + for k in seq(0, 8): + y: f32 @ DRAM + y = 1.1 + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM + x = 0.0 + for k in seq(0, 8): + y: f32 @ DRAM + y = 1.1 # BLOCK START for k in seq(0, n): pass diff --git a/tests/test_cursors.py b/tests/test_cursors.py index 97dee83b1..0e26f62ee 100644 --- a/tests/test_cursors.py +++ b/tests/test_cursors.py @@ -730,3 +730,34 @@ def foo(n: size, x: i8[n]): if_stmt = foo.find("if _: _ ") i_loop_alternative = if_stmt.find("for i in _: _") assert i_loop2 == i_loop_alternative + + +def test_cursor_print(golden): + @proc + def foo(n: size, x: i8[n]): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0 + + i_loop2 = foo.find("for i in _:_ #1") + i_loop1 = foo.find("for i in _:_ #0") + + res = str(i_loop2) + str(i_loop2.before()) + str(i_loop2.after()) + res += ( + str(i_loop1) + + str(i_loop1.before()) + + str(i_loop1.after()) + + str(i_loop1.expand(1, 0)) + + str(i_loop1.expand(0, 1)) + ) + + assert res == golden + + with pytest.raises(InvalidCursorError, match="Trying to print the Invalid Cursor!"): + print(i_loop1.parent()) From 9fc5a80ec9dc25337cefb19b4ed76dd90a5cca23 Mon Sep 17 00:00:00 2001 From: Julien de Castelnau Date: Sun, 13 Oct 2024 23:56:37 +0200 Subject: [PATCH 7/7] Add RVM example (#708) --- README.md | 6 +- examples/{ => avx2_matmul}/Makefile | 0 examples/{ => avx2_matmul}/README.md | 2 +- examples/{ => avx2_matmul}/main.c | 0 examples/{ => avx2_matmul}/x86_matmul.py | 0 examples/rvm_conv1d/.gitignore | 1 + examples/rvm_conv1d/Makefile | 31 ++ examples/rvm_conv1d/README.md | 48 +++ examples/rvm_conv1d/conv1Di32.h | 20 ++ examples/rvm_conv1d/exo/.gitignore | 4 + examples/rvm_conv1d/exo/conv1d.py | 399 +++++++++++++++++++++++ examples/rvm_conv1d/gen_stimuli.py | 87 +++++ examples/rvm_conv1d/main.c | 173 ++++++++++ 13 files changed, 767 insertions(+), 4 deletions(-) rename examples/{ => avx2_matmul}/Makefile (100%) rename examples/{ => avx2_matmul}/README.md (99%) rename examples/{ => avx2_matmul}/main.c (100%) rename examples/{ => avx2_matmul}/x86_matmul.py (100%) create mode 100644 examples/rvm_conv1d/.gitignore create mode 100644 examples/rvm_conv1d/Makefile create mode 100644 examples/rvm_conv1d/README.md create mode 100644 examples/rvm_conv1d/conv1Di32.h create mode 100644 examples/rvm_conv1d/exo/.gitignore create mode 100644 examples/rvm_conv1d/exo/conv1d.py create mode 100644 examples/rvm_conv1d/gen_stimuli.py create mode 100644 examples/rvm_conv1d/main.c diff --git a/README.md b/README.md index bbbe3cf1f..25c05b613 100644 --- a/README.md +++ b/README.md @@ -31,7 +31,7 @@ You can use optional arguments to customize the output: # Examples -Take a look at [examples](examples/README.md) for scheduling examples, and [API documentation](docs/API.md) for complete scheduling interface documentation. +Take a look at [examples](examples/avx2_matmul/README.md) for scheduling examples, and [API documentation](docs/API.md) for scheduling interface documentation. # Build Exo from source @@ -126,7 +126,7 @@ In this repository, folders are structured as follows: - **APIs.** Documentation for the APIs can be found in the [API documentation](docs/API.md). - `API.py` defines a stable API for top-level decorators (`proc`, `instr`, and `config`). - `API_scheduling.py` defines a API for scheduling primitives. - - `API_cursors.py` defines a API for scheduling primitives. + - `API_cursors.py` defines a API for Cursors. - **Standard libraries.** These could be user-defined, but we provide them for convenience. - `libs/` contains some common memory definitions (`memories.py`) and custom malloc implementations. - `platforms/` contains instruction definitions that are part of the release. @@ -141,7 +141,7 @@ In this repository, folders are structured as follows: # Contact -Please contact [exo@mit.edu](mailto:exo@mit.edu) if you have any questions. +Please contact [exo@mit.edu](mailto:exo@mit.edu) or [yuka@csail.mit.edu](mailto:yuka@csail.mit.edu) if you have any questions. # Publication diff --git a/examples/Makefile b/examples/avx2_matmul/Makefile similarity index 100% rename from examples/Makefile rename to examples/avx2_matmul/Makefile diff --git a/examples/README.md b/examples/avx2_matmul/README.md similarity index 99% rename from examples/README.md rename to examples/avx2_matmul/README.md index a586f4786..4328da356 100644 --- a/examples/README.md +++ b/examples/avx2_matmul/README.md @@ -6,7 +6,7 @@ This tutorial assumes some familiarity with SIMD instructions. Exo provides *scheduling operators* to transform program and rewrite them to make use of complex hardware instructions. We'll show you how to take a simple matrix multiplication kernel and transform it into an implementation that can make use of [AVX2](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) vector instructions. -The complete code with scheduling operations can be found in `exo/examples/x86_matmul.py`, and running `make` will compile the Exo code and generate an executable `avx2_matmul`. +The complete code with scheduling operations can be found in `exo/examples/avx2_matmul/x86_matmul.py`, and running `make` will compile the Exo code and generate an executable `avx2_matmul`. ## Basic Implementation diff --git a/examples/main.c b/examples/avx2_matmul/main.c similarity index 100% rename from examples/main.c rename to examples/avx2_matmul/main.c diff --git a/examples/x86_matmul.py b/examples/avx2_matmul/x86_matmul.py similarity index 100% rename from examples/x86_matmul.py rename to examples/avx2_matmul/x86_matmul.py diff --git a/examples/rvm_conv1d/.gitignore b/examples/rvm_conv1d/.gitignore new file mode 100644 index 000000000..466e24805 --- /dev/null +++ b/examples/rvm_conv1d/.gitignore @@ -0,0 +1 @@ +out/ \ No newline at end of file diff --git a/examples/rvm_conv1d/Makefile b/examples/rvm_conv1d/Makefile new file mode 100644 index 000000000..4720f3da6 --- /dev/null +++ b/examples/rvm_conv1d/Makefile @@ -0,0 +1,31 @@ +PROG = conv1d +OUT = out/ +CC = "${RISCV}/bin/clang" +SPIKE = "${RISCV}/bin/spike" +ASFLAGS = -march=rv32imc_xtheadmatrix0p1 -menable-experimental-extensions +CFLAGS = -O2 -g3 $(ASFLAGS) + +default: sim +exo_comp: exo/conv1d_exo.c + +$(OUT)/$(PROG).elf: $(OUT)/$(PROG).o $(OUT)/conv1d_exo.o + $(CC) $(LDFLAGS) -o $@ $^ + +$(OUT)/$(PROG).o: main.c exo/conv1d_exo.h conv1Di32.h $(OUT) + $(CC) $(CFLAGS) -o $@ -c $< + +$(OUT)/conv1d_exo.o: exo/conv1d_exo.c $(OUT) + $(CC) $(CFLAGS) -o $@ -c $< + +$(OUT): + @mkdir -p $(OUT) + +exo/conv1d_exo.h: exo/conv1d_exo.c +exo/conv1d_exo.c: exo/conv1d.py + exocc -o exo/ --stem conv1d_exo exo/conv1d.py + +conv1Di32.h: gen_stimuli.py + python3 $< + +sim: $(OUT)/$(PROG).elf + @$(SPIKE) --isa=RV32IMC_xmatrix pk -s $< \ No newline at end of file diff --git a/examples/rvm_conv1d/README.md b/examples/rvm_conv1d/README.md new file mode 100644 index 000000000..93b2e77ca --- /dev/null +++ b/examples/rvm_conv1d/README.md @@ -0,0 +1,48 @@ +# Conv1D on RVM example + +This is an implementation of a simplified 1D convolution routine, using a custom [RISC-V ISA extension called RVM](https://github.com/esl-epfl/xheep_matrix_spec/tree/main). + +The tutorial accompanying this example is on [the main website](https://exo-lang.dev/tutorial.html). This page will just show you how to first compile the Exo program to C, and how to run it as well (optional.) + +## File organization + +* `main.c` - driver program testing handwritten vs Exo routine +* `gen_stimuli.py` - generate C arrays used as test vectors for conv1d routine, with expected output +* `conv1Di32.h` - generated output from `gen_stimuli.py` +* `exo/conv1d.py` - Exo code for conv1d +* `exo/conv1d_exo.{c,d,h}` - generated outputs from Exo + + +## Setup Exo & Compile + +First follow [the documentation](https://github.com/exo-lang/exo#install-exo) to install Exo, if you have not already. We assume `exocc` is in `$PATH`, and you have `make` installed. To compile the exo program in `exo/conv1d.py`, run: + +```bash +make exo_comp +``` + +The resulting C code for the example will be in `exo/conv1d_exo.c`. + +From here, if you would like to also compile the program to a RISC-V binary, and run it in a simulator, you will need the custom RVM toolchain. The following steps walk through that process. Otherwise, you can stop here. + +## Install RVM toolchain + +RVM is the custom RISC-V extension, which supports instructions and registers to do matrix operations. It requires a custom LLVM toolchain to build code, and in order to run programs, a fork of the Spike simulator. [The repo for RVM has a guide to set up these components](https://github.com/esl-epfl/xheep_matrix_spec/blob/main/BUILDING.md). In the end you should have the LLVM tools as well as Spike installed under `$RISCV/bin`. + + +## Build + +Run `make` to build the driver program, and simulate it in spike. **This assumes you have `$RISCV` defined from the installation step.** You should see an output like this: + +``` +$ make +... +handwritten err: 0 +exo err: 0 +2350 ticks +93797 cycles +93799 instructions +0.99 CPI +``` + +Note that the cycle counts are *not* accurate, and they should not be used to measure performance. Unfortunately, the hardware for RVM is not public as of today, and the Spike simulator is not meant to simulate these details, so it is only used for testing functional correctness. \ No newline at end of file diff --git a/examples/rvm_conv1d/conv1Di32.h b/examples/rvm_conv1d/conv1Di32.h new file mode 100644 index 000000000..8d6f78ab7 --- /dev/null +++ b/examples/rvm_conv1d/conv1Di32.h @@ -0,0 +1,20 @@ +#ifndef _CONV1Di32 +#define _CONV1Di32 +// This file is automatically generated +int32_t __attribute__((section(".xheep_data_interleaved"))) DATA[] = { + 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}; + +int32_t __attribute__((section(".xheep_data_interleaved"))) KERNELS[] = { + 0,1,2,3,10,11,12,13,20,21,22,23,30,31,32,33,100,101,102,103,110,111,112,113,120,121,122,123,130,131,132,133,200,201,202,203,210,211,212,213,220,221,222,223,230,231,232,233,300,301,302,303,310,311,312,313,320,321,322,323,330,331,332,333,400,401,402,403,410,411,412,413,420,421,422,423,430,431,432,433,500,501,502,503,510,511,512,513,520,521,522,523,530,531,532,533,600,601,602,603,610,611,612,613,620,621,622,623,630,631,632,633,700,701,702,703,710,711,712,713,720,721,722,723,730,731,732,733, + 800,801,802,803,810,811,812,813,820,821,822,823,830,831,832,833,900,901,902,903,910,911,912,913,920,921,922,923,930,931,932,933,1000,1001,1002,1003,1010,1011,1012,1013,1020,1021,1022,1023,1030,1031,1032,1033,1100,1101,1102,1103,1110,1111,1112,1113,1120,1121,1122,1123,1130,1131,1132,1133,1200,1201,1202,1203,1210,1211,1212,1213,1220,1221,1222,1223,1230,1231,1232,1233,1300,1301,1302,1303,1310,1311,1312,1313,1320,1321,1322,1323,1330,1331,1332,1333,1400,1401,1402,1403,1410,1411,1412,1413,1420,1421,1422,1423,1430,1431,1432,1433,1500,1501,1502,1503,1510,1511,1512,1513,1520,1521,1522,1523,1530,1531,1532,1533}; + +int32_t __attribute__((section(".xheep_data_interleaved"))) EXPECTED[] = { + 416,680,944,1208,1472,1736,2000,2264,2528,2792,3056,3320,3584,2696,1800,900,2816,4680,6544,8408,10272,12136,14000,15864,17728,19592,21456,23320,25184,19496,13400,6900,5216,8680,12144,15608,19072,22536,26000,29464,32928,36392,39856,43320,46784,36296,25000,12900,7616,12680,17744,22808,27872,32936,38000,43064,48128,53192,58256,63320,68384,53096,36600,18900,10016,16680,23344,30008,36672,43336,50000,56664,63328,69992,76656,83320,89984,69896,48200,24900,12416,20680,28944,37208,45472,53736,62000,70264,78528,86792,95056,103320,111584,86696,59800,30900,14816,24680,34544,44408,54272,64136,74000,83864,93728,103592,113456,123320,133184,103496,71400,36900,17216,28680,40144,51608,63072,74536,86000,97464,108928,120392,131856,143320,154784,120296,83000,42900, + 19616,32680,45744,58808,71872,84936,98000,111064,124128,137192,150256,163320,176384,137096,94600,48900,22016,36680,51344,66008,80672,95336,110000,124664,139328,153992,168656,183320,197984,153896,106200,54900,24416,40680,56944,73208,89472,105736,122000,138264,154528,170792,187056,203320,219584,170696,117800,60900,26816,44680,62544,80408,98272,116136,134000,151864,169728,187592,205456,223320,241184,187496,129400,66900,29216,48680,68144,87608,107072,126536,146000,165464,184928,204392,223856,243320,262784,204296,141000,72900,31616,52680,73744,94808,115872,136936,158000,179064,200128,221192,242256,263320,284384,221096,152600,78900,34016,56680,79344,102008,124672,147336,170000,192664,215328,237992,260656,283320,305984,237896,164200,84900,36416,60680,84944,109208,133472,157736,182000,206264,230528,254792,279056,303320,327584,254696,175800,90900}; + +#define N 16 +#define IC 4 +#define W 4 +#define OC 16 +#define PAD 1 +#endif \ No newline at end of file diff --git a/examples/rvm_conv1d/exo/.gitignore b/examples/rvm_conv1d/exo/.gitignore new file mode 100644 index 000000000..c06c83d3c --- /dev/null +++ b/examples/rvm_conv1d/exo/.gitignore @@ -0,0 +1,4 @@ +__pycache__/ +conv1d_exo.c +conv1d_exo.h +conv1d_exo.d diff --git a/examples/rvm_conv1d/exo/conv1d.py b/examples/rvm_conv1d/exo/conv1d.py new file mode 100644 index 000000000..23b3681c0 --- /dev/null +++ b/examples/rvm_conv1d/exo/conv1d.py @@ -0,0 +1,399 @@ +from __future__ import annotations + +import os +import sys + +import exo.API_cursors as pc +from exo import proc +from exo.libs.memories import * +from exo.platforms.x86 import * +from exo.stdlib.scheduling import * +from exo.stdlib.stdlib import * + +############# +# ALGORITHM # +############# +N = 16 +IC = 4 +W = 4 +OC = 16 +TILE = 4 +def gen_conv1d(): + @proc + def generic_conv1d( + data: i32[IC, N], + kernels: i32[OC, IC, W], + out: i32[OC, N], + ): + # do the convolution + for i in seq(0, OC): + for j in seq(0, N): + # zero out the result memory + out[i, j] = 0.0 + for c in seq(0, IC): + for r in seq(0, W): + y: i32 + if j + r < N: + y = data[c, j + r] + else: + y = 0 + out[i, j] += kernels[i, c, r] * y + return generic_conv1d + +############## +# HW LIBRARY # +############## + +class RVM_TILE(StaticMemory): + NUM_RVM_TILES = 8 + StaticMemory.init_state(NUM_RVM_TILES) + tile_dict = {} + + @classmethod + def reset_allocations(cls): + cls.init_state(cls.NUM_RVM_TILES) + cls.tile_dict = {} + + @classmethod + def can_read(cls): + return False + + @classmethod + def alloc(cls, new_name, prim_type, shape, srcinfo): + if not (len(shape) == 2): + raise MemGenError("Must be a 2D tile.") + if not (shape[0].isdecimal() and int(shape[0]) == 4): + raise MemGenError("Number of tile rows must be 4.") + if not (shape[1].isdecimal() and int(shape[1]) == 4): + raise MemGenError("Number of tile columns must be 4.") + + tile_num = cls.find_free_chunk() + cls.mark(tile_num) + cls.tile_dict[new_name] = tile_num + return f'#define {new_name} "m{7-tile_num}"' + + @classmethod + def free(cls, new_name, prim_type, shape, srcinfo): + tile_num = cls.tile_dict[new_name] + del cls.tile_dict[new_name] + cls.unmark(tile_num) + return f"#undef {new_name}" + + +@instr( + 'asm volatile("mld.w "{dst_int}", (%1), %0" :: "r"(4*({src}.strides[0])), "r"(&{src_data}));' +) +def rvm_mld(dst: [i32][4, 4] @ RVM_TILE, src: [i32][4, 4] @ DRAM): + assert stride(src, 1) == 1 + assert stride(dst, 1) == 1 + + for i in seq(0, 4): + for j in seq(0, 4): + dst[i, j] = src[i, j] + + +@instr('asm volatile("mzero "{dst_int});') +def rvm_mzero(dst: [i32][4, 4] @ RVM_TILE): + assert stride(dst, 1) == 1 + + for i in seq(0, 4): + for j in seq(0, 4): + dst[i, j] = 0.0 + + +@instr( + 'asm volatile("mst.w "{src_int}", (%1), %0" :: "r"(4*({dst}.strides[0])), "r"(&{dst_data}));' +) +def rvm_mst(src: [i32][4, 4] @ RVM_TILE, dst: [i32][4, 4] @ DRAM): + assert stride(src, 1) == 1 + assert stride(dst, 1) == 1 + + for i in seq(0, 4): + for j in seq(0, 4): + dst[i, j] = src[i, j] + + +@instr('asm volatile("mmasa.w "{md_int}", "{ms1_int}", "{ms2_int});') +def rvm_mmasa( + md: [i32][4, 4] @ RVM_TILE, ms1: [i32][4, 4] @ RVM_TILE, ms2: [i32][4, 4] @ RVM_TILE +): + assert stride(md, 1) == 1 + assert stride(ms1, 1) == 1 + assert stride(ms2, 1) == 1 + for i in seq(0, 4): + for j in seq(0, 4): + for k in seq(0, 4): + md[i, j] += ms2[i, k] * ms1[j, k] + +########################## +# CUSTOM REWRITING RULES # +########################## + +def fuse_two_loops(p, c): + """ + for i in ...: <- c + for j in ...: + s1 + for k in ...: <- c.next() + for i in ...: + s2 + ----> + for i in ...: <- c + for j in ...: + s1 + for k in ...: + s2 + """ + try: + next_c = c.next() + except: + return p, False + + if isinstance(c, pc.ForCursor) and isinstance(next_c, pc.ForCursor): + if c.name() == next_c.name() and expr_to_string(c.hi()) == expr_to_string( + next_c.hi() + ): + p = fuse(p, c, next_c, unsafe_disable_check=False) + return p, True + else: + tgt_c, count = find_child_loop(next_c, c.name()) + if tgt_c: + p = lift_scope_n(p, tgt_c, n_lifts=count) + p = fuse(p, c, tgt_c, unsafe_disable_check=False) + return p, True + + return p, False + + +def fuse_all_loops(p, cursor): + """ + recursively calls fuse_two_loops to all the loops + """ + while True: + if isinstance(cursor, pc.ForCursor): + p = fuse_all_loops(p, cursor.body()[0]) + + # Fuse in current scope + p, b = fuse_two_loops(p, cursor) + + if b: + cursor = p.forward(cursor) + else: + try: + cursor = p.forward(cursor).next() + except: + break + + return p + +def autolift_alloc(p, alloc_c, dep_set=None, max_size=0, lift=True): + """ + for i in seq(0, 10): + for j in seq(0, 20): + a : R <- alloc_c, dep_set = {'i'} + a[i] = ... + ----> + a : R[10] <- if size is less than max_size + for i in seq(0, n): + for j in seq(0, m): + a[i] = ... + """ + alloc_c = p.forward(alloc_c) + loop_c = get_enclosing_loop(p, alloc_c) + accum_size = 1 + while True: + try: + if not isinstance(loop_c, pc.ForCursor): + break + if dep_set == None or loop_c.name() in dep_set: + if ( + isinstance(loop_c.hi(), LiteralCursor) + and accum_size * loop_c.hi().value() <= max_size + ): + p = expand_dim(p, alloc_c, loop_c.hi().value(), loop_c.name()) + accum_size = accum_size * loop_c.hi().value() + if lift: + p = lift_alloc(p, alloc_c) + loop_c = loop_c.parent() + except: + break + return p + +def reorder_top(p, c): + """ + for i in seq(0, 10): + s1 + s2 + s3 <- c + ----> + for i in seq(0, 10): + s3 <- c + s1 + s2 + """ + c = p.forward(c) + while True: + try: + p = reorder_stmts(p, c.expand(1, 0)) + c = p.forward(c) + except: + break + return p + + +def fission_as_much_as_possible(p, cursor): + """ + for i in ...: + for j in ...: + s1 + s2 <- cursor + s3 + ---> + for i in ...: + for j in ...: + s2 + + for i in ...: + for j in ...: + s1 + s3 + """ + cursor = p.forward(cursor) + p = reorder_top(p, cursor) + gap_c = cursor.after() + while True: + try: + p = fission(p, gap_c) + gap_c = p.forward(gap_c).parent().after() + except: + break + + return p + + +def lift_scope_n(p, c, n_lifts=1): + """ + for i in seq(0, 10): + for j in seq(0, 10): + for k in seq(0, 10): + if ...: <- c + s1 + ----> if n_lifts == 2: + for i in seq(0, 10): + if ...: <- c + for j in seq(0, 10): + for k in seq(0, 10): + s1 + """ + for i in range(0, n_lifts): + p = lift_scope(p, c) + return p + + +def remove_redundant_loops(p, c, num=0): + """ + for i in ...: + for j in ...: + s1[j] <- c + ---> + for j in ...: + s1[j] <- c + """ + c = p.forward(c) + cur_depth = 0 + while True: + c = c.parent() + if not isinstance(c, pc.ForCursor): + break + try: + if cur_depth >= num: + break + hi = c.hi().value() + name = c.name() + child = p.forward(c).body()[0] + p = remove_loop(p, c) + cur_depth += 1 + except: + continue + return p + +############## +# SCHEDULING # +############## + +def optimize_conv(p): + p = rename(p, "exo_conv1d_tile_lt_kw") + + # Before scheduling, grab cursors to the object code. + i_loop = p.find("for i in _:_") + j_loop = p.find("for j in _:_") + c_loop = p.find("for c in _:_") + y_alloc = p.find("y : _") + y_assign = p.find("y = data[_]") + + # Tile outer loops to TILE size for RVM + p, _ = tile_loops(p, [(i_loop, TILE), (j_loop, TILE)], perfect=True) + p, _ = tile_loops(p, [(i_loop, 4)], perfect=True) + i_loop_reg = p.find("for ioi in _:_") + p = reorder_loops(p, i_loop_reg) + + # Stage output to out_tile + p, (out_alloc, out_tile, body, _) = auto_stage_mem( + p, p.find_loop("c").expand(1, 0), "out", "out_tile", rc=True + ) + p = autolift_alloc(p, out_tile, max_size=4 * 4 * 4, dep_set=["ioi","ii","ji"]) + + # Block the zero initialization and store blocks + p = fission_as_much_as_possible(p, body) + p = fission_as_much_as_possible(p, body[0]) + + # Reorder c loop to the top + p = lift_scope_n(p, c_loop, 3) + + # Stage y + p = autolift_alloc(p, y_alloc, max_size=4 * 4, dep_set=["r","ji"]) + p = lift_alloc(p, y_alloc, n_lifts=2) + + # Fission the initialization loop and remove redundant loops + p = fission_as_much_as_possible(p, y_assign.parent()) + p = remove_redundant_loops(p, y_assign.parent(), num=2) + + # Stage kernels to kernel_tile and y to data_tile + ii_loop = p.forward(c_loop).body()[2].body()[0] + p, (kernel_alloc, _, _, _) = auto_stage_mem( + p, ii_loop, "kernels", "kernel_tile", rc=True + ) + p = simplify(expand_dim(p, kernel_alloc, 4, ii_loop.parent().name())) + p = lift_alloc(p, kernel_alloc) + p, (data_alloc, _, _, _) = auto_stage_mem( + p, ii_loop.parent(), "y", "data_tile", rc=True + ) + + # Set adequate memories + p = set_memory(p, y_alloc, DRAM_STATIC) + p = set_memory(p, out_tile, RVM_TILE) + p = set_memory(p, kernel_alloc, RVM_TILE) + p = set_memory(p, data_alloc, RVM_TILE) + + # Replace inner loops to calls to RVM instructions + p = replace_all(p, [rvm_mzero, rvm_mst, rvm_mld, rvm_mmasa]) + + # Clean up + p = unroll_loop(p, "ioi") + p = unroll_loop(p, "ioi") + p = unroll_loop(p, "ioi") + p = simplify(p) + p = unroll_buffer(p, kernel_alloc, 0) + p = reuse_buffer(p, "kernel_tile_0: _", "kernel_tile_3: _") + p = unroll_buffer(p, "out_tile", 0) + + return p + + +def make_routine(): + generic_conv1d = gen_conv1d() + rvm_optimized = optimize_conv(generic_conv1d) + return rvm_optimized + + +exo_conv1d_tile_lt_kw = make_routine() \ No newline at end of file diff --git a/examples/rvm_conv1d/gen_stimuli.py b/examples/rvm_conv1d/gen_stimuli.py new file mode 100644 index 000000000..39de2410a --- /dev/null +++ b/examples/rvm_conv1d/gen_stimuli.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python + +import sys +import random + +# Copyright 2017 ETH Zurich and University of Bologna. +# Copyright and related rights are licensed under the Solderpad Hardware +# License, Version 0.51 (the License); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# http://solderpad.org/licenses/SHL-0.51. Unless required by applicable law +# or agreed to in writing, software, hardware and materials distributed under +# this License is distributed on an AS IS BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + +def write_arr(f, name, arr, ctype, size, linebreak): + f.write(ctype + " " + name + "[] = {\n\t") + i = 1 + for v in arr: + if i % size == 0: + f.write('%d};\n\n' % (v)) + elif i % linebreak == 0: + f.write('%d,\n\t' % (v)) + else: + f.write('%d,' % (v)) + i+=1 + return + + +################################################################################ +f = open('conv1Di32.h', 'w') +f.write('#ifndef _CONV1Di32 \n') +f.write('#define _CONV1Di32 \n') +f.write('// This file is automatically generated\n') + + +N = 16 +IC = 4 +OC = 16 +W = 4 +RANGE = 4095 + +data = [] +kernel = [] +expected = [] + +pad = 1 + +# N C W format +for i in range(0,IC): + for j in range(0,N): + data.append(j) + #data.append(random.randint(-RANGE, RANGE-1)) + +# O I W format +for i in range(0,OC): + for j in range(0,IC): + for k in range(0,W): + kernel.append(i*100+j*10+k) + #kernel.append(random.randint(-RANGE, RANGE-1)) + +# O W format +for i in range(0,OC): + for j in range(0,N): + sum = 0 + for w_i in range(0,W): + for w_j in range(0,IC): + data_idx = j + w_i + data_at_idx = 0 + if data_idx < N: + data_at_idx = data[w_j * N + j + w_i] + sum += kernel[(IC * i + w_j)*W + w_i] * data_at_idx + expected.append(sum) + + +write_arr(f, 'DATA' , data, 'int32_t __attribute__((section(".xheep_data_interleaved")))', IC * N, 128) +write_arr(f, 'KERNELS' , kernel, 'int32_t __attribute__((section(".xheep_data_interleaved")))', OC * IC * W, 128) +write_arr(f, 'EXPECTED', expected, 'int32_t __attribute__((section(".xheep_data_interleaved")))', OC * N, 128) + +f.write('#define N %d\n' % N) +f.write('#define IC %d\n' % IC) +f.write('#define W %d\n' % W) +f.write('#define OC %d\n' % OC) +f.write('#define PAD %d\n' % 1) + + +f.write('#endif') diff --git a/examples/rvm_conv1d/main.c b/examples/rvm_conv1d/main.c new file mode 100644 index 000000000..0a38286f1 --- /dev/null +++ b/examples/rvm_conv1d/main.c @@ -0,0 +1,173 @@ + +/* Includes */ +#include +#include +#include +#include + +#include "conv1Di32.h" +#include "exo/conv1d_exo.h" + +//////////////////// +// CONFIGURATION // +////////////////// + +#define TILE 4 + +///////////// +// MACROS // +/////////// + +#define CEIL_DIV(a, b) ((((a) % (b)) != 0) ? (((a) / (b)) + 1) : (a) / (b)) + +int32_t out[OC * N]; +int32_t data_tile[TILE][IC * W]; +int32_t result[OC * N]; +int32_t small_data_tile_a[TILE*TILE]; +int32_t small_data_tile_b[TILE*TILE]; + +//////////////// +// MAIN CODE // +////////////// + +void conv1d_tile_lt_kw_reord(int32_t *data, int32_t *kernels, int32_t *out) +{ + // should be ceil_div(ic*kw, tile) * tile + // and initialized to 0 + int tile_i_len = CEIL_DIV(OC, TILE*4); + int tile_j_len = CEIL_DIV(N, TILE); + int data_base; + int cycles; + int32_t *kernel_base = kernels; + register int32_t *small_data_tile = small_data_tile_a; + register int32_t *temp; + for (int tile_i = 0; tile_i < tile_i_len; tile_i++) + { + data_base = 0; + for (int tile_j = 0; tile_j < tile_j_len; tile_j++) + { + asm volatile("mzero m1"); + asm volatile("mzero m2"); + asm volatile("mzero m3"); + asm volatile("mzero m4"); + int data_row = 0; + for (int tile_k = 0; tile_k < IC; tile_k++) + { + //CSR_CLEAR_BITS(CSR_REG_MCOUNTINHIBIT, 0x1); + //CSR_WRITE(CSR_REG_MCYCLE, 0); + for (int replica = 0; replica < TILE; replica++) + { + int ofs = data_base + replica; + int drow_ofs = data_row + ofs; + int dtile_ofs = replica*TILE; + for (int i = 0; i < W; i++) + { + // Check that we are not out of bounds of the input in the current channel + // this should not block: addresses are different + small_data_tile[dtile_ofs] = 0; + if (ofs < N) { + small_data_tile[dtile_ofs] = data[drow_ofs]; + } + + ofs++; + drow_ofs++; + dtile_ofs++; + } + //CSR_READ(CSR_REG_MCYCLE, &cycles); + //printf("cyc: %d\n", cycles); + } + data_row += N; + + asm volatile("mld.w m0, (%1), %0" ::"r"(TILE * 4), "r"(small_data_tile)); + asm volatile("mld.w m5, (%1), %0" ::"r"(IC * W * 4), "r"(kernel_base)); + asm volatile("mmasa.w m1, m0, m5"); + asm volatile("mld.w m6, (%1), %0" ::"r"(IC * W * 4), "r"(kernel_base+TILE * IC * W)); + asm volatile("mmasa.w m2, m0, m6"); + asm volatile("mld.w m7, (%1), %0" ::"r"(IC * W * 4), "r"(kernel_base+TILE * IC * W*2)); + asm volatile("mmasa.w m3, m0, m7"); + asm volatile("mld.w m5, (%1), %0" ::"r"(IC * W * 4), "r"(kernel_base+TILE * IC * W*3)); + asm volatile("mmasa.w m4, m0, m5"); + kernel_base += W; + // swap + // asm ("xor %0, %0, %1" : "=r"(small_data_tile_cur) : "r"(small_data_tile_old)); + // asm ("xor %0, %0, %1" : "=r"(small_data_tile_old) : "r"(small_data_tile_cur)); + // asm ("xor %0, %0, %1" : "=r"(small_data_tile_cur) : "r"(small_data_tile_old)); + } + int32_t *outptr = (out + (tile_i * N * 4 + tile_j) * TILE); + asm volatile("mst.w m1, (%1), %0" ::"r"(N * 4), "r"(outptr)); + asm volatile("mst.w m2, (%1), %0" ::"r"(N * 4), "r"(outptr + TILE*N)); + asm volatile("mst.w m3, (%1), %0" ::"r"(N * 4), "r"(outptr + TILE*N*2)); + asm volatile("mst.w m4, (%1), %0" ::"r"(N * 4), "r"(outptr + TILE*N*3)); + + data_base += TILE; + kernel_base -= W * IC; + } + kernel_base += TILE * IC * W*4; + } +} + +#define BRANCHLESS_TERNARY(c, x, y) ((-(c) & x) | (~(-(c)) & y)); +void conv1d_cpu(int32_t *data, int32_t *kernels, int32_t *out) +{ + for (int i = 0; i < OC; i++) + { + for (int j = 0; j < N; j++) + { + out[N * i + j] = 0; + for (int w_i = 0; w_i < W; w_i++) + { + for (int w_j = 0; w_j < IC; w_j++) + { + int data_idx = j + w_i; + int kernel_idx = (IC * i + w_j) * W + w_i; + int data_at_idx = BRANCHLESS_TERNARY(data_idx < N, data[w_j * N + j + w_i], 0); + out[N * i + j] += data_at_idx * kernels[kernel_idx]; + } + } + } + } +} + +int check_result(int32_t *result) { + int err = 0; + for (int i = 0; i < OC; i++) + { + for (int j = 0; j < N; j++) + { + if (result[N * i + j] != EXPECTED[N * i + j]) + { + err++; + printf("exp %d got %d\n\r", EXPECTED[N * i + j], result[N * i + j]); + } + } + } + return err; +} + +int main() +{ + for (int i = 0; i < TILE; i++) + { + for (int j = 0; j < TILE; j++) + { + small_data_tile_a[i*TILE+j] = 0; + small_data_tile_b[i*TILE+j] = 0; + } + } + + conv1d_tile_lt_kw_reord(DATA, KERNELS, result); + printf("handwritten err: %d\n\r", check_result(result)); + + for (int i = 0; i < OC; i++) + { + for (int j = 0; j < N; j++) + { + result[i*N+j] = 0; + } + } + + exo_conv1d_tile_lt_kw(NULL, DATA, KERNELS, result); + printf("exo err: %d\n\r", check_result(result)); + + return 0; +}