-
Notifications
You must be signed in to change notification settings - Fork 28
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
74204ba
commit d9772a3
Showing
4 changed files
with
403 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
|
||
#pragma once | ||
#ifndef TEST_CASE_H | ||
#define TEST_CASE_H | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
|
||
#include <stdint.h> | ||
#include <stdbool.h> | ||
|
||
// Compiler feature macros adapted from Hedley (public domain) | ||
// https://github.com/nemequ/hedley | ||
|
||
#if defined(__has_builtin) | ||
# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) | ||
#else | ||
# define EXO_HAS_BUILTIN(builtin) (0) | ||
#endif | ||
|
||
#if EXO_HAS_BUILTIN(__builtin_assume) | ||
# define EXO_ASSUME(expr) __builtin_assume(expr) | ||
#elif EXO_HAS_BUILTIN(__builtin_unreachable) | ||
# define EXO_ASSUME(expr) \ | ||
((void)((expr) ? 1 : (__builtin_unreachable(), 1))) | ||
#else | ||
# define EXO_ASSUME(expr) ((void)(expr)) | ||
#endif | ||
|
||
|
||
#ifndef EXO_WIN_1F32 | ||
#define EXO_WIN_1F32 | ||
struct exo_win_1f32{ | ||
float * const data; | ||
const int_fast32_t strides[1]; | ||
}; | ||
#endif | ||
#ifndef EXO_WIN_1F32C | ||
#define EXO_WIN_1F32C | ||
struct exo_win_1f32c{ | ||
const float * const data; | ||
const int_fast32_t strides[1]; | ||
}; | ||
#endif | ||
// rank_k_reduce_6x16( | ||
// K : size, | ||
// A : f32[6, K] @DRAM, | ||
// B : f32[K, 16] @DRAM, | ||
// C : f32[6, 16] @DRAM | ||
// ) | ||
void rank_k_reduce_6x16( void *ctxt, int_fast32_t K, const float* A, const float* B, float* C ); | ||
|
||
// rank_k_reduce_6x16_scheduled( | ||
// K : size, | ||
// A : f32[6, K] @DRAM, | ||
// B : f32[K, 16] @DRAM, | ||
// C : f32[6, 16] @DRAM | ||
// ) | ||
void rank_k_reduce_6x16_scheduled( void *ctxt, int_fast32_t K, const float* A, const float* B, float* C ); | ||
|
||
|
||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
#endif // TEST_CASE_H | ||
|
||
#include "test_case.h" | ||
|
||
#include <immintrin.h> | ||
#include <stdio.h> | ||
#include <stdlib.h> | ||
|
||
|
||
/* relying on the following instruction..." | ||
mm256_broadcast_ss(out,val) | ||
{out_data} = _mm256_broadcast_ss(&{val_data}); | ||
*/ | ||
|
||
/* relying on the following instruction..." | ||
mm256_fmadd_ps(dst,src1,src2) | ||
{dst_data} = _mm256_fmadd_ps({src1_data}, {src2_data}, {dst_data}); | ||
*/ | ||
|
||
/* relying on the following instruction..." | ||
mm256_loadu_ps(dst,src) | ||
{dst_data} = _mm256_loadu_ps(&{src_data}); | ||
*/ | ||
|
||
/* relying on the following instruction..." | ||
mm256_storeu_ps(dst,src) | ||
_mm256_storeu_ps(&{dst_data}, {src_data}); | ||
*/ | ||
// rank_k_reduce_6x16( | ||
// K : size, | ||
// A : f32[6, K] @DRAM, | ||
// B : f32[K, 16] @DRAM, | ||
// C : f32[6, 16] @DRAM | ||
// ) | ||
void rank_k_reduce_6x16( void *ctxt, int_fast32_t K, const float* A, const float* B, float* C ) { | ||
for (int_fast32_t i = 0; i < 6; i++) { | ||
for (int_fast32_t j = 0; j < 16; j++) { | ||
for (int_fast32_t k = 0; k < K; k++) { | ||
C[i * 16 + j] += A[i * K + k] * B[k * 16 + j]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
// rank_k_reduce_6x16_scheduled( | ||
// K : size, | ||
// A : f32[6, K] @DRAM, | ||
// B : f32[K, 16] @DRAM, | ||
// C : f32[6, 16] @DRAM | ||
// ) | ||
void rank_k_reduce_6x16_scheduled( void *ctxt, int_fast32_t K, const float* A, const float* B, float* C ) { | ||
__m256 C_reg[6][2]; | ||
for (int_fast32_t i0 = 0; i0 < 6; i0++) { | ||
for (int_fast32_t i2 = 0; i2 < 2; i2++) { | ||
C_reg[i0][i2] = _mm256_loadu_ps(&C[(i0) * (16) + 8 * i2]); | ||
} | ||
} | ||
for (int_fast32_t k = 0; k < K; k++) { | ||
__m256 B_reg[2]; | ||
for (int_fast32_t io = 0; io < 2; io++) { | ||
B_reg[io] = _mm256_loadu_ps(&B[(k) * (16) + 8 * io]); | ||
} | ||
for (int_fast32_t i = 0; i < 6; i++) { | ||
__m256 A_reg; | ||
A_reg = _mm256_broadcast_ss(&A[(i) * K + k]); | ||
for (int_fast32_t jo = 0; jo < 2; jo++) { | ||
C_reg[i][jo] = _mm256_fmadd_ps(A_reg, B_reg[jo], C_reg[i][jo]); | ||
} | ||
} | ||
} | ||
for (int_fast32_t i0 = 0; i0 < 6; i0++) { | ||
for (int_fast32_t i2 = 0; i2 < 2; i2++) { | ||
_mm256_storeu_ps(&C[(i0) * (16) + 8 * i2], C_reg[i0][i2]); | ||
} | ||
} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
|
||
#pragma once | ||
#ifndef TEST_CASE_H | ||
#define TEST_CASE_H | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
|
||
#include <stdint.h> | ||
#include <stdbool.h> | ||
|
||
// Compiler feature macros adapted from Hedley (public domain) | ||
// https://github.com/nemequ/hedley | ||
|
||
#if defined(__has_builtin) | ||
# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) | ||
#else | ||
# define EXO_HAS_BUILTIN(builtin) (0) | ||
#endif | ||
|
||
#if EXO_HAS_BUILTIN(__builtin_assume) | ||
# define EXO_ASSUME(expr) __builtin_assume(expr) | ||
#elif EXO_HAS_BUILTIN(__builtin_unreachable) | ||
# define EXO_ASSUME(expr) \ | ||
((void)((expr) ? 1 : (__builtin_unreachable(), 1))) | ||
#else | ||
# define EXO_ASSUME(expr) ((void)(expr)) | ||
#endif | ||
|
||
|
||
|
||
// gemv( | ||
// M : size, | ||
// N : size, | ||
// A : f32[M, N] @DRAM, | ||
// x : f32[N] @DRAM, | ||
// y : f32[M] @DRAM | ||
// ) | ||
void gemv( void *ctxt, int_fast32_t M, int_fast32_t N, const float* A, const float* x, float* y ); | ||
|
||
|
||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
#endif // TEST_CASE_H | ||
|
||
#include "test_case.h" | ||
|
||
#include <stdio.h> | ||
#include <stdlib.h> | ||
|
||
// gemv( | ||
// M : size, | ||
// N : size, | ||
// A : f32[M, N] @DRAM, | ||
// x : f32[N] @DRAM, | ||
// y : f32[M] @DRAM | ||
// ) | ||
void gemv( void *ctxt, int_fast32_t M, int_fast32_t N, const float* A, const float* x, float* y ) { | ||
EXO_ASSUME(M % 8 == 0); | ||
EXO_ASSUME(N % 8 == 0); | ||
for (int_fast32_t io = 0; io < ((M) / (8)); io++) { | ||
for (int_fast32_t jo = 0; jo < ((N) / (8)); jo++) { | ||
for (int_fast32_t ii = 0; ii < 8; ii++) { | ||
for (int_fast32_t ji = 0; ji < 8; ji++) { | ||
y[8 * io + ii] += A[(8 * io + ii) * N + 8 * jo + ji] * x[8 * jo + ji]; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
|
||
#pragma once | ||
#ifndef TEST_CASE_H | ||
#define TEST_CASE_H | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
|
||
#include <stdint.h> | ||
#include <stdbool.h> | ||
|
||
// Compiler feature macros adapted from Hedley (public domain) | ||
// https://github.com/nemequ/hedley | ||
|
||
#if defined(__has_builtin) | ||
# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) | ||
#else | ||
# define EXO_HAS_BUILTIN(builtin) (0) | ||
#endif | ||
|
||
#if EXO_HAS_BUILTIN(__builtin_assume) | ||
# define EXO_ASSUME(expr) __builtin_assume(expr) | ||
#elif EXO_HAS_BUILTIN(__builtin_unreachable) | ||
# define EXO_ASSUME(expr) \ | ||
((void)((expr) ? 1 : (__builtin_unreachable(), 1))) | ||
#else | ||
# define EXO_ASSUME(expr) ((void)(expr)) | ||
#endif | ||
|
||
|
||
#ifndef EXO_WIN_2I32 | ||
#define EXO_WIN_2I32 | ||
struct exo_win_2i32{ | ||
int32_t * const data; | ||
const int_fast32_t strides[2]; | ||
}; | ||
#endif | ||
#ifndef EXO_WIN_2I32C | ||
#define EXO_WIN_2I32C | ||
struct exo_win_2i32c{ | ||
const int32_t * const data; | ||
const int_fast32_t strides[2]; | ||
}; | ||
#endif | ||
// exo_conv1d_tile_lt_kw( | ||
// data : i32[4, 16] @DRAM, | ||
// kernels : i32[16, 4, 4] @DRAM, | ||
// out : i32[16, 16] @DRAM | ||
// ) | ||
void exo_conv1d_tile_lt_kw( void *ctxt, const int32_t* data, const int32_t* kernels, int32_t* out ); | ||
|
||
|
||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
#endif // TEST_CASE_H | ||
|
||
#include "test_case.h" | ||
|
||
#include <stdio.h> | ||
#include <stdlib.h> | ||
|
||
#include <stdio.h> | ||
#include <stdlib.h> | ||
|
||
|
||
// exo_conv1d_tile_lt_kw( | ||
// data : i32[4, 16] @DRAM, | ||
// kernels : i32[16, 4, 4] @DRAM, | ||
// out : i32[16, 16] @DRAM | ||
// ) | ||
void exo_conv1d_tile_lt_kw( void *ctxt, const int32_t* data, const int32_t* kernels, int32_t* out ) { | ||
for (int_fast32_t ioo = 0; ioo < 1; ioo++) { | ||
for (int_fast32_t jo = 0; jo < 4; jo++) { | ||
#define out_tile_0 "m7" | ||
#define out_tile_1 "m6" | ||
#define out_tile_2 "m5" | ||
#define out_tile_3 "m4" | ||
asm volatile("mzero "out_tile_0); | ||
asm volatile("mzero "out_tile_1); | ||
asm volatile("mzero "out_tile_2); | ||
asm volatile("mzero "out_tile_3); | ||
for (int_fast32_t c = 0; c < 4; c++) { | ||
static int32_t y[4 * 4]; | ||
for (int_fast32_t ji = 0; ji < 4; ji++) { | ||
for (int_fast32_t r = 0; r < 4; r++) { | ||
if (ji + r + 4 * jo < 16) { | ||
y[ji * 4 + r] = data[c * 16 + ji + r + 4 * jo]; | ||
} else { | ||
y[ji * 4 + r] = ((int32_t) 0); | ||
} | ||
} | ||
} | ||
#define kernel_tile_0 "m3" | ||
#define kernel_tile_1 "m2" | ||
#define kernel_tile_2 "m1" | ||
#define data_tile "m0" | ||
asm volatile("mld.w "data_tile", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){ &y[0], { 4, 1 } }).strides[0])), "r"(&y[0])); | ||
asm volatile("mld.w "kernel_tile_0", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){ &kernels[(16 * ioo) * (16) + (c) * 4], { 16, 1 } }).strides[0])), "r"(&kernels[(16 * ioo) * (16) + (c) * 4])); | ||
asm volatile("mmasa.w "out_tile_0", "data_tile", "kernel_tile_0); | ||
asm volatile("mld.w "kernel_tile_1", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){ &kernels[(4 + 16 * ioo) * (16) + (c) * 4], { 16, 1 } }).strides[0])), "r"(&kernels[(4 + 16 * ioo) * (16) + (c) * 4])); | ||
asm volatile("mmasa.w "out_tile_1", "data_tile", "kernel_tile_1); | ||
#undef kernel_tile_1 | ||
asm volatile("mld.w "kernel_tile_2", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){ &kernels[(8 + 16 * ioo) * (16) + (c) * 4], { 16, 1 } }).strides[0])), "r"(&kernels[(8 + 16 * ioo) * (16) + (c) * 4])); | ||
asm volatile("mmasa.w "out_tile_2", "data_tile", "kernel_tile_2); | ||
#undef kernel_tile_2 | ||
asm volatile("mld.w "kernel_tile_0", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){ &kernels[(12 + 16 * ioo) * (16) + (c) * 4], { 16, 1 } }).strides[0])), "r"(&kernels[(12 + 16 * ioo) * (16) + (c) * 4])); | ||
asm volatile("mmasa.w "out_tile_3", "data_tile", "kernel_tile_0); | ||
#undef data_tile | ||
#undef kernel_tile_0 | ||
} | ||
asm volatile("mst.w "out_tile_0", (%1), %0" :: "r"(4*(((struct exo_win_2i32){ &out[(16 * ioo) * (16) + 4 * jo], { 16, 1 } }).strides[0])), "r"(&out[(16 * ioo) * (16) + 4 * jo])); | ||
#undef out_tile_0 | ||
asm volatile("mst.w "out_tile_1", (%1), %0" :: "r"(4*(((struct exo_win_2i32){ &out[(4 + 16 * ioo) * (16) + 4 * jo], { 16, 1 } }).strides[0])), "r"(&out[(4 + 16 * ioo) * (16) + 4 * jo])); | ||
#undef out_tile_1 | ||
asm volatile("mst.w "out_tile_2", (%1), %0" :: "r"(4*(((struct exo_win_2i32){ &out[(8 + 16 * ioo) * (16) + 4 * jo], { 16, 1 } }).strides[0])), "r"(&out[(8 + 16 * ioo) * (16) + 4 * jo])); | ||
#undef out_tile_2 | ||
asm volatile("mst.w "out_tile_3", (%1), %0" :: "r"(4*(((struct exo_win_2i32){ &out[(12 + 16 * ioo) * (16) + 4 * jo], { 16, 1 } }).strides[0])), "r"(&out[(12 + 16 * ioo) * (16) + 4 * jo])); | ||
#undef out_tile_3 | ||
} | ||
} | ||
} | ||
|
||
|
||
/* relying on the following instruction..." | ||
rvm_mld(dst,src) | ||
asm volatile("mld.w "{dst_int}", (%1), %0" :: "r"(4*({src}.strides[0])), "r"(&{src_data})); | ||
*/ | ||
|
||
/* relying on the following instruction..." | ||
rvm_mmasa(md,ms1,ms2) | ||
asm volatile("mmasa.w "{md_int}", "{ms1_int}", "{ms2_int}); | ||
*/ | ||
|
||
/* relying on the following instruction..." | ||
rvm_mst(src,dst) | ||
asm volatile("mst.w "{src_int}", (%1), %0" :: "r"(4*({dst}.strides[0])), "r"(&{dst_data})); | ||
*/ | ||
|
||
/* relying on the following instruction..." | ||
rvm_mzero(dst) | ||
asm volatile("mzero "{dst_int}); | ||
*/ |
Oops, something went wrong.