Skip to content

Commit

Permalink
Add pytest for examples
Browse files Browse the repository at this point in the history
  • Loading branch information
yamaguchi1024 committed Nov 10, 2024
1 parent 74204ba commit d9772a3
Show file tree
Hide file tree
Showing 4 changed files with 403 additions and 0 deletions.
144 changes: 144 additions & 0 deletions tests/golden/test_examples/test_avx2_matmul.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@

#pragma once
#ifndef TEST_CASE_H
#define TEST_CASE_H

#ifdef __cplusplus
extern "C" {
#endif


#include <stdint.h>
#include <stdbool.h>

// Compiler feature macros adapted from Hedley (public domain)
// https://github.com/nemequ/hedley

#if defined(__has_builtin)
# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin)
#else
# define EXO_HAS_BUILTIN(builtin) (0)
#endif

#if EXO_HAS_BUILTIN(__builtin_assume)
# define EXO_ASSUME(expr) __builtin_assume(expr)
#elif EXO_HAS_BUILTIN(__builtin_unreachable)
# define EXO_ASSUME(expr) \
((void)((expr) ? 1 : (__builtin_unreachable(), 1)))
#else
# define EXO_ASSUME(expr) ((void)(expr))
#endif


#ifndef EXO_WIN_1F32
#define EXO_WIN_1F32
struct exo_win_1f32{
float * const data;
const int_fast32_t strides[1];
};
#endif
#ifndef EXO_WIN_1F32C
#define EXO_WIN_1F32C
struct exo_win_1f32c{
const float * const data;
const int_fast32_t strides[1];
};
#endif
// rank_k_reduce_6x16(
// K : size,
// A : f32[6, K] @DRAM,
// B : f32[K, 16] @DRAM,
// C : f32[6, 16] @DRAM
// )
void rank_k_reduce_6x16( void *ctxt, int_fast32_t K, const float* A, const float* B, float* C );

// rank_k_reduce_6x16_scheduled(
// K : size,
// A : f32[6, K] @DRAM,
// B : f32[K, 16] @DRAM,
// C : f32[6, 16] @DRAM
// )
void rank_k_reduce_6x16_scheduled( void *ctxt, int_fast32_t K, const float* A, const float* B, float* C );



#ifdef __cplusplus
}
#endif
#endif // TEST_CASE_H

#include "test_case.h"

#include <immintrin.h>
#include <stdio.h>
#include <stdlib.h>


/* relying on the following instruction..."
mm256_broadcast_ss(out,val)
{out_data} = _mm256_broadcast_ss(&{val_data});
*/

/* relying on the following instruction..."
mm256_fmadd_ps(dst,src1,src2)
{dst_data} = _mm256_fmadd_ps({src1_data}, {src2_data}, {dst_data});
*/

/* relying on the following instruction..."
mm256_loadu_ps(dst,src)
{dst_data} = _mm256_loadu_ps(&{src_data});
*/

/* relying on the following instruction..."
mm256_storeu_ps(dst,src)
_mm256_storeu_ps(&{dst_data}, {src_data});
*/
// rank_k_reduce_6x16(
// K : size,
// A : f32[6, K] @DRAM,
// B : f32[K, 16] @DRAM,
// C : f32[6, 16] @DRAM
// )
void rank_k_reduce_6x16( void *ctxt, int_fast32_t K, const float* A, const float* B, float* C ) {
for (int_fast32_t i = 0; i < 6; i++) {
for (int_fast32_t j = 0; j < 16; j++) {
for (int_fast32_t k = 0; k < K; k++) {
C[i * 16 + j] += A[i * K + k] * B[k * 16 + j];
}
}
}
}

// rank_k_reduce_6x16_scheduled(
// K : size,
// A : f32[6, K] @DRAM,
// B : f32[K, 16] @DRAM,
// C : f32[6, 16] @DRAM
// )
void rank_k_reduce_6x16_scheduled( void *ctxt, int_fast32_t K, const float* A, const float* B, float* C ) {
__m256 C_reg[6][2];
for (int_fast32_t i0 = 0; i0 < 6; i0++) {
for (int_fast32_t i2 = 0; i2 < 2; i2++) {
C_reg[i0][i2] = _mm256_loadu_ps(&C[(i0) * (16) + 8 * i2]);
}
}
for (int_fast32_t k = 0; k < K; k++) {
__m256 B_reg[2];
for (int_fast32_t io = 0; io < 2; io++) {
B_reg[io] = _mm256_loadu_ps(&B[(k) * (16) + 8 * io]);
}
for (int_fast32_t i = 0; i < 6; i++) {
__m256 A_reg;
A_reg = _mm256_broadcast_ss(&A[(i) * K + k]);
for (int_fast32_t jo = 0; jo < 2; jo++) {
C_reg[i][jo] = _mm256_fmadd_ps(A_reg, B_reg[jo], C_reg[i][jo]);
}
}
}
for (int_fast32_t i0 = 0; i0 < 6; i0++) {
for (int_fast32_t i2 = 0; i2 < 2; i2++) {
_mm256_storeu_ps(&C[(i0) * (16) + 8 * i2], C_reg[i0][i2]);
}
}
}

