diff --git a/build/bli_config.h.in b/build/bli_config.h.in index fa6bbbe12e..dce87a3fa6 100644 --- a/build/bli_config.h.in +++ b/build/bli_config.h.in @@ -183,5 +183,10 @@ #define BLIS_DISABLE_COMPLEX_RETURN_INTEL #endif +#if @enable_amd_offload@ +#define BLIS_ENABLE_AMD_OFFLOAD +#else +#define BLIS_DISABLE_AMD_OFFLOAD +#endif #endif diff --git a/build/config.mk.in b/build/config.mk.in index 7ef8c6bd00..92138003ac 100644 --- a/build/config.mk.in +++ b/build/config.mk.in @@ -208,5 +208,8 @@ LIBPTHREAD := @libpthread@ # Whether we should use AMD-customized versions of certain framework files. ENABLE_AMD_FRAME_TWEAKS := @enable_amd_frame_tweaks@ +# Whether offloading to AMD accelerators should be attempted +ENABLE_AMD_OFFLOAD := @enable_amd_offload@ + # end of ifndef CONFIG_MK_INCLUDED conditional block endif diff --git a/common.mk b/common.mk index 6661f84c51..ae784e92dc 100644 --- a/common.mk +++ b/common.mk @@ -552,6 +552,11 @@ ifeq ($(DEBUG_TYPE),sde) LDFLAGS := $(filter-out $(LIBMEMKIND),$(LDFLAGS)) endif +ifeq ($(ENABLE_AMD_OFFLOAD),yes) +LDFLAGS += -ldl +LDFLAGS += -L/opt/rocm/lib -lamdhip64 -lrocblas +endif + # Specify the shared library's 'soname' field. # NOTE: The flag for creating shared objects is different for Linux and OS X. ifeq ($(OS_NAME),Darwin) @@ -1134,6 +1139,11 @@ ifeq ($(MK_ENABLE_CBLAS),yes) CINCFLAGS += -I$(CBLAS_H_DIRPATH) endif +# If AMD offloading is enabled, we also add the ROCm include directory +ifeq ($(ENABLE_AMD_OFFLOAD),yes) +CINCFLAGS += -I/opt/rocm/include -D__HIP_PLATFORM_AMD__=1 +endif + # Obtain a list of header paths in the configured addons. Then add -I to each # header path. CADDONINCFLAGS := $(strip $(patsubst %, -I%, $(ADDON_HDR_DIRPATHS))) diff --git a/configure b/configure index a6018edab2..2a27aa7668 100755 --- a/configure +++ b/configure @@ -285,6 +285,15 @@ print_usage() echo " which are determined by the BLIS subconfiguration used at" echo " runtime.) By default, these customized files are disabled." echo " " + echo " --enable-amd-offload, --disable-amd-offload" + echo " " + echo " Enable conditional offloading of some Level-3 BLAS calls" + echo " to AMD accelerators such as MI100, MI200." + echo " Enabling this option requires ROCm to be installed and" + echo " uses rocBLAS as a backend." + echo " Introduces rocblas-dev and hip-dev as dependencies." + echo " By default, the offloading path are disabled." + echo " " echo " -a NAME --enable-addon=NAME" echo " " echo " Enable the code provided by an addon. An addon consists" @@ -2469,6 +2478,7 @@ main() enable_mixed_dt_extra_mem='yes' enable_sup_handling='yes' enable_amd_frame_tweaks='no' + enable_amd_offload='no' enable_memkind='' # The default memkind value is determined later on. enable_trsm_preinversion='yes' force_version='no' @@ -2687,6 +2697,12 @@ main() disable-amd-frame-tweaks) enable_amd_frame_tweaks='no' ;; + enable-amd-offload) + enable_amd_offload='yes' + ;; + disable-amd-offload) + enable_amd_offload='no' + ;; with-memkind) enable_memkind='yes' ;; @@ -3616,6 +3632,29 @@ main() echo "${script_name}: AMD-specific framework files will not be considered." fi + # Check whether anything should be offloaded to AMD accelerators + enable_amd_offload_01=0 + if [ "x${enable_amd_offload}" = "xyes" ]; then + echo "${script_name}: Offloading to AMD accelerators will be considered." + echo "${script_name}: checking for ROCm installation and availability." + + # Make sure there's a ROCm installation present + # use rocm_agent_enumerator to see if there's a gfx != gfx000 + gfxs=`rocm_agent_enumerator` + if [ -z "$gfxs" ]; then + echo "${script_name}: rocm_agent_enumerator returns no agents." + enable_amd_offload='no' + elif [[ "$gfxs" =~ "gfx9" ]] || [[ "$gfxs" =~ "gfx10" ]]; then + echo "${script_name}: found AMD accelerator(s)." + enable_amd_offload_01=1 + else + echo "${script_name}: Illegal rocm_agent_enumerator output. $gfsx" + enable_amd_offload='no' + fi + else + echo "${script_name}: Offloading to AMD accelerators will not be considered." + fi + # Check if addons were given. if [ -n "${addon_flag}" ]; then @@ -3871,6 +3910,7 @@ main() | sed -e "s/@enable_blas@/${enable_blas}/g" \ | sed -e "s/@enable_cblas@/${enable_cblas}/g" \ | sed -e "s/@enable_amd_frame_tweaks@/${enable_amd_frame_tweaks}/g" \ + | sed -e "s/@enable_amd_offload@/${enable_amd_offload}/g" \ | sed -e "s/@enable_memkind@/${enable_memkind}/g" \ | sed -e "s/@pragma_omp_simd@/${pragma_omp_simd}/g" \ | sed -e "s/@addon_list@/${addon_list}/g" \ @@ -3910,6 +3950,7 @@ main() | sed -e "s/@enable_sandbox@/${enable_sandbox_01}/g" \ | sed -e "s/@enable_shared@/${enable_shared_01}/g" \ | sed -e "s/@complex_return_intel@/${complex_return_intel01}/g" \ + | sed -e "s/@enable_amd_offload@/${enable_amd_offload_01}/g" \ > "${bli_config_h_out_path}" # -- Instantiate bli_addon.h file from template ---------------------------- diff --git a/frame/3/bli_l3_oapi_ex.c b/frame/3/bli_l3_oapi_ex.c index 20b0294eb0..fdb8e828a5 100644 --- a/frame/3/bli_l3_oapi_ex.c +++ b/frame/3/bli_l3_oapi_ex.c @@ -33,6 +33,9 @@ */ #include "blis.h" +#ifdef BLIS_ENABLE_AMD_OFFLOAD +#include "../base/bli_offloader.h" +#endif // // Define object-based interfaces (expert). @@ -73,6 +76,21 @@ void PASTEMAC(gemm,BLIS_OAPI_EX_SUF) return; } } +#ifdef BLIS_ENABLE_AMD_OFFLOAD + // check if we should offload - since attempting to offload and fail + // incurrs a non-trivial cost, we only want to fail and fall through + // in rare cases + const bool do_offload = bli_do_offload_gemmex( alpha, a, b, beta, c); + if ( do_offload ) + { + // attempts to offload + const err_t result = bli_offload_gemmex( alpha, a, b, beta, c); + if ( result == BLIS_SUCCESS ) + { + return; + } + } +#endif // Initialize a local runtime with global settings if necessary. Note // that in the case that a runtime is passed in, we make a local copy. diff --git a/frame/base/bli_init.c b/frame/base/bli_init.c index f1baa2c217..3be7f87df2 100644 --- a/frame/base/bli_init.c +++ b/frame/base/bli_init.c @@ -34,6 +34,9 @@ */ #include "blis.h" +#ifdef BLIS_ENABLE_AMD_OFFLOAD +#include "bli_offloader.h" +#endif // ----------------------------------------------------------------------------- @@ -87,7 +90,11 @@ int bli_init_apis( void ) bli_pack_init(); bli_memsys_init(); - return 0; +#ifdef BLIS_ENABLE_AMD_OFFLOAD + bli_offloader_init(); +#endif + + return 0; } int bli_finalize_apis( void ) @@ -99,6 +106,10 @@ int bli_finalize_apis( void ) bli_ind_finalize(); bli_gks_finalize(); +#ifdef BLIS_ENABLE_AMD_OFFLOAD + bli_offloader_finalize(); +#endif + return 0; } diff --git a/frame/base/bli_offloader.c b/frame/base/bli_offloader.c new file mode 100644 index 0000000000..86f61cfc0a --- /dev/null +++ b/frame/base/bli_offloader.c @@ -0,0 +1,572 @@ +#include "blis.h" +#ifdef BLIS_ENABLE_AMD_OFFLOAD +#include "bli_offloader.h" +#include +#include +#include +#include +#include +#include + +// The global rntm_t structure. (The definition resides in bli_rntm.c.) +extern rntm_t global_rntm; + +// A mutex to allow synchronous access to global_rntm. (The definition +// resides in bli_rntm.c.) +extern bli_pthread_mutex_t global_rntm_mutex; + +void bli_offloader_init ( void ) +{ + bli_offloader_init_rntm_from_env ( &global_rntm ); +} + +void bli_offloader_init_rntm_from_env ( rntm_t* rntm ) +{ + // allocate struct + rntm->offloader_state = malloc ( sizeof ( offload_t ) ); + offload_t* config = rntm->offloader_state; + config->rocblas = NULL; + + char* s_eng = getenv ( "BLIS_OFFLOAD" ); + s_eng = ( s_eng == NULL ) ? "never" : s_eng; + if ( strcmp ( s_eng, "never" ) == 0 ) + { + fprintf ( stdout, "Never attempting to offload.\n" ); + config->never_offload_dgemm = true; + config->never_offload_sgemm = true; + config->never_offload_zgemm = true; + config->never_offload_cgemm = true; + config->offload_sgemm_thresh = LLONG_MAX; + config->offload_dgemm_thresh = LLONG_MAX; + config->offload_cgemm_thresh = LLONG_MAX; + config->offload_zgemm_thresh = LLONG_MAX; + return; + } + else if ( strcmp ( s_eng, "always" ) == 0 ) + { + fprintf ( stdout, "Always attempting to offload.\n" ); + config->never_offload_dgemm = false; + config->never_offload_sgemm = false; + config->never_offload_zgemm = false; + config->never_offload_cgemm = false; + config->offload_sgemm_thresh = 0; + config->offload_dgemm_thresh = 0; + config->offload_cgemm_thresh = 0; + config->offload_zgemm_thresh = 0; + // still initialize rocBLAS handle + } + else if ( strcmp ( s_eng, "threshold" ) == 0 ) + { + const char* s_sgemm = getenv ( "BLIS_OFFLOAD_SGEMM_THRESH" ); + const int64_t offload_after_s = ( s_sgemm == NULL ) ? LLONG_MAX : atol ( s_sgemm ); + config->offload_sgemm_thresh = offload_after_s; + + if ( offload_after_s == LLONG_MAX ) + { + fprintf ( stdout, "Never offloading sgemms.\n" ); + config->never_offload_sgemm = true; + } + else + { + fprintf ( stdout, "Offloading all sgemms with at least M*N >= %ld\n", offload_after_s ); + config->never_offload_sgemm = false; + } + + const char* s_dgemm = getenv ( "BLIS_OFFLOAD_DGEMM_THRESH" ); + const int64_t offload_after_d = ( s_dgemm == NULL ) ? LLONG_MAX : atol ( s_dgemm ); + config->offload_dgemm_thresh = offload_after_d; + + if ( offload_after_d == LLONG_MAX ) + { + fprintf ( stdout, "Never offloading dgemms.\n" ); + config->never_offload_dgemm = true; + } + else + { + fprintf ( stdout, "Offloading all dgemms with at least M*N >= %ld\n", offload_after_d ); + config->never_offload_dgemm = false; + } + + const char* s_cgemm = getenv ( "BLIS_OFFLOAD_CGEMM_THRESH" ); + const int64_t offload_after_c = ( s_sgemm == NULL ) ? LLONG_MAX : atol ( s_cgemm ); + config->offload_cgemm_thresh = offload_after_c; + + if ( offload_after_c == LLONG_MAX ) + { + fprintf ( stdout, "Never offloading cgemms.\n" ); + config->never_offload_cgemm = true; + } + else + { + fprintf ( stdout, "Offloading all cgemms with at least M*N >= %ld\n", offload_after_c ); + config->never_offload_cgemm = false; + } + + const char* s_zgemm = getenv ( "BLIS_OFFLOAD_ZGEMM_THRESH" ); + const int64_t offload_after_z = ( s_dgemm == NULL ) ? LLONG_MAX : atol ( s_zgemm ); + config->offload_zgemm_thresh = offload_after_z; + + if ( offload_after_z == LLONG_MAX ) + { + fprintf ( stdout, "Never offloading zgemms.\n" ); + config->never_offload_zgemm = true; + } + else + { + fprintf ( stdout, "Offloading all zgemms with at least M*N >= %ld\n", offload_after_z ); + config->never_offload_zgemm = false; + } + + // still initialize rocBLAS handle + } + else + { + fprintf ( stderr, "Unknown BLIS_OFFLOAD selection: %s . Offloading never.\n", s_eng ); + config->never_offload_dgemm = true; + config->never_offload_sgemm = true; + config->never_offload_zgemm = true; + config->never_offload_cgemm = true; + config->offload_sgemm_thresh = LLONG_MAX; + config->offload_dgemm_thresh = LLONG_MAX; + config->offload_cgemm_thresh = LLONG_MAX; + config->offload_zgemm_thresh = LLONG_MAX; + return; + } + + const rocblas_status stat = rocblas_create_handle ( & ( config->rocblas ) ); + if ( stat != rocblas_status_success ) + { + fprintf ( stderr, "Couldn't create rocBLAS handle w/ error %d\n", stat ); + } + const rocblas_status stat_p = rocblas_set_pointer_mode ( config->rocblas, + rocblas_pointer_mode_host ); + if ( stat_p != rocblas_status_success ) + { + fprintf ( stderr, "Couldn't set rocBLAS pointer mode to host w/ error %d\n", stat ); + } +} + +void bli_offloader_finalize ( void ) +{ + bli_offloader_finalize_rntm_from_env ( &global_rntm ); +} + +void bli_offloader_finalize_rntm_from_env ( rntm_t* rntm ) +{ + if ( rntm->offloader_state->rocblas != NULL ) + { + // just destroy rocblas handle + const rocblas_status stat = rocblas_destroy_handle ( rntm->offloader_state->rocblas ); + if ( stat != rocblas_status_success ) + { + fprintf ( stderr, "Couldn't destroy rocBLAS handle w/ error %d\n", stat ); + } + } + + // free struct itself + free ( rntm->offloader_state ); +} + +bool bli_do_offload_gemmex + ( + const obj_t* alpha, + const obj_t* a, + const obj_t* b, + const obj_t* beta, + const obj_t* c + ) +{ + return bli_do_offload_gemmex_rntm_from_env ( &global_rntm, alpha, a, b, beta, c ); +} + +bool bli_do_offload_gemmex_rntm_from_env + ( + rntm_t* rntm, + const obj_t* alpha, + const obj_t* a, + const obj_t* b, + const obj_t* beta, + const obj_t* c + ) +{ + + offload_t* config = rntm->offloader_state; + + // never offload anything + if ( config->never_offload_dgemm && config->never_offload_sgemm ) + { + return false; + } + + // figure out if C is integer and reject (for now) + // NOTE: rocBLAS supports f16, f16 cmpl, f32, f32 cmpl, f64, f64 cmpl, i8, u8, i32, + // i32 cmpl, u32 compl, bf16, bf16 cmpl as data type settings + // (not in all combinations) + if ( bli_obj_is_int ( a ) || bli_obj_is_int ( b ) || bli_obj_is_int ( c ) ) + { + return false; + } + + const inc_t rs_a = bli_obj_row_stride ( a ); + const inc_t rs_b = bli_obj_row_stride ( b ); + const inc_t rs_c = bli_obj_row_stride ( c ); + // do not offload if any row stride is != 1 (as rocBLAS only supports col strides) + if ( rs_a != 1 || rs_b != 1 || rs_c != 1 ) + { + return false; + } + + // figure out if the result matrix C's M*N is above or below the data type specific cutoff + const bool is_float_c = bli_obj_is_float ( c ); + const bool is_compl_c = bli_obj_is_complex ( c ); + if ( is_float_c && !is_compl_c && config->never_offload_sgemm ) + { + return false; + } + else if ( !is_float_c && !is_compl_c && config->never_offload_dgemm ) + { + return false; + } + else if ( is_float_c && is_compl_c && config->never_offload_cgemm ) + { + return false; + } + else if ( !is_float_c && is_compl_c && config->never_offload_zgemm ) + { + return false; + } + + const dim_t m_c = bli_obj_length ( c ); + const dim_t n_c = bli_obj_width ( c ); + const size_t mul = m_c * n_c; + + if ( !is_compl_c ) + { + return ( is_float_c ) ? ( mul >= config->offload_sgemm_thresh ) : ( mul >= config->offload_dgemm_thresh ); + } + else + { + // make sure we're not conjugate AND not transpose + if ( bli_obj_has_conj( a ) && !bli_obj_has_trans( a ) ) return false; + if ( bli_obj_has_conj( b ) && !bli_obj_has_trans( b ) ) return false; + + return ( bli_obj_is_scomplex( c ) ) ? ( mul >= config->offload_cgemm_thresh ) : ( mul >= config->offload_zgemm_thresh ); + } +} + + +err_t bli_offload_gemmex + ( + const obj_t* alpha, + const obj_t* a, + const obj_t* b, + const obj_t* beta, + const obj_t* c + ) +{ + return bli_offload_gemmex_rntm_from_env ( &global_rntm, alpha, a, b, beta, c ); + +} + +err_t bli_offload_gemmex_rntm_from_env + ( + rntm_t* rntm, + const obj_t* alpha, + const obj_t* a, + const obj_t* b, + const obj_t* beta, + const obj_t* c + ) +{ + + offload_t* config = rntm->offloader_state; + + // never offload anything + if ( config->never_offload_dgemm && config->never_offload_sgemm ) + { + return BLIS_FAILURE; + } + + // figure out if C is integer and reject + if ( bli_obj_is_int ( a ) || bli_obj_is_int ( b ) || bli_obj_is_int ( c ) ) + { + return BLIS_EXPECTED_NONINTEGER_DATATYPE; + } + + const inc_t rs_a = bli_obj_row_stride ( a ); + const inc_t rs_b = bli_obj_row_stride ( b ); + const inc_t rs_c = bli_obj_row_stride ( c ); + // do not offload if any row stride is != 1 (as rocBLAS only supports col strides) + if ( rs_a != 1 || rs_b != 1 || rs_c != 1 ) + { + return BLIS_INVALID_ROW_STRIDE; + } + + // are any of the matrices complex + const bool is_compl_a = bli_obj_is_complex ( a ); + const bool is_compl_b = bli_obj_is_complex ( b ); + const bool is_compl_c = bli_obj_is_complex ( c ); + + // figure out if the result matrix C's M*N is above or below the data type specific cutoff + const bool is_float_a = bli_obj_is_float ( a ); + const bool is_float_b = bli_obj_is_float ( b ); + const bool is_float_c = bli_obj_is_float ( c ); + if ( is_float_c && config->never_offload_sgemm ) + { + return BLIS_FAILURE; + } + else if ( !is_float_c && config->never_offload_dgemm ) + { + return BLIS_FAILURE; + } + + const inc_t lda = bli_obj_col_stride ( a ); + const inc_t ldb = bli_obj_col_stride ( b ); + const inc_t ldc = bli_obj_col_stride ( c ); + const dim_t m_a = bli_obj_length ( a ); + const dim_t n_a = bli_obj_width ( a ); + // const dim_t m_b = bli_obj_length ( b ); + const dim_t n_b = bli_obj_width ( b ); + const dim_t m_c = bli_obj_length ( c ); + const dim_t n_c = bli_obj_width ( c ); + const size_t mul = m_c * n_c; + + bool should_offload; + if ( !is_compl_c ) + { + should_offload = ( is_float_c ) ? ( mul >= config->offload_sgemm_thresh ) : ( mul >= config->offload_dgemm_thresh ); + } + else + { + // make sure we're not conjugate AND not transpose + if ( bli_obj_has_conj( a ) && !bli_obj_has_trans( a ) ) should_offload = false; + else if ( bli_obj_has_conj( b ) && !bli_obj_has_trans( b ) ) should_offload = false; + else should_offload = ( bli_obj_is_scomplex( c ) ) ? ( mul >= config->offload_cgemm_thresh ) : ( mul >= config->offload_zgemm_thresh ); + } + if ( !should_offload ) + { + return BLIS_NONCONFORMAL_DIMENSIONS; + } + + // we should offload: gather some dimensions and pointers + void *A = bli_obj_buffer_at_off ( a ); // pointer to elements of Matrix A + void *B = bli_obj_buffer_at_off ( b ); // pointer to elements of Matrix B + void *C = bli_obj_buffer_at_off ( c ); // pointer to elements of Matrix C + + const bool is_trans_a = bli_obj_has_trans ( a ); + const bool is_trans_b = bli_obj_has_trans ( b ); + + const size_t buff_size_a = lda * n_a * bli_obj_elem_size ( a ); + const size_t buff_size_b = ldb * n_b * bli_obj_elem_size ( b ); + const size_t buff_size_c = ldc * n_c * bli_obj_elem_size ( c ); + + // inspect pointers for memory location of buffers + hipPointerAttribute_t attr; + const hipError_t err_insp_a = hipPointerGetAttributes(&attr, A); + bool copy_a = true; + if ( err_insp_a == hipSuccess ) + { + copy_a = ( attr.memoryType != hipMemoryTypeDevice ); + } + const hipError_t err_insp_b = hipPointerGetAttributes(&attr, B); + bool copy_b = true; + if ( err_insp_b == hipSuccess ) + { + copy_b = ( attr.memoryType != hipMemoryTypeDevice ); + } + const hipError_t err_insp_c = hipPointerGetAttributes(&attr, C); + bool copy_c = true; + if ( err_insp_c == hipSuccess ) + { + copy_c = ( attr.memoryType != hipMemoryTypeDevice ); + } + + // if applicable: allocate buffers on device and copy data + // note: we cannot assume the CPU buffers to be pinned and hence most likely the copies will be synchronous + void* dev_buff_a; + void* dev_buff_b; + void* dev_buff_c; + + hipStream_t stream; + rocblas_get_stream( config->rocblas, &stream ); + + if ( copy_a ) + { + const hipError_t err_a = hipMalloc ( &dev_buff_a, buff_size_a ); + if ( err_a != hipSuccess ) + { + fprintf ( stderr, "Failure to allocate device buffer A of size %ld: %d\n", buff_size_a, err_a ); + return BLIS_FAILURE; + } + const hipError_t err_cpa = hipMemcpy ( dev_buff_a, A, buff_size_a, hipMemcpyHostToDevice ); + if ( err_cpa != hipSuccess ) + { + fprintf ( stderr, "Failure to hipMemcpy A to device: %d\n", err_cpa ); + return BLIS_FAILURE; + } + } + else + { + dev_buff_a = A; + } + + if ( copy_b ) + { + const hipError_t err_b = hipMalloc ( &dev_buff_b, buff_size_b ); + if ( err_b != hipSuccess ) + { + fprintf ( stderr, "Failure to allocate device buffer B of size %ld: %d\n", buff_size_b, err_b ); + return BLIS_FAILURE; + } + const hipError_t err_cpb = hipMemcpy ( dev_buff_b, B, buff_size_b, hipMemcpyHostToDevice ); + if ( err_cpb != hipSuccess ) + { + fprintf ( stderr, "Failure to hipMemcpy B to device: %d\n", err_cpb ); + return BLIS_FAILURE; + } + + } + else + { + dev_buff_b = B; + } + + if ( copy_c ) + { + const hipError_t err_c = hipMalloc ( &dev_buff_c, buff_size_c ); + if ( err_c != hipSuccess ) + { + fprintf ( stderr, "Failure to allocate device buffer C of size %ld: %d\n", buff_size_c, err_c ); + return BLIS_FAILURE; + } + + // is beta zero? + const bool is_beta_non_zero = !bli_obj_equals ( beta, &BLIS_ZERO ); + + if ( is_beta_non_zero || ldc != m_c ) // only if the result buffer is m*n sized AND beta == 0.0 we can eschew the copy + { + const hipError_t err_cpc = hipMemcpy ( dev_buff_c, C, buff_size_c, hipMemcpyHostToDevice ); + if ( err_cpc != hipSuccess ) + { + fprintf ( stderr, "Failure to hipMemcpy C to device: %d\n", err_cpc ); + return BLIS_FAILURE; + } + } + } + + // call rocblas + rocblas_operation trans_a = is_trans_a ? rocblas_operation_transpose : rocblas_operation_none; + rocblas_operation trans_b = is_trans_b ? rocblas_operation_transpose : rocblas_operation_none; + if ( is_compl_a && bli_obj_has_conj( a ) ) + trans_a = rocblas_operation_conjugate_transpose; + if ( is_compl_b && bli_obj_has_conj( b ) ) + trans_b = rocblas_operation_conjugate_transpose; + + rocblas_datatype a_type; + rocblas_datatype b_type; + rocblas_datatype c_type; + if ( is_compl_a ) + a_type = ( bli_obj_is_scomplex( a ) ) ? rocblas_datatype_f32_c : rocblas_datatype_f64_c; + else + a_type = ( is_float_a ) ? rocblas_datatype_f32_r : rocblas_datatype_f64_r; + if ( is_compl_b ) + b_type = ( bli_obj_is_scomplex( b ) ) ? rocblas_datatype_f32_c : rocblas_datatype_f64_c; + else + b_type = ( is_float_b ) ? rocblas_datatype_f32_r : rocblas_datatype_f64_r; + if ( is_compl_c ) + c_type = ( bli_obj_is_scomplex( c ) ) ? rocblas_datatype_f32_c : rocblas_datatype_f64_c; + else + c_type = ( is_float_c ) ? rocblas_datatype_f32_r : rocblas_datatype_f64_r; + + rocblas_datatype compute_type; + if ( !is_compl_a && !is_compl_b && !is_compl_c) + compute_type = ( is_float_a && is_float_b && is_float_c ) ? rocblas_datatype_f32_r : rocblas_datatype_f64_r; + else + compute_type = ( bli_obj_is_scomplex( a ) && bli_obj_is_scomplex( b ) && bli_obj_is_scomplex( c ) ) ? rocblas_datatype_f32_c : rocblas_datatype_f64_c; + + const num_t dt_exec = bli_obj_dt ( c ); + void* restrict alpha_f = bli_obj_buffer_for_1x1 ( dt_exec, alpha ); + void* restrict beta_f = bli_obj_buffer_for_1x1 ( dt_exec, beta ); + + + const size_t k = is_trans_a ? m_a : n_a; + const rocblas_gemm_algo algo = rocblas_gemm_algo_standard; + const int32_t solution_index = 0; + const uint32_t flags = 0; + const rocblas_status roc_err = rocblas_gemm_ex ( config->rocblas, + trans_a, + trans_b, + m_c, + n_c, + k, + alpha_f, + dev_buff_a, + a_type, + lda, + dev_buff_b, + b_type, + ldb, + beta_f, + dev_buff_c, + c_type, + ldc, + dev_buff_c, + c_type, + ldc, + compute_type, + algo, + solution_index, + flags ); + if ( roc_err != rocblas_status_success ) + { + fprintf ( stderr, "Failure to call rocblas_dgemm: %d\n", roc_err ); + return BLIS_FAILURE; + } + + // if applicable: free intermediate buffers + if ( copy_a ) + { + const hipError_t err_fa = hipFree ( dev_buff_a ); + if ( err_fa != hipSuccess ) + { + fprintf ( stderr, "Failure to free device buffer A: %d\n", err_fa ); + return BLIS_FAILURE; + } + } + if ( copy_b ) + { + const hipError_t err_fb = hipFree ( dev_buff_b ); + if ( err_fb != hipSuccess ) + { + fprintf ( stderr, "Failure to free device buffer B: %d\n", err_fb ); + return BLIS_FAILURE; + } + } + + if ( copy_c ) + { + // copy result back synchronously + const hipError_t err_cpr = hipMemcpy ( C, dev_buff_c, buff_size_c, hipMemcpyDeviceToHost ); + if ( err_cpr != hipSuccess ) + { + fprintf ( stderr, "Failure to hipMemcpy C from device: %d\n", err_cpr ); + return BLIS_FAILURE; + } + // free + const hipError_t err_fc = hipFree ( dev_buff_c ); + if ( err_fc != hipSuccess ) + { + fprintf ( stderr, "Failure to free device buffer C: %d\n", err_fc ); + return BLIS_FAILURE; + } + } + else + { + // only synchronize on the rocBLAS stream to ensure data correctness + hipStreamSynchronize( stream ); + } + + return BLIS_SUCCESS; +} + +#endif diff --git a/frame/base/bli_offloader.h b/frame/base/bli_offloader.h new file mode 100644 index 0000000000..10124f94f1 --- /dev/null +++ b/frame/base/bli_offloader.h @@ -0,0 +1,55 @@ +#ifdef BLIS_ENABLE_AMD_OFFLOAD +#ifndef BLI_OFFLOADER_H +#define BLI_OFFLOADER_H +#include +#include +#include "blis.h" + +void bli_offloader_init ( void ); + +void bli_offloader_init_rntm_from_env ( rntm_t* rntm ); + +void bli_offloader_finalize ( void ); + +void bli_offloader_finalize_rntm_from_env ( rntm_t* rntm ); + +bool bli_do_offload_gemmex + ( + const obj_t* alpha, + const obj_t* a, + const obj_t* b, + const obj_t* beta, + const obj_t* c + ); + +bool bli_do_offload_gemmex_rntm_from_env + ( + rntm_t* rntm, + const obj_t* alpha, + const obj_t* a, + const obj_t* b, + const obj_t* beta, + const obj_t* c + ); + +err_t bli_offload_gemmex + ( + const obj_t* alpha, + const obj_t* a, + const obj_t* b, + const obj_t* beta, + const obj_t* c + ); + +err_t bli_offload_gemmex_rntm_from_env + ( + rntm_t* rntm, + const obj_t* alpha, + const obj_t* a, + const obj_t* b, + const obj_t* beta, + const obj_t* c + ); + +#endif // BLI_OFFLOADER_H +#endif // BLIS_ENABLE_AMD_OFFLOAD diff --git a/frame/base/bli_rntm.h b/frame/base/bli_rntm.h index 2a39f8894c..fe68259bbe 100644 --- a/frame/base/bli_rntm.h +++ b/frame/base/bli_rntm.h @@ -53,7 +53,9 @@ typedef struct rntm_s pool_t* sba_pool; pba_t* pba; - +#ifdef BLIS_ENABLE_AMD_OFFLOAD + bli_offloader_t* offloader_state; +#endif } rntm_t; */ @@ -304,6 +306,21 @@ BLIS_INLINE void bli_rntm_clear_l3_sup( rntm_t* rntm ) // of the public "set" accessors, each of which guarantees that the rntm_t // will be in a good state upon return. +#ifndef BLIS_ENABLE_AMD_OFFLOAD +#define BLIS_RNTM_INITIALIZER \ + { \ + .auto_factor = TRUE, \ + .num_threads = -1, \ + .thrloop = { -1, -1, -1, -1, -1, -1 }, \ + .pack_a = FALSE, \ + .pack_b = FALSE, \ + .l3_sup = TRUE, \ + .sba_pool = NULL, \ + .pba = NULL, \ + } \ + +#else + #define BLIS_RNTM_INITIALIZER \ { \ .auto_factor = TRUE, \ @@ -314,8 +331,11 @@ BLIS_INLINE void bli_rntm_clear_l3_sup( rntm_t* rntm ) .l3_sup = TRUE, \ .sba_pool = NULL, \ .pba = NULL, \ + .offloader_state = NULL, \ } \ +#endif + BLIS_INLINE void bli_rntm_init( rntm_t* rntm ) { bli_rntm_set_auto_factor_only( TRUE, rntm ); diff --git a/frame/include/bli_type_defs.h b/frame/include/bli_type_defs.h index 08c7ddc4a6..783b4ab1e1 100644 --- a/frame/include/bli_type_defs.h +++ b/frame/include/bli_type_defs.h @@ -1035,6 +1035,22 @@ typedef struct pba_s } pba_t; +// -- optional: offloader state -- + +#ifdef BLIS_ENABLE_AMD_OFFLOAD +typedef struct offload_s +{ + bool never_offload_dgemm; + bool never_offload_sgemm; + bool never_offload_zgemm; + bool never_offload_cgemm; + struct _rocblas_handle* rocblas; + int64_t offload_sgemm_thresh; + int64_t offload_dgemm_thresh; + int64_t offload_cgemm_thresh; + int64_t offload_zgemm_thresh; +} offload_t; +#endif // -- Memory object type -- @@ -1449,6 +1465,11 @@ typedef struct rntm_s // The packing block allocator, which is attached in the l3 thread decorator. pba_t* pba; +#ifdef BLIS_ENABLE_AMD_OFFLOAD + // if offloading is enabled - this contains the offloader state + offload_t* offloader_state; +#endif + } rntm_t;