Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dynamic Dispatch for Kernels + Support MKL-based kernels w/ Fallback #122

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions kernels/avx/matmul_avx.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#ifndef MATMUL_OPERATOR_AVX_H
#define MATMUL_OPERATOR_AVX_H

#include "matmul.h"
#include <iostream>

namespace matmul {

class MatmulOperatorAVX : public MatmulOperator {
public:
void mat_mul_accelerator_transposed_fastover_column(const struct matmul_params* params) override;
void mat_mul_accelerator_transposed_fastover_column_bias(const struct matmul_params* params) override;

// int8 operations
void mat_mul_accelerator_int8_fast_32unroll_over_column(const struct matmul_params* params) override;
void mat_mul_accelerator_int8_fast_2x2_32unroll(const struct matmul_params* params) override;
void mat_mul_accelerator_int8_fast_2x2_32unroll_nobias(const struct matmul_params* params) override;
void mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_batch(const struct matmul_params* params) override;
void mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_ofp32(const struct matmul_params* params) override;
void mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_ofp32_batch(const struct matmul_params* params) override;
void mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32(const struct matmul_params* params) override;
void mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32_over_column(const struct matmul_params* params) override;

void mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_params* params) override;

void mat_mul_accelerator_int4_fast(const struct matmul_params* params) override;
void mat_mul_accelerator_int4_fast_no_offset(const struct matmul_params* params) override;
};

inline MatmulOperator& CreateMatmulOperatorAVX() {
static MatmulOperatorAVX instance;
return instance;
}

} // namespace matmul

#endif
7 changes: 4 additions & 3 deletions kernels/avx/matmul_avx_fp32.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
#include <pthread.h>
#include <stdio.h>
#include <xmmintrin.h> // intel SSE intrinsic
#include <iostream>

#include "../matmul.h"
#include "matmul_avx.h"