75 changes: 75 additions & 0 deletions tests/golden/test_examples/test_cursors.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@

#pragma once
#ifndef TEST_CASE_H
#define TEST_CASE_H

#ifdef __cplusplus
extern "C" {
#endif


#include <stdint.h>
#include <stdbool.h>

// Compiler feature macros adapted from Hedley (public domain)
// https://github.com/nemequ/hedley

#if defined(__has_builtin)
# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin)
#else
# define EXO_HAS_BUILTIN(builtin) (0)
#endif

#if EXO_HAS_BUILTIN(__builtin_assume)
# define EXO_ASSUME(expr) __builtin_assume(expr)
#elif EXO_HAS_BUILTIN(__builtin_unreachable)
# define EXO_ASSUME(expr) \
((void)((expr) ? 1 : (__builtin_unreachable(), 1)))
#else
# define EXO_ASSUME(expr) ((void)(expr))
#endif



// gemv(
// M : size,
// N : size,
// A : f32[M, N] @DRAM,
// x : f32[N] @DRAM,
// y : f32[M] @DRAM
// )
void gemv( void *ctxt, int_fast32_t M, int_fast32_t N, const float* A, const float* x, float* y );



#ifdef __cplusplus
}
#endif
#endif // TEST_CASE_H

#include "test_case.h"

#include <stdio.h>
#include <stdlib.h>

// gemv(
// M : size,
// N : size,
// A : f32[M, N] @DRAM,
// x : f32[N] @DRAM,
// y : f32[M] @DRAM
// )
void gemv( void *ctxt, int_fast32_t M, int_fast32_t N, const float* A, const float* x, float* y ) {
EXO_ASSUME(M % 8 == 0);
EXO_ASSUME(N % 8 == 0);
for (int_fast32_t io = 0; io < ((M) / (8)); io++) {
for (int_fast32_t jo = 0; jo < ((N) / (8)); jo++) {
for (int_fast32_t ii = 0; ii < 8; ii++) {
for (int_fast32_t ji = 0; ji < 8; ji++) {
y[8 * io + ii] += A[(8 * io + ii) * N + 8 * jo + ji] * x[8 * jo + ji];
}
}
}
}
}

146 changes: 146 additions & 0 deletions tests/golden/test_examples/test_rvm_conv1d.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@

#pragma once
#ifndef TEST_CASE_H
#define TEST_CASE_H

#ifdef __cplusplus
extern "C" {
#endif


#include <stdint.h>
#include <stdbool.h>

// Compiler feature macros adapted from Hedley (public domain)
// https://github.com/nemequ/hedley

#if defined(__has_builtin)
# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin)
#else
# define EXO_HAS_BUILTIN(builtin) (0)
#endif

#if EXO_HAS_BUILTIN(__builtin_assume)
# define EXO_ASSUME(expr) __builtin_assume(expr)
#elif EXO_HAS_BUILTIN(__builtin_unreachable)
# define EXO_ASSUME(expr) \
((void)((expr) ? 1 : (__builtin_unreachable(), 1)))
#else
# define EXO_ASSUME(expr) ((void)(expr))
#endif


#ifndef EXO_WIN_2I32
#define EXO_WIN_2I32
struct exo_win_2i32{
int32_t * const data;
const int_fast32_t strides[2];
};
#endif
#ifndef EXO_WIN_2I32C
#define EXO_WIN_2I32C
struct exo_win_2i32c{
const int32_t * const data;
const int_fast32_t strides[2];
};
#endif
// exo_conv1d_tile_lt_kw(
// data : i32[4, 16] @DRAM,
// kernels : i32[16, 4, 4] @DRAM,
// out : i32[16, 16] @DRAM
// )
void exo_conv1d_tile_lt_kw( void *ctxt, const int32_t* data, const int32_t* kernels, int32_t* out );



#ifdef __cplusplus
}
#endif
#endif // TEST_CASE_H

#include "test_case.h"

#include <stdio.h>
#include <stdlib.h>

#include <stdio.h>
#include <stdlib.h>