namespace matmul {

Expand Down Expand Up @@ -60,7 +61,7 @@ void *mat_mul_transposed_fastover_column_func(void *args) {
return NULL;
}

void MatmulOperator::mat_mul_accelerator_transposed_fastover_column(const struct matmul_params *params) {
void MatmulOperatorAVX::mat_mul_accelerator_transposed_fastover_column(const struct matmul_params *params) {
int i, j, k;

int num_thread = params->opt_params.num_thread;
Expand Down Expand Up @@ -112,7 +113,7 @@ void fp32_ref_matmul_bias(const struct matmul_params *params) {
}
}

void MatmulOperator::mat_mul_accelerator_transposed_fastover_column_bias(const struct matmul_params *params) {
void MatmulOperatorAVX::mat_mul_accelerator_transposed_fastover_column_bias(const struct matmul_params *params) {
fp32_ref_matmul_bias(params);
}

Expand Down
6 changes: 3 additions & 3 deletions kernels/avx/matmul_avx_int4.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include <cassert>

#include "../matmul.h"
#include "matmul_avx.h"

static inline __m256i bytes_from_nibbles_32(const uint8_t *rsi) {
// Load 16 bytes from memory
Expand Down Expand Up @@ -675,7 +675,7 @@ static void *fast_zp_no_offset_over_column_func_v5(void *args) {
}

namespace matmul {
void MatmulOperator::mat_mul_accelerator_int4_fast(const struct matmul_params *params) {
void MatmulOperatorAVX::mat_mul_accelerator_int4_fast(const struct matmul_params *params) {
const int num_thread = params->opt_params.num_thread;
int i, j, k;
pthread_t thread_pool[num_thread];
Expand All @@ -693,7 +693,7 @@ void MatmulOperator::mat_mul_accelerator_int4_fast(const struct matmul_params *p
for (j = 0; j < num_thread; j++) pthread_join(thread_pool[j], NULL);
};

void MatmulOperator::mat_mul_accelerator_int4_fast_no_offset(const struct matmul_params *params) {
void MatmulOperatorAVX::mat_mul_accelerator_int4_fast_no_offset(const struct matmul_params *params) {
const int num_thread = params->opt_params.num_thread;
int i, j, k;
pthread_t thread_pool[num_thread];
Expand Down
22 changes: 12 additions & 10 deletions kernels/avx/matmul_avx_int8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <cstdlib>
#include <iostream>

#include "../matmul.h"
#include "matmul_avx.h"

inline void assign_8int32(int *ptr, int &acc) {
acc = (ptr[0] + ptr[1] + ptr[2] + ptr[3] + ptr[4] + ptr[5] + ptr[6] + ptr[7]);
Expand Down Expand Up @@ -381,7 +381,7 @@ void *mat_mul_accelerator_int8_thread_func_2x2_32unroll(void *args) {
return NULL;
}

void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll(const struct matmul_params *params) {
void MatmulOperatorAVX::mat_mul_accelerator_int8_fast_2x2_32unroll(const struct matmul_params *params) {
int j, num_thread = params->opt_params.num_thread;

assert(params->A.column % 64 == 0);
Expand Down Expand Up @@ -478,7 +478,7 @@ void *mat_mul_accelerator_int8_fast_32unroll_over_column_thread_func(void *args)
return NULL;
}

void MatmulOperator::mat_mul_accelerator_int8_fast_32unroll_over_column(const struct matmul_params *params) {
void MatmulOperatorAVX::mat_mul_accelerator_int8_fast_32unroll_over_column(const struct matmul_params *params) {
int j, num_thread = params->opt_params.num_thread;

if (num_thread > params->C.column) num_thread = params->C.column;
Expand Down Expand Up @@ -610,7 +610,7 @@ void *mat_mul_accelerator_int8_thread_func_2x2_32unroll_nobias(void *args) {
return NULL;
}

void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias(const struct matmul_params *params) {
void MatmulOperatorAVX::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias(const struct matmul_params *params) {
int j, num_thread = params->opt_params.num_thread;

assert((params->C.column) % 2 == 0);
Expand Down Expand Up @@ -681,7 +681,7 @@ void *mat_mul_accelerator_int8_thread_func_2x2_32unroll_nobias_batch(void *args)
return NULL;
}

void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_batch(const struct matmul_params *params) {
void MatmulOperatorAVX::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_batch(const struct matmul_params *params) {
int j, num_thread = params->opt_params.num_thread;

assert((params->C.column) % 2 == 0);
Expand Down Expand Up @@ -791,7 +791,7 @@ void *mat_mul_accelerator_int8_thread_func_2x2_32unroll_nobias_ofp32(void *args)
return NULL;
}

void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_ofp32(const struct matmul_params *params) {
void MatmulOperatorAVX::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_ofp32(const struct matmul_params *params) {
int j, num_thread = params->opt_params.num_thread;

assert(params->A.column % 32 == 0);
Expand Down Expand Up @@ -851,7 +851,7 @@ void *mat_mul_accelerator_int8_thread_func_2x2_32unroll_nobias_ofp32_batch(void
return NULL;
}

void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_ofp32_batch(const struct matmul_params *params) {
void MatmulOperatorAVX::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_ofp32_batch(const struct matmul_params *params) {
int j, num_thread = params->opt_params.num_thread;

assert(params->A.column % 32 == 0);
Expand Down Expand Up @@ -940,7 +940,7 @@ void *mat_mul_accelerator_int8_thread_func_2x2_32unroll_bfp32_ofp32(void *args)
return NULL;
}

void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32(const struct matmul_params *params) {
void MatmulOperatorAVX::mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32(const struct matmul_params *params) {
int j, num_thread = params->opt_params.num_thread;

assert(params->A.column % 64 == 0);
Expand Down Expand Up @@ -1211,8 +1211,9 @@ void *mat_mul_accelerator_int8_thread_func_2x2_32unroll_bfp32_ofp32_over_column(
return NULL;
}

void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32_over_column(
void MatmulOperatorAVX::mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32_over_column(
const struct matmul_params *params) {

int j, num_thread = params->opt_params.num_thread;

if (num_thread > params->C.column) num_thread = params->C.column;
Expand Down Expand Up @@ -1241,4 +1242,5 @@ void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32_over
}
}

} // namespace matmul

}
4 changes: 2 additions & 2 deletions kernels/avx/matmul_avx_int8_int4.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <cassert>
#include <cmath>

#include "../matmul.h"
#include "matmul_avx.h"

#include "pthread_pool.h"

Expand Down Expand Up @@ -322,7 +322,7 @@ static void quantize_fp_to_int8_block_size32(float *x, int size, int8_t *qx, flo

namespace matmul {

void MatmulOperator::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_params *params) {
void MatmulOperatorAVX::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_params *params) {
// const int num_thread = 4;
const int num_thread = params->opt_params.num_thread;
int i, j, k;
Expand Down
8 changes: 4 additions & 4 deletions kernels/cuda/gemv_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <iostream>
#include <stdio.h>

#include "../matmul.h"
#include "matmul_cuda.h"
#include "ops/linear.h"

// #include <cuda_runtime.h>
Expand Down Expand Up @@ -210,7 +210,7 @@ namespace matmul{
Returns:
out_feats: tensor of shape [B, OC];
*/
void MatmulOperator::gemv_forward_cuda(const struct matmul_params *params)
void MatmulOperatorCUDA::gemv_forward_cuda(const struct matmul_params *params)
{
const struct matrix *A = &params->A, *B = &params->B, *C = &params->C;

Expand Down Expand Up @@ -259,11 +259,11 @@ namespace matmul{
PROFILE_END("gemv_forward_cuda");
}

void MatmulOperator::mat_mul_accelerator_int4_fast(const struct matmul_params *params) {
void MatmulOperatorCUDA::mat_mul_accelerator_int4_fast(const struct matmul_params *params) {
// TODO: remove this
};

void MatmulOperator::mat_mul_accelerator_int4_fast_no_offset(const struct matmul_params *params) {
void MatmulOperatorCUDA::mat_mul_accelerator_int4_fast_no_offset(const struct matmul_params *params) {
// TODO: remove this
};

Expand Down
38 changes: 38 additions & 0 deletions kernels/cuda/matmul_cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#ifndef MATMUL_OPERATOR_CUDA_H
#define MATMUL_OPERATOR_CUDA_H

#include "matmul.h"
#include <iostream>

namespace matmul {

class MatmulOperatorCUDA : public MatmulOperator {
public:
void mat_mul_accelerator_transposed_fastover_column(const struct matmul_params* params) override;

// int8 operations
void mat_mul_accelerator_int8_fast_32unroll_over_column(const struct matmul_params* params) override;
void mat_mul_accelerator_int8_fast_2x2_32unroll(const struct matmul_params* params) override;
void mat_mul_accelerator_int8_fast_2x2_32unroll_nobias(const struct matmul_params* params) override;
void mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_batch(const struct matmul_params* params) override;
void mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_ofp32(const struct matmul_params* params) override;
void mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_ofp32_batch(const struct matmul_params* params) override;
void mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32(const struct matmul_params* params) override;
void mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32_over_column(const struct matmul_params* params) override;

void mat_mul_accelerator_int4_fast(const struct matmul_params* params) override;
void mat_mul_accelerator_int4_fast_no_offset(const struct matmul_params* params) override;

void gemv_forward_cuda(const struct matmul_params* params) override;
void naive_mat_mul_fp16_int4(const struct matmul_params* params) override;
};

// Declaring as static to prevent linker errors due to both cc and cu files
static inline MatmulOperator& CreateMatmulOperatorCUDA() {
static MatmulOperatorCUDA instance;
return instance;
}

} // namespace matmul

#endif
4 changes: 2 additions & 2 deletions kernels/cuda/matmul_int4.cu
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
#include <cstdlib>
#include <iostream>

#include "../matmul.h"
#include "matmul_cuda.h"

namespace matmul {

void MatmulOperator::naive_mat_mul_fp16_int4(const struct matmul_params *params) {
void MatmulOperatorCUDA::naive_mat_mul_fp16_int4(const struct matmul_params *params) {
const struct matrix *A = &params->A, *B = &params->B, *C = &params->C;
const int block_size = params->block_size;
// CHECK_MATRICES_int4weight(A, B, C);
Expand Down
5 changes: 3 additions & 2 deletions kernels/cuda/matmul_ref_fp32.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <cmath>
#include <cstdlib>

#include "../matmul.h"
#include "matmul_cuda.h"

namespace matmul {
void fp32_ref_matmul(const struct matmul_params *params) {
Expand All @@ -28,7 +28,8 @@ void fp32_ref_matmul(const struct matmul_params *params) {
}
}

void MatmulOperator::mat_mul_accelerator_transposed_fastover_column(const struct matmul_params *params) {
void MatmulOperatorCUDA::mat_mul_accelerator_transposed_fastover_column(const struct matmul_params *params) {
std::cout<<"mat_mul_accelerator_transposed_fastover_column, fp32"<<std::endl;
fp32_ref_matmul(params);
}

Expand Down
18 changes: 9 additions & 9 deletions kernels/cuda/matmul_ref_int8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#include <cmath>
#include <cstdlib>

#include "../matmul.h"
#include "matmul_cuda.h"

namespace matmul {
void int8_ref_matmul(const struct matmul_params *params) {
Expand Down Expand Up @@ -157,35 +157,35 @@ void int8_ref_matmul_nobias_ofp32_batch(const struct matmul_params *params) {
}
}

void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll(const struct matmul_params *params) {
void MatmulOperatorCUDA::mat_mul_accelerator_int8_fast_2x2_32unroll(const struct matmul_params *params) {
int8_ref_matmul(params);
}

void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias(const struct matmul_params *params) {
void MatmulOperatorCUDA::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias(const struct matmul_params *params) {
int8_ref_matmul_nobias(params);
}

void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_batch(const struct matmul_params *params) {
void MatmulOperatorCUDA::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_batch(const struct matmul_params *params) {
int8_ref_matmul_nobias_batch(params);
}

void MatmulOperator::mat_mul_accelerator_int8_fast_32unroll_over_column(const struct matmul_params *params) {
void MatmulOperatorCUDA::mat_mul_accelerator_int8_fast_32unroll_over_column(const struct matmul_params *params) {
int8_ref_matmul(params);
}

void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32(const struct matmul_params *params) {
void MatmulOperatorCUDA::mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32(const struct matmul_params *params) {
int8_ref_matmul_bfp32_ofp32(params);
}

void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_ofp32(const struct matmul_params *params) {
void MatmulOperatorCUDA::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_ofp32(const struct matmul_params *params) {
int8_ref_matmul_nobias_ofp32(params);
}

void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_ofp32_batch(const struct matmul_params *params) {
void MatmulOperatorCUDA::mat_mul_accelerator_int8_fast_2x2_32unroll_nobias_ofp32_batch(const struct matmul_params *params) {
int8_ref_matmul_nobias_ofp32_batch(params);
}

void MatmulOperator::mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32_over_column(
void MatmulOperatorCUDA::mat_mul_accelerator_int8_fast_2x2_32unroll_bfp32_ofp32_over_column(
const struct matmul_params *params) {
int8_ref_matmul_bfp32_ofp32(params);
}
Expand Down
Loading