// exo_conv1d_tile_lt_kw(
// data : i32[4, 16] @DRAM,
// kernels : i32[16, 4, 4] @DRAM,
// out : i32[16, 16] @DRAM
// )
void exo_conv1d_tile_lt_kw( void *ctxt, const int32_t* data, const int32_t* kernels, int32_t* out ) {
for (int_fast32_t ioo = 0; ioo < 1; ioo++) {
for (int_fast32_t jo = 0; jo < 4; jo++) {
#define out_tile_0 "m7"
#define out_tile_1 "m6"
#define out_tile_2 "m5"
#define out_tile_3 "m4"
asm volatile("mzero "out_tile_0);
asm volatile("mzero "out_tile_1);
asm volatile("mzero "out_tile_2);
asm volatile("mzero "out_tile_3);
for (int_fast32_t c = 0; c < 4; c++) {
static int32_t y[4 * 4];
for (int_fast32_t ji = 0; ji < 4; ji++) {
for (int_fast32_t r = 0; r < 4; r++) {
if (ji + r + 4 * jo < 16) {
y[ji * 4 + r] = data[c * 16 + ji + r + 4 * jo];
} else {
y[ji * 4 + r] = ((int32_t) 0);
}
}
}
#define kernel_tile_0 "m3"
#define kernel_tile_1 "m2"
#define kernel_tile_2 "m1"
#define data_tile "m0"
asm volatile("mld.w "data_tile", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){ &y[0], { 4, 1 } }).strides[0])), "r"(&y[0]));
asm volatile("mld.w "kernel_tile_0", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){ &kernels[(16 * ioo) * (16) + (c) * 4], { 16, 1 } }).strides[0])), "r"(&kernels[(16 * ioo) * (16) + (c) * 4]));
asm volatile("mmasa.w "out_tile_0", "data_tile", "kernel_tile_0);
asm volatile("mld.w "kernel_tile_1", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){ &kernels[(4 + 16 * ioo) * (16) + (c) * 4], { 16, 1 } }).strides[0])), "r"(&kernels[(4 + 16 * ioo) * (16) + (c) * 4]));
asm volatile("mmasa.w "out_tile_1", "data_tile", "kernel_tile_1);
#undef kernel_tile_1
asm volatile("mld.w "kernel_tile_2", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){ &kernels[(8 + 16 * ioo) * (16) + (c) * 4], { 16, 1 } }).strides[0])), "r"(&kernels[(8 + 16 * ioo) * (16) + (c) * 4]));
asm volatile("mmasa.w "out_tile_2", "data_tile", "kernel_tile_2);
#undef kernel_tile_2
asm volatile("mld.w "kernel_tile_0", (%1), %0" :: "r"(4*(((struct exo_win_2i32c){ &kernels[(12 + 16 * ioo) * (16) + (c) * 4], { 16, 1 } }).strides[0])), "r"(&kernels[(12 + 16 * ioo) * (16) + (c) * 4]));
asm volatile("mmasa.w "out_tile_3", "data_tile", "kernel_tile_0);
#undef data_tile
#undef kernel_tile_0
}
asm volatile("mst.w "out_tile_0", (%1), %0" :: "r"(4*(((struct exo_win_2i32){ &out[(16 * ioo) * (16) + 4 * jo], { 16, 1 } }).strides[0])), "r"(&out[(16 * ioo) * (16) + 4 * jo]));
#undef out_tile_0
asm volatile("mst.w "out_tile_1", (%1), %0" :: "r"(4*(((struct exo_win_2i32){ &out[(4 + 16 * ioo) * (16) + 4 * jo], { 16, 1 } }).strides[0])), "r"(&out[(4 + 16 * ioo) * (16) + 4 * jo]));
#undef out_tile_1
asm volatile("mst.w "out_tile_2", (%1), %0" :: "r"(4*(((struct exo_win_2i32){ &out[(8 + 16 * ioo) * (16) + 4 * jo], { 16, 1 } }).strides[0])), "r"(&out[(8 + 16 * ioo) * (16) + 4 * jo]));
#undef out_tile_2
asm volatile("mst.w "out_tile_3", (%1), %0" :: "r"(4*(((struct exo_win_2i32){ &out[(12 + 16 * ioo) * (16) + 4 * jo], { 16, 1 } }).strides[0])), "r"(&out[(12 + 16 * ioo) * (16) + 4 * jo]));
#undef out_tile_3
}
}
}


/* relying on the following instruction..."
rvm_mld(dst,src)
asm volatile("mld.w "{dst_int}", (%1), %0" :: "r"(4*({src}.strides[0])), "r"(&{src_data}));
*/

/* relying on the following instruction..."
rvm_mmasa(md,ms1,ms2)
asm volatile("mmasa.w "{md_int}", "{ms1_int}", "{ms2_int});
*/

/* relying on the following instruction..."
rvm_mst(src,dst)
asm volatile("mst.w "{src_int}", (%1), %0" :: "r"(4*({dst}.strides[0])), "r"(&{dst_data}));
*/

/* relying on the following instruction..."
rvm_mzero(dst)
asm volatile("mzero "{dst_int});
*/
Loading

0 comments on commit d9772a3

Please sign in to comment.