diff --git a/NeoMathEngine/CMakeLists.txt b/NeoMathEngine/CMakeLists.txt index 5fc1fd705..17cfc83c0 100644 --- a/NeoMathEngine/CMakeLists.txt +++ b/NeoMathEngine/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required(VERSION 3.11 FATAL_ERROR) -project(NeoMathEngine LANGUAGES CXX) +project(NeoMathEngine LANGUAGES CXX C ASM) list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/../cmake) if(USE_FINE_OBJECTS) diff --git a/NeoMathEngine/src/CMakeLists.txt b/NeoMathEngine/src/CMakeLists.txt index cb9bfa87f..742b2eb40 100644 --- a/NeoMathEngine/src/CMakeLists.txt +++ b/NeoMathEngine/src/CMakeLists.txt @@ -170,7 +170,22 @@ if((DARWIN AND BUILD_ARCH MATCHES "^arm64.*") OR (ANDROID AND ANDROID_ABI MATCHE target_include_directories(${PROJECT_NAME} PRIVATE $) else() message(STATUS "USE X86 SOURCES") + if(CMAKE_SIZEOF_VOID_P EQUAL 8) # x64 + if(WIN32) + set(CPU_X86_ASM_SOURCES + CPU/x86/CpuX86FmaCount/fma_shuffle_tpt.asm + CPU/x86/CpuX86FmaCount/fma_only_tpt.asm) + else() #if(CMAKE_CXX_COMPILER_ID MATCHES Clang OR CMAKE_CXX_COMPILER_ID MATCHES GNU OR CMAKE_CXX_COMPILER_ID MATCHES AppleClang) + target_compile_options(${PROJECT_NAME} PRIVATE + $<$:-x$assembler-with-cpp>) + set(CPU_X86_ASM_SOURCES + CPU/x86/CpuX86FmaCount/fma_shuffle_tpt.s + CPU/x86/CpuX86FmaCount/fma_only_tpt.s) + endif() + endif() + set(CPU_X86_SOURCES + CPU/x86/CpuX86FmaCount.cpp CPU/x86/CpuX86MathEngineBlas.cpp CPU/x86/CpuX86MathEngineBlasMkl.cpp CPU/x86/CpuX86MathEngineDnn.cpp @@ -180,6 +195,7 @@ else() target_sources(${PROJECT_NAME} PRIVATE ${CPU_X86_SOURCES} + ${CPU_X86_ASM_SOURCES} CPU/x86/CpuX86.h CPU/x86/CpuX86Functors.h CPU/x86/CpuX86MathEngineVectorMathPrivate.h @@ -202,6 +218,20 @@ else() set_property(SOURCE ${CPU_AVX_SOURCES} PROPERTY COMPILE_OPTIONS $<$:-mavx2 -mfma>) endif() + set(CPU_AVX512_SOURCES + CPU/x86/avx512/Avx512VectorFunctions.cpp + ) + target_sources(${PROJECT_NAME} PRIVATE + ${CPU_AVX512_SOURCES} + CPU/x86/avx512/Avx512Functions.h + ) + set_property(SOURCE ${CPU_AVX512_SOURCES} PROPERTY UNITY_GROUP 3) + if(WIN32) + set_property(SOURCE ${CPU_AVX512_SOURCES} PROPERTY COMPILE_OPTIONS /arch:AVX512) + else() + set_property(SOURCE ${CPU_AVX512_SOURCES} PROPERTY COMPILE_OPTIONS $<$:-mfma -mavx -mavx2 -mavx512f -mavx512dq -mavx512vl>) + endif() + if(NEOML_USE_AVX) target_sources(${PROJECT_NAME} PRIVATE diff --git a/NeoMathEngine/src/CPU/CPUInfo.h b/NeoMathEngine/src/CPU/CPUInfo.h index d7ed87ea4..444d6fae3 100644 --- a/NeoMathEngine/src/CPU/CPUInfo.h +++ b/NeoMathEngine/src/CPU/CPUInfo.h @@ -1,4 +1,4 @@ -/* Copyright © 2017-2023 ABBYY +/* Copyright © 2017-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -21,15 +21,22 @@ limitations under the License. #include #elif FINE_PLATFORM( FINE_WINDOWS ) #include -#endif +#endif // FINE_PLATFORM -#endif // !FINE_ARCHITECTURE( FINE_ARM64 ) +#endif // !FINE_ARCHITECTURE( ARM ) #include +#if FINE_ARCHITECTURE( FINE_X64 ) +// Intel X86/X64 optimization manual +// https://github.com/intel/optimization-manual/tree/main/chap18/ex25 +int Avx512FmaUnitCount(); +#else // !x64 +inline int Avx512FmaUnitCount() { return 0; } +#endif // !x64 // The structure with CPU information -struct CCPUInfo { +struct CCPUInfo final { enum class TCpuArch { Intel, AMD, @@ -117,7 +124,7 @@ struct CCPUInfo { { #ifdef NEOML_USE_NEON return 4; -#else +#else // !NEOML_USE_NEON int floatAlignment = 4; // SSE alignment Regs regs; @@ -138,15 +145,20 @@ struct CCPUInfo { } #elif FINE_PLATFORM(FINE_LINUX) || FINE_PLATFORM(FINE_DARWIN) || FINE_PLATFORM(FINE_ANDROID) || FINE_PLATFORM(FINE_IOS) floatAlignment = 8; -#else +#else // ERROR FINE_PLATFORM #error "Platform isn't supported!" -#endif +#endif // ERROR FINE_PLATFORM + } + if( HasAvx512And2Fma ) { + floatAlignment = 16; } return floatAlignment; -#endif // NEOML_USE_NEON +#endif // !NEOML_USE_NEON } + static const bool NEOMATHENGINE_API HasAvx512; + static const bool HasAvx512And2Fma; static const bool HasAvxAndFma; static const bool IsNotIntel; @@ -184,9 +196,9 @@ struct CCPUInfo { #if FINE_PLATFORM(FINE_WINDOWS) typedef int RegType; -#else +#else // !FINE_WINDOWS typedef unsigned int RegType; -#endif +#endif // !FINE_WINDOWS struct Regs { RegType eax; RegType ebx; @@ -201,12 +213,12 @@ struct CCPUInfo { __cpuid( ( RegType* )( &outRegs ), eax ); #elif FINE_PLATFORM( FINE_LINUX ) || FINE_PLATFORM( FINE_DARWIN ) __get_cpuid( eax, &outRegs.eax, &outRegs.ebx, &outRegs.ecx, &outRegs.edx ); -#else +#else // ERROR FINE_PLATFORM ( void ) eax; -#endif -#else +#endif // ERROR FINE_PLATFORM +#else // ERROR FINE_ARCHITECTURE ( void ) eax; -#endif // !FINE_ARCHITECTURE( FINE_ARM64 ) +#endif // ERROR FINE_ARCHITECTURE } static void callCpuIdEx( Regs& outRegs, const RegType& eax, const RegType& ecx ) { @@ -216,14 +228,14 @@ struct CCPUInfo { __cpuidex((RegType*)( &outRegs ), eax, ecx ); #elif FINE_PLATFORM( FINE_LINUX ) || FINE_PLATFORM( FINE_DARWIN ) __cpuid_count( eax, ecx, outRegs.eax, outRegs.ebx, outRegs.ecx, outRegs.edx ); -#else +#else // ERROR FINE_PLATFORM ( void ) eax; ( void ) ecx; -#endif -#else +#endif // ERROR FINE_PLATFORM +#else // ERROR FINE_ARCHITECTURE ( void ) eax; ( void ) ecx; -#endif // !FINE_ARCHITECTURE( FINE_ARM64 ) +#endif // ERROR FINE_ARCHITECTURE } }; diff --git a/NeoMathEngine/src/CPU/CpuMathEngine.cpp b/NeoMathEngine/src/CPU/CpuMathEngine.cpp index 15933d1be..e41dfc672 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngine.cpp +++ b/NeoMathEngine/src/CPU/CpuMathEngine.cpp @@ -1,4 +1,4 @@ -/* Copyright © 2017-2023 ABBYY +/* Copyright © 2017-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -43,6 +43,8 @@ limitations under the License. #endif // NEOML_USE_MKL +const bool NEOMATHENGINE_API CCPUInfo::HasAvx512 = CCPUInfo::IsAvx512Available(); +const bool CCPUInfo::HasAvx512And2Fma = CCPUInfo::HasAvx512 && Avx512FmaUnitCount() > 1; const bool CCPUInfo::HasAvxAndFma = CCPUInfo::IsAvxAndFmaAvailable(); const bool CCPUInfo::IsNotIntel = CCPUInfo::GetCpuArch() != CCPUInfo::TCpuArch::Intel; diff --git a/NeoMathEngine/src/CPU/x86/CpuX86.h b/NeoMathEngine/src/CPU/x86/CpuX86.h index 3f0e19259..ac4cdd90d 100644 --- a/NeoMathEngine/src/CPU/x86/CpuX86.h +++ b/NeoMathEngine/src/CPU/x86/CpuX86.h @@ -1,4 +1,4 @@ -/* Copyright © 2017-2023 ABBYY +/* Copyright © 2017-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -35,6 +35,7 @@ limitations under the License. #include #include "avx2/Avx2Functions.h" +#include "avx512/Avx512Functions.h" #include "../CPUInfo.h" namespace NeoML { @@ -341,7 +342,10 @@ inline void dataCopy(float* dst, const float* src, int vectorSize) { static_assert( sizeof(float) == sizeof(unsigned int), "Size of float isn't equal to size of unsigned int." ); - if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { + if( CCPUInfo::HasAvx512And2Fma && vectorSize >= NeoML::Avx512::VectorMathMinSize ) { + NeoML::Avx512::dataCopy( dst, src, vectorSize ); + return; + } else if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { NeoML::Avx2::dataCopy( dst, src, vectorSize ); return; } diff --git a/NeoMathEngine/src/CPU/x86/CpuX86FmaCount.cpp b/NeoMathEngine/src/CPU/x86/CpuX86FmaCount.cpp new file mode 100644 index 000000000..b719c63cb --- /dev/null +++ b/NeoMathEngine/src/CPU/x86/CpuX86FmaCount.cpp @@ -0,0 +1,111 @@ +/* Copyright © 2023-2024 ABBYY + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--------------------------------------------------------------------------------------------------------------*/ + +#include +#pragma hdrstop + +#include +#include + +#if FINE_ARCHITECTURE( FINE_X64 ) + +#include +#ifdef _MSC_VER +#include +#else +#include +#endif +#include + +#ifdef __cplusplus +extern "C" { +#endif + +void fma_shuffle_tpt( uint64_t loop_cnt ); +void fma_only_tpt( uint64_t loop_cnt ); + +int64_t rdtsc( void ) +{ + return __rdtsc(); +} + +int fma_unit_count( void ) +{ + int i; + uint64_t fma_shuf_tpt_test[3]; + uint64_t fma_shuf_tpt_test_min; + uint64_t fma_only_tpt_test[3]; + uint64_t fma_only_tpt_test_min; + uint64_t start = 0; + int number_of_fma_units_per_core = 2; + + /*********************************************************/ + /* Step 1: Warmup */ + /*********************************************************/ + + fma_only_tpt( 100000 ); + + /*********************************************************/ + /* Step 2: Execute FMA and Shuffle TPT Test */ + /*********************************************************/ + for( i = 0; i < 3; ++i ) { + start = rdtsc(); + fma_shuffle_tpt( 1000 ); + fma_shuf_tpt_test[i] = rdtsc() - start; + } + + /*********************************************************/ + /* Step 3: Execute FMA only TPT Test */ + /*********************************************************/ + for( i = 0; i < 3; ++i ) { + start = rdtsc(); + fma_only_tpt( 1000 ); + fma_only_tpt_test[i] = rdtsc() - start; + } + + /*********************************************************/ + /* Step 4: Decide if 1 FMA server or 2 FMA server */ + /*********************************************************/ + fma_shuf_tpt_test_min = fma_shuf_tpt_test[0]; + fma_only_tpt_test_min = fma_only_tpt_test[0]; + for( i = 1; i < 3; ++i ) { + if( (int)fma_shuf_tpt_test[i] < (int)fma_shuf_tpt_test_min ) { + fma_shuf_tpt_test_min = fma_shuf_tpt_test[i]; + } + if( (int)fma_only_tpt_test[i] < (int)fma_only_tpt_test_min ) { + fma_only_tpt_test_min = fma_only_tpt_test[i]; + } + } + + if( ( double( fma_shuf_tpt_test_min ) / fma_only_tpt_test_min ) < 1.5 ) { + number_of_fma_units_per_core = 1; + } + + printf( " *** x64 AVX512 %d FMA units per core *** \n", number_of_fma_units_per_core ); + return number_of_fma_units_per_core; +} + +#ifdef __cplusplus +} +#endif + +//------------------------------------------------------------------------------------------------- + +int Avx512FmaUnitCount() +{ + return fma_unit_count(); +} + +#endif // x64 diff --git a/NeoMathEngine/src/CPU/x86/CpuX86FmaCount/fma_only_tpt.asm b/NeoMathEngine/src/CPU/x86/CpuX86FmaCount/fma_only_tpt.asm new file mode 100644 index 000000000..bf853377f --- /dev/null +++ b/NeoMathEngine/src/CPU/x86/CpuX86FmaCount/fma_only_tpt.asm @@ -0,0 +1,84 @@ +; +; Copyright (C) 2021 by Intel Corporation +; +; Permission to use, copy, modify, and/or distribute this software for any +; purpose with or without fee is hereby granted. +; +; THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +; REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +; AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +; INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +; LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR +; OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +; PERFORMANCE OF THIS SOFTWARE. +; + + +; .globl fma_only_tpt + + ; void fma_only_tpt(uint64_t loop_cnt); + ; On entry: + ; rcx = loop_cnt + +_RDATA SEGMENT READ ALIGN(64) 'DATA' + +one_vec REAL8 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 + +_RDATA ENDS + + +.code +fma_only_tpt PROC public + + mov rdx, rsp + and rsp, -10h + sub rsp, 96 + vmovaps xmmword ptr[rsp], xmm6 + vmovaps xmmword ptr[rsp+16], xmm7 + vmovaps xmmword ptr[rsp+32], xmm8 + vmovaps xmmword ptr[rsp+48], xmm9 + vmovaps xmmword ptr[rsp+64], xmm10 + vmovaps xmmword ptr[rsp+80], xmm11 + + vmovups zmm0, ZMMWORD PTR one_vec + vmovups zmm1, ZMMWORD PTR one_vec + vmovups zmm2, ZMMWORD PTR one_vec + vmovups zmm3, ZMMWORD PTR one_vec + vmovups zmm4, ZMMWORD PTR one_vec + vmovups zmm5, ZMMWORD PTR one_vec + vmovups zmm6, ZMMWORD PTR one_vec + vmovups zmm7, ZMMWORD PTR one_vec + vmovups zmm8, ZMMWORD PTR one_vec + vmovups zmm9, ZMMWORD PTR one_vec + vmovups zmm10, ZMMWORD PTR one_vec + vmovups zmm11, ZMMWORD PTR one_vec + ; mov rcx, loops +loop1: + vfmadd231pd zmm0, zmm0, zmm0 + vfmadd231pd zmm1, zmm1, zmm1 + vfmadd231pd zmm2, zmm2, zmm2 + vfmadd231pd zmm3, zmm3, zmm3 + vfmadd231pd zmm4, zmm4, zmm4 + vfmadd231pd zmm5, zmm5, zmm5 + vfmadd231pd zmm6, zmm6, zmm6 + vfmadd231pd zmm7, zmm7, zmm7 + vfmadd231pd zmm8, zmm8, zmm8 + vfmadd231pd zmm9, zmm9, zmm9 + vfmadd231pd zmm10, zmm10, zmm10 + vfmadd231pd zmm11, zmm11, zmm11 + dec rcx + jg loop1 + + vzeroupper + + vmovaps xmm6, xmmword ptr[rsp] + vmovaps xmm7, xmmword ptr[rsp+16] + vmovaps xmm8, xmmword ptr[rsp+32] + vmovaps xmm9, xmmword ptr[rsp+48] + vmovaps xmm10, xmmword ptr[rsp+64] + vmovaps xmm11, xmmword ptr[rsp+80] + mov rsp, rdx + + ret +fma_only_tpt ENDP +end \ No newline at end of file diff --git a/NeoMathEngine/src/CPU/x86/CpuX86FmaCount/fma_only_tpt.s b/NeoMathEngine/src/CPU/x86/CpuX86FmaCount/fma_only_tpt.s new file mode 100644 index 000000000..59ca9de65 --- /dev/null +++ b/NeoMathEngine/src/CPU/x86/CpuX86FmaCount/fma_only_tpt.s @@ -0,0 +1,74 @@ +# +# Copyright (C) 2021 by Intel Corporation +# +# Permission to use, copy, modify, and/or distribute this software for any +# purpose with or without fee is hereby granted. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +# REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +# AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +# LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR +# OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +# PERFORMANCE OF THIS SOFTWARE. +# + + .intel_syntax noprefix + + .globl _fma_only_tpt + .globl fma_only_tpt + + # void fma_only_tpt(uint64_t loop_cnt); + # On entry: + # rdi = loop_cnt + + .text + +_fma_only_tpt: +fma_only_tpt: + + vmovups zmm0, one_vec[rip] + vmovups zmm1, one_vec[rip] + vmovups zmm2, one_vec[rip] + vmovups zmm3, one_vec[rip] + vmovups zmm4, one_vec[rip] + vmovups zmm5, one_vec[rip] + vmovups zmm6, one_vec[rip] + vmovups zmm7, one_vec[rip] + vmovups zmm8, one_vec[rip] + vmovups zmm9, one_vec[rip] + vmovups zmm10, one_vec[rip] + vmovups zmm11, one_vec[rip] + mov rdx, rdi # mov rdx, loops +loop1: + vfmadd231pd zmm0, zmm0, zmm0 + vfmadd231pd zmm1, zmm1, zmm1 + vfmadd231pd zmm2, zmm2, zmm2 + vfmadd231pd zmm3, zmm3, zmm3 + vfmadd231pd zmm4, zmm4, zmm4 + vfmadd231pd zmm5, zmm5, zmm5 + vfmadd231pd zmm6, zmm6, zmm6 + vfmadd231pd zmm7, zmm7, zmm7 + vfmadd231pd zmm8, zmm8, zmm8 + vfmadd231pd zmm9, zmm9, zmm9 + vfmadd231pd zmm10, zmm10, zmm10 + vfmadd231pd zmm11, zmm11, zmm11 + dec rdx + jg loop1 + + vzeroupper + + ret + +#ifdef __APPLE__ + .section __TEXT,__const +#else + .section .rodata +#endif + .p2align 6 +one_vec: + .double 1, 1, 1, 1, 1, 1, 1, 1 + +#if defined(__linux__) && defined(__ELF__) +.section .note.GNU-stack,"",%progbits +#endif diff --git a/NeoMathEngine/src/CPU/x86/CpuX86FmaCount/fma_shuffle_tpt.asm b/NeoMathEngine/src/CPU/x86/CpuX86FmaCount/fma_shuffle_tpt.asm new file mode 100644 index 000000000..a96114a26 --- /dev/null +++ b/NeoMathEngine/src/CPU/x86/CpuX86FmaCount/fma_shuffle_tpt.asm @@ -0,0 +1,120 @@ +; +; Copyright (C) 2021 by Intel Corporation +; +; Permission to use, copy, modify, and/or distribute this software for any +; purpose with or without fee is hereby granted. +; +; THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +; REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +; AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +; INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +; LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR +; OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +; PERFORMANCE OF THIS SOFTWARE. +; + + +; .globl fma_shuffle_tpt + + ; void fma_shuffle_tpt(uint64_t loop_cnt); + ; On entry: + ; rcx = loop_cnt + + +_RDATA SEGMENT READ ALIGN(64) 'DATA' + +one_vec REAL8 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0 +shuf_vec DD 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + +_RDATA ENDS + + +.code +fma_shuffle_tpt PROC public + + mov rdx, rsp + and rsp, -10h + sub rsp, 160 + vmovaps xmmword ptr[rsp], xmm6 + vmovaps xmmword ptr[rsp+16], xmm7 + vmovaps xmmword ptr[rsp+32], xmm8 + vmovaps xmmword ptr[rsp+48], xmm9 + vmovaps xmmword ptr[rsp+64], xmm10 + vmovaps xmmword ptr[rsp+80], xmm11 + vmovaps xmmword ptr[rsp+96], xmm12 + vmovaps xmmword ptr[rsp+112], xmm13 + vmovaps xmmword ptr[rsp+128], xmm14 + vmovaps xmmword ptr[rsp+144], xmm15 + + vmovups zmm0, ZMMWORD PTR one_vec + vmovups zmm1, ZMMWORD PTR one_vec + vmovups zmm2, ZMMWORD PTR one_vec + vmovups zmm3, ZMMWORD PTR one_vec + vmovups zmm4, ZMMWORD PTR one_vec + vmovups zmm5, ZMMWORD PTR one_vec + vmovups zmm6, ZMMWORD PTR one_vec + vmovups zmm7, ZMMWORD PTR one_vec + vmovups zmm8, ZMMWORD PTR one_vec + vmovups zmm9, ZMMWORD PTR one_vec + vmovups zmm10, ZMMWORD PTR one_vec + vmovups zmm11, ZMMWORD PTR one_vec + vmovups zmm12, ZMMWORD PTR shuf_vec + vmovups zmm13, ZMMWORD PTR shuf_vec + vmovups zmm14, ZMMWORD PTR shuf_vec + vmovups zmm15, ZMMWORD PTR shuf_vec + vmovups zmm16, ZMMWORD PTR shuf_vec + vmovups zmm17, ZMMWORD PTR shuf_vec + vmovups zmm18, ZMMWORD PTR shuf_vec + vmovups zmm19, ZMMWORD PTR shuf_vec + vmovups zmm20, ZMMWORD PTR shuf_vec + vmovups zmm21, ZMMWORD PTR shuf_vec + vmovups zmm22, ZMMWORD PTR shuf_vec + vmovups zmm23, ZMMWORD PTR shuf_vec + vmovups zmm30, ZMMWORD PTR shuf_vec + ; mov rcx, loops +loop1: + vfmadd231pd zmm0, zmm0, zmm0 + vfmadd231pd zmm1, zmm1, zmm1 + vfmadd231pd zmm2, zmm2, zmm2 + vfmadd231pd zmm3, zmm3, zmm3 + vfmadd231pd zmm4, zmm4, zmm4 + vfmadd231pd zmm5, zmm5, zmm5 + vfmadd231pd zmm6, zmm6, zmm6 + vfmadd231pd zmm7, zmm7, zmm7 + vfmadd231pd zmm8, zmm8, zmm8 + vfmadd231pd zmm9, zmm9, zmm9 + vfmadd231pd zmm10, zmm10, zmm10 + vfmadd231pd zmm11, zmm11, zmm11 + vpermd zmm12, zmm30, zmm30 + vpermd zmm13, zmm30, zmm30 + vpermd zmm14, zmm30, zmm30 + vpermd zmm15, zmm30, zmm30 + vpermd zmm16, zmm30, zmm30 + vpermd zmm17, zmm30, zmm30 + vpermd zmm18, zmm30, zmm30 + vpermd zmm19, zmm30, zmm30 + vpermd zmm20, zmm30, zmm30 + vpermd zmm21, zmm30, zmm30 + vpermd zmm22, zmm30, zmm30 + vpermd zmm23, zmm30, zmm30 + dec rcx + jg loop1 + + vzeroupper + + vmovaps xmm6, xmmword ptr[rsp] + vmovaps xmm7, xmmword ptr[rsp+16] + vmovaps xmm8, xmmword ptr[rsp+32] + vmovaps xmm9, xmmword ptr[rsp+48] + vmovaps xmm10, xmmword ptr[rsp+64] + vmovaps xmm11, xmmword ptr[rsp+80] + vmovaps xmm12, xmmword ptr[rsp+96] + vmovaps xmm13, xmmword ptr[rsp+112] + vmovaps xmm14, xmmword ptr[rsp+128] + vmovaps xmm15, xmmword ptr[rsp+144] + + mov rsp, rdx + + ret +fma_shuffle_tpt ENDP +end \ No newline at end of file diff --git a/NeoMathEngine/src/CPU/x86/CpuX86FmaCount/fma_shuffle_tpt.s b/NeoMathEngine/src/CPU/x86/CpuX86FmaCount/fma_shuffle_tpt.s new file mode 100644 index 000000000..f69ba1d00 --- /dev/null +++ b/NeoMathEngine/src/CPU/x86/CpuX86FmaCount/fma_shuffle_tpt.s @@ -0,0 +1,101 @@ +# +# Copyright (C) 2021 by Intel Corporation +# +# Permission to use, copy, modify, and/or distribute this software for any +# purpose with or without fee is hereby granted. +# +# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH +# REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY +# AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, DIRECT, +# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM +# LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR +# OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR +# PERFORMANCE OF THIS SOFTWARE. +# + + .intel_syntax noprefix + + .globl _fma_shuffle_tpt + .globl fma_shuffle_tpt + + # void fma_shuffle_tpt(uint64_t loop_cnt); + # On entry: + # rdi = loop_cnt + + .text + +_fma_shuffle_tpt: +fma_shuffle_tpt: + + vmovups zmm0, one_vec[rip] + vmovups zmm1, one_vec[rip] + vmovups zmm2, one_vec[rip] + vmovups zmm3, one_vec[rip] + vmovups zmm4, one_vec[rip] + vmovups zmm5, one_vec[rip] + vmovups zmm6, one_vec[rip] + vmovups zmm7, one_vec[rip] + vmovups zmm8, one_vec[rip] + vmovups zmm9, one_vec[rip] + vmovups zmm10, one_vec[rip] + vmovups zmm11, one_vec[rip] + vmovups zmm12, shuf_vec[rip] + vmovups zmm13, shuf_vec[rip] + vmovups zmm14, shuf_vec[rip] + vmovups zmm15, shuf_vec[rip] + vmovups zmm16, shuf_vec[rip] + vmovups zmm17, shuf_vec[rip] + vmovups zmm18, shuf_vec[rip] + vmovups zmm19, shuf_vec[rip] + vmovups zmm20, shuf_vec[rip] + vmovups zmm21, shuf_vec[rip] + vmovups zmm22, shuf_vec[rip] + vmovups zmm23, shuf_vec[rip] + vmovups zmm30, shuf_vec[rip] + mov rdx, rdi # mov rdx, loops +loop1: + vfmadd231pd zmm0, zmm0, zmm0 + vfmadd231pd zmm1, zmm1, zmm1 + vfmadd231pd zmm2, zmm2, zmm2 + vfmadd231pd zmm3, zmm3, zmm3 + vfmadd231pd zmm4, zmm4, zmm4 + vfmadd231pd zmm5, zmm5, zmm5 + vfmadd231pd zmm6, zmm6, zmm6 + vfmadd231pd zmm7, zmm7, zmm7 + vfmadd231pd zmm8, zmm8, zmm8 + vfmadd231pd zmm9, zmm9, zmm9 + vfmadd231pd zmm10, zmm10, zmm10 + vfmadd231pd zmm11, zmm11, zmm11 + vpermd zmm12, zmm30, zmm30 + vpermd zmm13, zmm30, zmm30 + vpermd zmm14, zmm30, zmm30 + vpermd zmm15, zmm30, zmm30 + vpermd zmm16, zmm30, zmm30 + vpermd zmm17, zmm30, zmm30 + vpermd zmm18, zmm30, zmm30 + vpermd zmm19, zmm30, zmm30 + vpermd zmm20, zmm30, zmm30 + vpermd zmm21, zmm30, zmm30 + vpermd zmm22, zmm30, zmm30 + vpermd zmm23, zmm30, zmm30 + dec rdx + jg loop1 + + vzeroupper + + ret + +#ifdef __APPLE__ + .section __TEXT,__const +#else + .section .rodata +#endif + .p2align 6 +one_vec: + .double 1, 1, 1, 1, 1, 1, 1, 1 +shuf_vec: + .4byte 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15 + +#if defined(__linux__) && defined(__ELF__) +.section .note.GNU-stack,"",%progbits +#endif diff --git a/NeoMathEngine/src/CPU/x86/CpuX86MathEngineVectorMathPrivate.h b/NeoMathEngine/src/CPU/x86/CpuX86MathEngineVectorMathPrivate.h index 784cac6ce..e95e456ff 100644 --- a/NeoMathEngine/src/CPU/x86/CpuX86MathEngineVectorMathPrivate.h +++ b/NeoMathEngine/src/CPU/x86/CpuX86MathEngineVectorMathPrivate.h @@ -1,4 +1,4 @@ -/* Copyright © 2017-2023 ABBYY +/* Copyright © 2017-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -242,7 +242,10 @@ inline void channelwise1x7( const float* source, const float* filter0, const flo inline void vectorFill( float* result, float value, int vectorSize ) { - if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { + if( CCPUInfo::HasAvx512And2Fma && vectorSize >= NeoML::Avx512::VectorMathMinSize ) { + NeoML::Avx512::vectorFill( result, vectorSize, value ); + return; + } else if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { NeoML::Avx2::vectorFill( result, vectorSize, value ); return; } @@ -319,7 +322,10 @@ inline void vectorFill( int* result, int value, int vectorSize ) inline void vectorFill0( float* result, int vectorSize ) { - if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { + if( CCPUInfo::HasAvx512And2Fma && vectorSize >= NeoML::Avx512::VectorMathMinSize ) { + NeoML::Avx512::vectorFill( result, vectorSize ); + return; + } else if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { NeoML::Avx2::vectorFill( result, vectorSize ); return; } @@ -399,7 +405,10 @@ inline void vectorEltwiseMax( const float* first, const float* second, float* re inline void vectorAdd( const float* first, const float* second, float* result, int vectorSize ) { - if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { + if( CCPUInfo::HasAvx512And2Fma && vectorSize >= NeoML::Avx512::VectorMathMinSize ) { + NeoML::Avx512::vectorAdd( first, second, result, vectorSize ); + return; + } else if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { NeoML::Avx2::vectorAdd( first, second, result, vectorSize ); return; } @@ -563,7 +572,10 @@ inline __m128i sse2Multiply4SignedInts( const __m128i& first, const __m128i& sec inline void vectorMultiply( const float* first, float* result, int vectorSize, float multiplier ) { - if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { + if( CCPUInfo::HasAvx512And2Fma && vectorSize >= NeoML::Avx512::VectorMathMinSize ) { + NeoML::Avx512::vectorMultiply( first, result, vectorSize, multiplier ); + return; + } else if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { NeoML::Avx2::vectorMultiply( first, result, vectorSize, multiplier ); return; } @@ -661,7 +673,10 @@ inline void vectorEltwiseMultiply( const float* first, const float* second, floa inline void vectorEltwiseMultiply( const float* first, const float* second, float* result, int vectorSize ) { - if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { + if( CCPUInfo::HasAvx512And2Fma && vectorSize >= NeoML::Avx512::VectorMathMinSize ) { + NeoML::Avx512::vectorEltwiseMultiply( first, second, result, vectorSize ); + return; + } else if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { NeoML::Avx2::vectorEltwiseMultiply( first, second, result, vectorSize ); return; } @@ -734,7 +749,10 @@ inline void vectorEltwiseMultiplyAdd( const float* first, const float* second, f return; } - if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { + if( CCPUInfo::HasAvx512And2Fma && vectorSize >= NeoML::Avx512::VectorMathMinSize ) { + NeoML::Avx512::vectorEltwiseMultiplyAdd( first, second, result, vectorSize ); + return; + } else if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { NeoML::Avx2::vectorEltwiseMultiplyAdd( first, second, result, vectorSize ); return; } @@ -786,7 +804,10 @@ inline void vectorEltwiseMultiplyAdd( const float* first, const float* second, f inline void vectorReLU( const float* first, float* result, int vectorSize ) { - if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { + if( CCPUInfo::HasAvx512And2Fma && vectorSize >= NeoML::Avx512::VectorMathMinSize ) { + NeoML::Avx512::vectorReLU( first, result, vectorSize ); + return; + } else if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { NeoML::Avx2::vectorReLU( first, result, vectorSize ); return; } @@ -829,7 +850,10 @@ inline void vectorReLU( const float* first, float* result, int vectorSize ) inline void vectorReLU( const float* first, float* result, int vectorSize, float threshold ) { - if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { + if( CCPUInfo::HasAvx512And2Fma && vectorSize >= NeoML::Avx512::VectorMathMinSize ) { + NeoML::Avx512::vectorReLU( first, result, vectorSize, threshold ); + return; + } else if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { NeoML::Avx2::vectorReLU( first, result, vectorSize, threshold ); return; } @@ -875,7 +899,10 @@ inline void vectorReLU( const float* first, float* result, int vectorSize, float inline void vectorAddValue( const float* first, float* result, int vectorSize, float value ) { - if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { + if( CCPUInfo::HasAvx512And2Fma && vectorSize >= NeoML::Avx512::VectorMathMinSize ) { + NeoML::Avx512::vectorAddValue( first, result, vectorSize, value ); + return; + } else if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { NeoML::Avx2::vectorAddValue( first, result, vectorSize, value ); return; } @@ -1179,29 +1206,29 @@ inline void vectorTanh( const float* first, float* result, int vectorSize ) { #ifdef NEOML_USE_MLAS MlasComputeTanh( first, result, static_cast( vectorSize ) ); -#else +#else // !NEOML_USE_MLAS for( int i = 0; i < vectorSize; ++i ) { result[i] = -1.f + 2 / ( 1.f + ExponentFunc( -2 * first[i] ) ); } -#endif +#endif // !NEOML_USE_MLAS } inline void vectorExp( const float* first, float* result, int vectorSize ) { #ifdef NEOML_USE_MLAS MlasComputeExp( first, result, static_cast( vectorSize ) ); -#else +#else // !NEOML_USE_MLAS for( int i = 0; i < vectorSize; ++i ) { result[i] = ExponentFunc( first[i] ); } -#endif +#endif // !NEOML_USE_MLAS } inline void vectorSigmoid( const float* first, float* result, int vectorSize ) { #ifdef NEOML_USE_MLAS MlasComputeLogistic( first, result, static_cast( vectorSize ) ); -#else +#else // !NEOML_USE_MLAS int sseSize; int nonSseSize; checkSse( vectorSize, sseSize, nonSseSize ); @@ -1223,7 +1250,7 @@ inline void vectorSigmoid( const float* first, float* result, int vectorSize ) *result = *result / ( *result + 1 ); ++result; } -#endif +#endif // !NEOML_USE_MLAS } //------------------------------------------------------------------------------------------------------------ @@ -1238,7 +1265,10 @@ inline __m128 vectorHSwishWorker( const __m128& first, const __m128& three, inline void vectorHSwish( const float* first, float* result, int vectorSize ) { - if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { + if( CCPUInfo::HasAvx512And2Fma && vectorSize >= NeoML::Avx512::VectorMathMinSize ) { + NeoML::Avx512::vectorHSwish( first, result, vectorSize ); + return; + } else if( CCPUInfo::HasAvxAndFma && vectorSize >= NeoML::Avx2::VectorMathMinSize ) { NeoML::Avx2::vectorHSwish( first, result, vectorSize ); return; } diff --git a/NeoMathEngine/src/CPU/x86/avx/src/AvxMathEngine.cpp b/NeoMathEngine/src/CPU/x86/avx/src/AvxMathEngine.cpp index 69d0d9ffb..d78adaa71 100644 --- a/NeoMathEngine/src/CPU/x86/avx/src/AvxMathEngine.cpp +++ b/NeoMathEngine/src/CPU/x86/avx/src/AvxMathEngine.cpp @@ -1,4 +1,4 @@ -/* Copyright © 2017-2023 ABBYY +/* Copyright © 2017-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -79,7 +79,7 @@ CConvolutionDesc* CAvxMathEngine::InitBlobConvolution( const CBlobDesc& source, int strideHeight, int strideWidth, int dilationHeight, int dilationWidth, const CBlobDesc& filter, const CBlobDesc& result ) const { - if( !CCPUInfo::IsAvx512Available() + if( !CCPUInfo::HasAvx512 && CBlobConvolutionFabric::IsBlobConvolutionAvailable( source.ObjectCount() * source.Height() * source.Width(), filter.BatchWidth() , filter.Height(), filter.Width() ) ) { diff --git a/NeoMathEngine/src/CPU/x86/avx2/Avx2VectorFunctions.cpp b/NeoMathEngine/src/CPU/x86/avx2/Avx2VectorFunctions.cpp index dcebc95d0..f2173e0bb 100644 --- a/NeoMathEngine/src/CPU/x86/avx2/Avx2VectorFunctions.cpp +++ b/NeoMathEngine/src/CPU/x86/avx2/Avx2VectorFunctions.cpp @@ -1,4 +1,4 @@ -/* Copyright © 2017-2023 ABBYY +/* Copyright © 2017-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -24,10 +24,20 @@ limitations under the License. static constexpr int AvxBlockSize = 8; -static constexpr int avxIOMask[2 * AvxBlockSize - 2] = { -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0 }; +//#define NEOML_USE_AVX_MASK // need (/AVX512 DQ+F+VL) for functions: _cvtu32_mask8, _mm256_mask_storeu_ps, _mm256_mask_loadu_ps +#ifdef NEOML_USE_AVX_MASK + +#define AVX_IO_MASK( N ) \ + _cvtu32_mask8( ( 1u << N ) - 1u ) + +#else // !NEOML_USE_AVX_MASK +static_assert( sizeof( int ) == sizeof( float ), "Avx2: invalid size int != float" ); +static constexpr int avxIOMask[2 * ( AvxBlockSize - 1 )]{ -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0 }; #define AVX_IO_MASK( N ) \ _mm256_lddqu_si256( reinterpret_cast( avxIOMask + AvxBlockSize - 1 - N ) ) +#endif // !NEOML_USE_AVX_MASK + #define AVX_LOAD_32_FLOATS(varPrefix, srcPtr) \ __m256 varPrefix##0 = _mm256_loadu_ps( srcPtr + 0 * AvxBlockSize ); \ @@ -63,8 +73,13 @@ void dataCopy( float* dst, const float* src, int vectorSize ) } if( vectorSize > 0 ) { +#ifdef NEOML_USE_AVX_MASK + const __mmask8 mask = AVX_IO_MASK( vectorSize ); + _mm256_mask_storeu_ps( dst, mask, _mm256_mask_loadu_ps( _mm256_setzero_ps(), mask, src ) ); +#else // !NEOML_USE_AVX_MASK const __m256i mask = AVX_IO_MASK( vectorSize ); _mm256_maskstore_ps( dst, mask, _mm256_maskload_ps( src, mask ) ); +#endif // !NEOML_USE_AVX_MASK } } @@ -88,7 +103,12 @@ void vectorFill( float* result, int vectorSize, float value ) } if( vectorSize > 0 ) { +#ifdef NEOML_USE_AVX_MASK + const __mmask8 mask = AVX_IO_MASK( vectorSize ); + _mm256_mask_storeu_ps( result, mask, valueSimd ); +#else // !NEOML_USE_AVX_MASK _mm256_maskstore_ps( result, AVX_IO_MASK( vectorSize ), valueSimd ); +#endif // !NEOML_USE_AVX_MASK } } @@ -118,10 +138,19 @@ void vectorAdd( const float* first, const float* second, float* result, int vect } if( vectorSize > 0 ) { +#ifdef NEOML_USE_AVX_MASK + const __m256 zeroSimd = _mm256_setzero_ps(); + const __mmask8 mask = AVX_IO_MASK( vectorSize ); + + const __m256 firstSimd = _mm256_mask_loadu_ps( zeroSimd, mask, first ); + const __m256 secondSimd = _mm256_mask_loadu_ps( zeroSimd, mask, second ); + _mm256_mask_storeu_ps( result, mask, _mm256_add_ps( firstSimd, secondSimd ) ); +#else // !NEOML_USE_AVX_MASK const __m256i mask = AVX_IO_MASK( vectorSize ); const __m256 firstSimd = _mm256_maskload_ps( first, mask ); const __m256 secondSimd = _mm256_maskload_ps( second, mask ); _mm256_maskstore_ps( result, mask, _mm256_add_ps( firstSimd, secondSimd ) ); +#endif // !NEOML_USE_AVX_MASK } } @@ -189,10 +218,19 @@ void vectorEltwiseMultiply( const float* first, const float* second, float* resu } if( vectorSize > 0 ) { +#ifdef NEOML_USE_AVX_MASK + const __m256 zeroSimd = _mm256_setzero_ps(); + const __mmask8 mask = AVX_IO_MASK( vectorSize ); + + const __m256 firstSimd = _mm256_mask_loadu_ps( zeroSimd, mask, first ); + const __m256 secondSimd = _mm256_mask_loadu_ps( zeroSimd, mask, second ); + _mm256_mask_storeu_ps( result, mask, _mm256_mul_ps( firstSimd, secondSimd ) ); +#else // !NEOML_USE_AVX_MASK const __m256i mask = AVX_IO_MASK( vectorSize ); const __m256 firstSimd = _mm256_maskload_ps( first, mask ); const __m256 secondSimd = _mm256_maskload_ps( second, mask ); _mm256_maskstore_ps( result, mask, _mm256_mul_ps( firstSimd, secondSimd ) ); +#endif // !NEOML_USE_AVX_MASK } } @@ -223,11 +261,21 @@ void vectorEltwiseMultiplyAdd( const float* first, const float* second, float* r } if( vectorSize > 0 ) { +#ifdef NEOML_USE_AVX_MASK + const __m256 zeroSimd = _mm256_setzero_ps(); + const __mmask8 mask = AVX_IO_MASK( vectorSize ); + + const __m256 firstSimd = _mm256_mask_loadu_ps( zeroSimd, mask, first ); + const __m256 secondSimd = _mm256_mask_loadu_ps( zeroSimd, mask, second ); + const __m256 resultSimd = _mm256_mask_loadu_ps( zeroSimd, mask, result ); + _mm256_mask_storeu_ps( result, mask, _mm256_fmadd_ps( firstSimd, secondSimd, resultSimd ) ); +#else // !NEOML_USE_AVX_MASK const __m256i mask = AVX_IO_MASK( vectorSize ); const __m256 firstSimd = _mm256_maskload_ps( first, mask ); const __m256 secondSimd = _mm256_maskload_ps( second, mask ); const __m256 resultSimd = _mm256_maskload_ps( result, mask ); _mm256_maskstore_ps( result, mask, _mm256_fmadd_ps( firstSimd, secondSimd, resultSimd ) ); +#endif // !NEOML_USE_AVX_MASK } } @@ -244,8 +292,13 @@ void vectorReLU( const float* first, float* result, int vectorSize ) } if( vectorSize > 0 ) { +#ifdef NEOML_USE_AVX_MASK + const __mmask8 mask = AVX_IO_MASK( vectorSize ); + _mm256_mask_storeu_ps( result, mask, _mm256_max_ps( _mm256_mask_loadu_ps( zeroSimd, mask, first ), zeroSimd ) ); +#else // !NEOML_USE_AVX_MASK const __m256i mask = AVX_IO_MASK( vectorSize ); _mm256_maskstore_ps( result, mask, _mm256_max_ps( _mm256_maskload_ps( first, mask ), zeroSimd ) ); +#endif // !NEOML_USE_AVX_MASK } } @@ -263,9 +316,15 @@ void vectorReLU( const float* first, float* result, int vectorSize, float thresh } if( vectorSize > 0 ) { +#ifdef NEOML_USE_AVX_MASK + const __mmask8 mask = AVX_IO_MASK( vectorSize ); + const __m256 firstSimd = _mm256_mask_loadu_ps( zeroSimd, mask, first ); + _mm256_mask_storeu_ps( result, mask, _mm256_min_ps( _mm256_max_ps( firstSimd, zeroSimd ), thresholdSimd ) ); +#else // !NEOML_USE_AVX_MASK const __m256i mask = AVX_IO_MASK( vectorSize ); const __m256 firstSimd = _mm256_maskload_ps( first, mask ); _mm256_maskstore_ps( result, mask, _mm256_min_ps( _mm256_max_ps( firstSimd, zeroSimd ), thresholdSimd ) ); +#endif // !NEOML_USE_AVX_MASK } } @@ -287,11 +346,16 @@ void vectorHSwish( const float* first, float* result, int vectorSize ) } if( vectorSize > 0 ) { +#ifdef NEOML_USE_AVX_MASK + const __mmask8 mask = AVX_IO_MASK( vectorSize ); + __m256 firstSimd = _mm256_mask_loadu_ps( _mm256_setzero_ps(), mask, first ); +#else // !NEOML_USE_AVX_MASK const __m256i mask = AVX_IO_MASK( vectorSize ); __m256 firstSimd = _mm256_maskload_ps( first, mask ); __m256 middlePart = _mm256_max_ps( _mm256_add_ps( firstSimd, threeSimd ), zeroSimd ); middlePart = _mm256_mul_ps( _mm256_mul_ps( firstSimd, oneSixthSimd ), middlePart ); _mm256_maskstore_ps( result, mask, _mm256_min_ps( middlePart, _mm256_max_ps( firstSimd, threeSimd ) ) ); +#endif // !NEOML_USE_AVX_MASK } } @@ -299,4 +363,4 @@ void vectorHSwish( const float* first, float* result, int vectorSize ) } // namespace NeoML -#endif +#endif // NEOML_USE_SSE diff --git a/NeoMathEngine/src/CPU/x86/avx512/Avx512Functions.h b/NeoMathEngine/src/CPU/x86/avx512/Avx512Functions.h new file mode 100644 index 000000000..065baf5c3 --- /dev/null +++ b/NeoMathEngine/src/CPU/x86/avx512/Avx512Functions.h @@ -0,0 +1,53 @@ +/* Copyright © 2023-2024 ABBYY + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--------------------------------------------------------------------------------------------------------------*/ + +#pragma once + +#include + +#ifdef NEOML_USE_SSE + +namespace NeoML { + +namespace Avx512 { + +// The minimum vector size recommended for using AVX512 vector functions +static constexpr int VectorMathMinSize = 16; + +void dataCopy( float* dst, const float* src, int vectorSize ); + +void vectorFill( float* result, int vectorSize, float value = 0.f ); + +void vectorAdd( const float* first, const float* second, float* result, int vectorSize ); + +void vectorAddValue( const float* first, float* result, int vectorSize, float value ); + +void vectorMultiply( const float* first, float* result, int vectorSize, float multiplier ); + +void vectorEltwiseMultiply( const float* first, const float* second, float* result, int vectorSize ); + +void vectorEltwiseMultiplyAdd( const float* first, const float* second, float* result, int vectorSize ); + +void vectorReLU( const float* first, float* result, int vectorSize ); + +void vectorReLU( const float* first, float* result, int vectorSize, float threshold ); + +void vectorHSwish( const float* first, float* result, int vectorSize ); + +} // namespace Avx512 + +} // namespace NeoML + +#endif // NEOML_USE_SSE diff --git a/NeoMathEngine/src/CPU/x86/avx512/Avx512VectorFunctions.cpp b/NeoMathEngine/src/CPU/x86/avx512/Avx512VectorFunctions.cpp new file mode 100644 index 000000000..6b32241fd --- /dev/null +++ b/NeoMathEngine/src/CPU/x86/avx512/Avx512VectorFunctions.cpp @@ -0,0 +1,407 @@ +/* Copyright © 2023-2024 ABBYY + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--------------------------------------------------------------------------------------------------------------*/ + +#include + +#ifdef NEOML_USE_SSE + +#include "Avx512Functions.h" + +#include +#include + +namespace NeoML { + +namespace Avx512 { + +static constexpr int AvxBlockSize = 16; + +#define AVX512_IO_MASK( N ) \ + _cvtu32_mask16( ( 1u << N ) - 1u ) + + +#ifdef AVX512_64FLOATS +#define AVX512_LOAD_64_FLOATS( varPrefix, srcPtr ) \ + __m512 varPrefix##0 = _mm512_loadu_ps( srcPtr + 0 * AvxBlockSize ); \ + __m512 varPrefix##1 = _mm512_loadu_ps( srcPtr + 1 * AvxBlockSize ); \ + __m512 varPrefix##2 = _mm512_loadu_ps( srcPtr + 2 * AvxBlockSize ); \ + __m512 varPrefix##3 = _mm512_loadu_ps( srcPtr + 3 * AvxBlockSize ) + +#define AVX512_STORE_64_FLOATS( varPrefix, dstPtr ) \ + _mm512_storeu_ps( dstPtr + 0 * AvxBlockSize, varPrefix##0 ); \ + _mm512_storeu_ps( dstPtr + 1 * AvxBlockSize, varPrefix##1 ); \ + _mm512_storeu_ps( dstPtr + 2 * AvxBlockSize, varPrefix##2 ); \ + _mm512_storeu_ps( dstPtr + 3 * AvxBlockSize, varPrefix##3 ) +#endif //AVX512_64FLOATS + + +//--------------------------------------------------------------------------------- + +void dataCopy( float* dst, const float* src, int vectorSize ) +{ +#ifdef AVX512_64FLOATS + while( vectorSize >= 4 * AvxBlockSize ) { + AVX512_LOAD_64_FLOATS( data, src ); + AVX512_STORE_64_FLOATS( data, dst ); + dst += 4 * AvxBlockSize; + src += 4 * AvxBlockSize; + vectorSize -= 4 * AvxBlockSize; + } +#endif //AVX512_64FLOATS + + while( vectorSize >= AvxBlockSize ) { + _mm512_storeu_ps( dst, _mm512_loadu_ps( src ) ); + dst += AvxBlockSize; + src += AvxBlockSize; + vectorSize -= AvxBlockSize; + } + + if( vectorSize > 0 ) { + const __mmask16 mask = AVX512_IO_MASK( vectorSize ); + _mm512_mask_storeu_ps( dst, mask, _mm512_mask_loadu_ps( _mm512_setzero_ps(), mask, src ) ); + } +} + +void vectorFill( float* result, int vectorSize, float value ) +{ + const __m512 valueSimd = _mm512_set1_ps( value ); +#ifdef AVX512_64FLOATS + while( vectorSize >= 4 * AvxBlockSize ) { + _mm512_storeu_ps( result + 0 * AvxBlockSize, valueSimd ); + _mm512_storeu_ps( result + 1 * AvxBlockSize, valueSimd ); + _mm512_storeu_ps( result + 2 * AvxBlockSize, valueSimd ); + _mm512_storeu_ps( result + 3 * AvxBlockSize, valueSimd ); + result += 4 * AvxBlockSize; + vectorSize -= 4 * AvxBlockSize; + } +#endif //AVX512_64FLOATS + + while( vectorSize >= AvxBlockSize ) { + _mm512_storeu_ps( result, valueSimd ); + result += AvxBlockSize; + vectorSize -= AvxBlockSize; + } + + if( vectorSize > 0 ) { + _mm512_mask_storeu_ps( result, AVX512_IO_MASK( vectorSize ), valueSimd ); + } +} + +void vectorAdd( const float* first, const float* second, float* result, int vectorSize ) +{ +#ifdef AVX512_64FLOATS + while( vectorSize >= 4 * AvxBlockSize ) { + AVX512_LOAD_64_FLOATS( first, first ); + AVX512_LOAD_64_FLOATS( second, second ); + first0 = _mm512_add_ps( first0, second0 ); + first1 = _mm512_add_ps( first1, second1 ); + first2 = _mm512_add_ps( first2, second2 ); + first3 = _mm512_add_ps( first3, second3 ); + AVX512_STORE_64_FLOATS( first, result ); + first += 4 * AvxBlockSize; + second += 4 * AvxBlockSize; + result += 4 * AvxBlockSize; + vectorSize -= 4 * AvxBlockSize; + } +#endif //AVX512_64FLOATS + + while( vectorSize >= AvxBlockSize ) { + _mm512_storeu_ps( result, + _mm512_add_ps( _mm512_loadu_ps( first ), _mm512_loadu_ps( second ) ) ); + first += AvxBlockSize; + second += AvxBlockSize; + result += AvxBlockSize; + vectorSize -= AvxBlockSize; + } + + if( vectorSize > 0 ) { + const __m512 zeroSimd = _mm512_setzero_ps(); // copy data from here, where mask bits are false + const __mmask16 mask = AVX512_IO_MASK( vectorSize ); + + const __m512 firstSimd = _mm512_mask_loadu_ps( zeroSimd, mask, first ); + const __m512 secondSimd = _mm512_mask_loadu_ps( zeroSimd, mask, second ); + _mm512_mask_storeu_ps( result, mask, _mm512_add_ps( firstSimd, secondSimd ) ); + } +} + +void vectorAddValue( const float* first, float* result, int vectorSize, float value ) +{ + const __m512 valueSimd = _mm512_set1_ps( value ); +#ifdef AVX512_64FLOATS + while( vectorSize >= 4 * AvxBlockSize ) { + AVX512_LOAD_64_FLOATS( first, first ); + first0 = _mm512_add_ps( first0, valueSimd ); + first1 = _mm512_add_ps( first1, valueSimd ); + first2 = _mm512_add_ps( first2, valueSimd ); + first3 = _mm512_add_ps( first3, valueSimd ); + AVX512_STORE_64_FLOATS( first, result ); + first += 4 * AvxBlockSize; + result += 4 * AvxBlockSize; + vectorSize -= 4 * AvxBlockSize; + } +#endif //AVX512_64FLOATS + + while( vectorSize >= AvxBlockSize ) { + _mm512_storeu_ps( result, + _mm512_add_ps( _mm512_loadu_ps( first ), valueSimd ) ); + first += AvxBlockSize; + result += AvxBlockSize; + vectorSize -= AvxBlockSize; + } + + if( vectorSize > 0 ) { + const __mmask16 mask = AVX512_IO_MASK( vectorSize ); + _mm512_mask_storeu_ps( result, mask, + _mm512_add_ps( _mm512_mask_loadu_ps( _mm512_setzero_ps(), mask, first ), valueSimd ) ); + } +} + +void vectorMultiply( const float* first, float* result, int vectorSize, float multiplier ) +{ + const __m512 multSimd = _mm512_set1_ps( multiplier ); +#ifdef AVX512_64FLOATS + while( vectorSize >= 4 * AvxBlockSize ) { + AVX512_LOAD_64_FLOATS( first, first ); + first0 = _mm512_mul_ps( first0, multSimd ); + first1 = _mm512_mul_ps( first1, multSimd ); + first2 = _mm512_mul_ps( first2, multSimd ); + first3 = _mm512_mul_ps( first3, multSimd ); + AVX512_STORE_64_FLOATS( first, result ); + first += 4 * AvxBlockSize; + result += 4 * AvxBlockSize; + vectorSize -= 4 * AvxBlockSize; + } +#endif //AVX512_64FLOATS + + while( vectorSize >= AvxBlockSize ) { + _mm512_storeu_ps( result, + _mm512_mul_ps( _mm512_loadu_ps( first ), multSimd ) ); + first += AvxBlockSize; + result += AvxBlockSize; + vectorSize -= AvxBlockSize; + } + + if( vectorSize > 0 ) { + const __mmask16 mask = AVX512_IO_MASK( vectorSize ); + _mm512_mask_storeu_ps( result, mask, + _mm512_mul_ps( _mm512_mask_loadu_ps( _mm512_setzero_ps(), mask, first ), multSimd ) ); + } +} + +void vectorEltwiseMultiply( const float* first, const float* second, float* result, int vectorSize ) +{ +#ifdef AVX512_64FLOATS + while( vectorSize >= 4 * AvxBlockSize ) { + AVX512_LOAD_64_FLOATS( first, first ); + AVX512_LOAD_64_FLOATS( second, second ); + first0 = _mm512_mul_ps( first0, second0 ); + first1 = _mm512_mul_ps( first1, second1 ); + first2 = _mm512_mul_ps( first2, second2 ); + first3 = _mm512_mul_ps( first3, second3 ); + AVX512_STORE_64_FLOATS( first, result ); + first += 4 * AvxBlockSize; + second += 4 * AvxBlockSize; + result += 4 * AvxBlockSize; + vectorSize -= 4 * AvxBlockSize; + } +#endif //AVX512_64FLOATS + + while( vectorSize >= AvxBlockSize ) { + _mm512_storeu_ps( result, + _mm512_mul_ps( _mm512_loadu_ps( first ), _mm512_loadu_ps( second ) ) ); + first += AvxBlockSize; + second += AvxBlockSize; + result += AvxBlockSize; + vectorSize -= AvxBlockSize; + } + + if( vectorSize > 0 ) { + const __m512 zeroSimd = _mm512_setzero_ps(); // copy data from here, where mask bits are false + const __mmask16 mask = AVX512_IO_MASK( vectorSize ); + + const __m512 firstSimd = _mm512_mask_loadu_ps( zeroSimd, mask, first ); + const __m512 secondSimd = _mm512_mask_loadu_ps( zeroSimd, mask, second ); + _mm512_mask_storeu_ps( result, mask, _mm512_mul_ps( firstSimd, secondSimd ) ); + } +} + +void vectorEltwiseMultiplyAdd( const float* first, const float* second, float* result, int vectorSize ) +{ +#ifdef AVX512_64FLOATS + while( vectorSize >= 4 * AvxBlockSize ) { + AVX512_LOAD_64_FLOATS( first, first ); + AVX512_LOAD_64_FLOATS( second, second ); + AVX512_LOAD_64_FLOATS( result, result ); + result0 = _mm512_fmadd_ps( first0, second0, result0 ); + result1 = _mm512_fmadd_ps( first1, second1, result1 ); + result2 = _mm512_fmadd_ps( first2, second2, result2 ); + result3 = _mm512_fmadd_ps( first3, second3, result3 ); + AVX512_STORE_64_FLOATS( result, result ); + first += 4 * AvxBlockSize; + second += 4 * AvxBlockSize; + result += 4 * AvxBlockSize; + vectorSize -= 4 * AvxBlockSize; + } +#endif //AVX512_64FLOATS + + while( vectorSize >= AvxBlockSize ) { + _mm512_storeu_ps( result, + _mm512_fmadd_ps( _mm512_loadu_ps( first ), _mm512_loadu_ps( second ), _mm512_loadu_ps( result ) ) ); + first += AvxBlockSize; + second += AvxBlockSize; + result += AvxBlockSize; + vectorSize -= AvxBlockSize; + } + + if( vectorSize > 0 ) { + const __m512 zeroSimd = _mm512_setzero_ps(); // copy data from here, where mask bits are false + const __mmask16 mask = AVX512_IO_MASK( vectorSize ); + + const __m512 firstSimd = _mm512_mask_loadu_ps( zeroSimd, mask, first ); + const __m512 secondSimd = _mm512_mask_loadu_ps( zeroSimd, mask, second ); + const __m512 resultSimd = _mm512_mask_loadu_ps( zeroSimd, mask, result ); + _mm512_mask_storeu_ps( result, mask, _mm512_fmadd_ps( firstSimd, secondSimd, resultSimd ) ); + } +} + +void vectorReLU( const float* first, float* result, int vectorSize ) +{ + const __m512 zeroSimd = _mm512_setzero_ps(); +#ifdef AVX512_64FLOATS + while( vectorSize >= 4 * AvxBlockSize ) { + AVX512_LOAD_64_FLOATS( first, first ); + __m512 result0 = _mm512_max_ps( first0, zeroSimd ); + __m512 result1 = _mm512_max_ps( first1, zeroSimd ); + __m512 result2 = _mm512_max_ps( first2, zeroSimd ); + __m512 result3 = _mm512_max_ps( first3, zeroSimd ); + AVX512_STORE_64_FLOATS( result, result ); + first += 4 * AvxBlockSize; + result += 4 * AvxBlockSize; + vectorSize -= 4 * AvxBlockSize; + } +#endif //AVX512_64FLOATS + + while( vectorSize >= AvxBlockSize ) { + _mm512_storeu_ps( result, + _mm512_max_ps( _mm512_loadu_ps( first ), zeroSimd ) ); + first += AvxBlockSize; + result += AvxBlockSize; + vectorSize -= AvxBlockSize; + } + + if( vectorSize > 0 ) { + const __mmask16 mask = AVX512_IO_MASK( vectorSize ); + _mm512_mask_storeu_ps( result, mask, + _mm512_max_ps( _mm512_mask_loadu_ps( zeroSimd, mask, first ), zeroSimd ) ); + } +} + +void vectorReLU( const float* first, float* result, int vectorSize, float threshold ) +{ + const __m512 zeroSimd = _mm512_setzero_ps(); + const __m512 thresholdSimd = _mm512_set1_ps( threshold ); +#ifdef AVX512_64FLOATS + while( vectorSize >= 4 * AvxBlockSize ) { + AVX512_LOAD_64_FLOATS( first, first ); + __m512 result0 = _mm512_min_ps( _mm512_max_ps( first0, zeroSimd ), thresholdSimd ); + __m512 result1 = _mm512_min_ps( _mm512_max_ps( first1, zeroSimd ), thresholdSimd ); + __m512 result2 = _mm512_min_ps( _mm512_max_ps( first2, zeroSimd ), thresholdSimd ); + __m512 result3 = _mm512_min_ps( _mm512_max_ps( first3, zeroSimd ), thresholdSimd ); + AVX512_STORE_64_FLOATS( result, result ); + first += 4 * AvxBlockSize; + result += 4 * AvxBlockSize; + vectorSize -= 4 * AvxBlockSize; + } +#endif //AVX512_64FLOATS + + while( vectorSize >= AvxBlockSize ) { + _mm512_storeu_ps( result, + _mm512_min_ps( _mm512_max_ps( _mm512_loadu_ps( first ), zeroSimd ), thresholdSimd ) ); + first += AvxBlockSize; + result += AvxBlockSize; + vectorSize -= AvxBlockSize; + } + + if( vectorSize > 0 ) { + const __mmask16 mask = AVX512_IO_MASK( vectorSize ); + const __m512 firstSimd = _mm512_mask_loadu_ps( zeroSimd, mask, first ); + _mm512_mask_storeu_ps( result, mask, _mm512_min_ps( _mm512_max_ps( firstSimd, zeroSimd ), thresholdSimd ) ); + } +} + +void vectorHSwish( const float* first, float* result, int vectorSize ) +{ + const __m512 zeroSimd = _mm512_setzero_ps(); + const __m512 threeSimd = _mm512_set1_ps( 3.f ); + const __m512 oneSixthSimd = _mm512_set1_ps( 1.f / 6.f ); + + //for( int i = 0; i < vectorSize; ++i ) { + // if( *first <= -3 ) *result = ( 0 ); + // else if( *first >= 3 ) *result = ( *first ); + // else *result = ( *first * 1 / 6 ) * ( *first + 3 ); + // ++result; + // ++first; + //} + +#ifdef AVX512_64FLOATS + while( vectorSize >= 4 * AvxBlockSize ) { + AVX512_LOAD_64_FLOATS( data, first ); + __m512 middlePart0 = _mm512_max_ps( _mm512_add_ps( data0, threeSimd ), zeroSimd ); + __m512 middlePart1 = _mm512_max_ps( _mm512_add_ps( data1, threeSimd ), zeroSimd ); + __m512 middlePart2 = _mm512_max_ps( _mm512_add_ps( data2, threeSimd ), zeroSimd ); + __m512 middlePart3 = _mm512_max_ps( _mm512_add_ps( data3, threeSimd ), zeroSimd ); + + middlePart0 = _mm512_mul_ps( _mm512_mul_ps( data0, oneSixthSimd ), middlePart0 ); + middlePart1 = _mm512_mul_ps( _mm512_mul_ps( data1, oneSixthSimd ), middlePart1 ); + middlePart2 = _mm512_mul_ps( _mm512_mul_ps( data2, oneSixthSimd ), middlePart2 ); + middlePart3 = _mm512_mul_ps( _mm512_mul_ps( data3, oneSixthSimd ), middlePart3 ); + + data0 = _mm512_min_ps( _mm512_max_ps( data0, threeSimd ), middlePart0 ); + data1 = _mm512_min_ps( _mm512_max_ps( data1, threeSimd ), middlePart1 ); + data2 = _mm512_min_ps( _mm512_max_ps( data2, threeSimd ), middlePart2 ); + data3 = _mm512_min_ps( _mm512_max_ps( data3, threeSimd ), middlePart3 ); + AVX512_STORE_64_FLOATS( data, result ); + first += 4 * AvxBlockSize; + result += 4 * AvxBlockSize; + vectorSize -= 4 * AvxBlockSize; + } +#endif //AVX512_64FLOATS + + while( vectorSize >= AvxBlockSize ) { + __m512 firstSimd = _mm512_loadu_ps( first ); + __m512 middlePart = _mm512_max_ps( _mm512_add_ps( firstSimd, threeSimd ), zeroSimd ); + middlePart = _mm512_mul_ps( _mm512_mul_ps( firstSimd, oneSixthSimd ), middlePart ); + _mm512_storeu_ps( result, _mm512_min_ps( middlePart, _mm512_max_ps( firstSimd, threeSimd ) ) ); + + first += AvxBlockSize; + result += AvxBlockSize; + vectorSize -= AvxBlockSize; + } + + if( vectorSize > 0 ) { + const __mmask16 mask = AVX512_IO_MASK( vectorSize ); + __m512 firstSimd = _mm512_mask_loadu_ps( zeroSimd, mask, first ); + __m512 middlePart = _mm512_max_ps( _mm512_add_ps( firstSimd, threeSimd ), zeroSimd ); + middlePart = _mm512_mul_ps( _mm512_mul_ps( firstSimd, oneSixthSimd ), middlePart ); + _mm512_mask_storeu_ps( result, mask, + _mm512_min_ps( middlePart, _mm512_max_ps( firstSimd, threeSimd ) ) ); + } +} + +} // namespace Avx512 + +} // namespace NeoML + +#endif // NEOML_USE_SSE diff --git a/NeoMathEngine/test/src/inference/CMakeLists.txt b/NeoMathEngine/test/src/inference/CMakeLists.txt index 50b89ac88..e3ca20d7c 100644 --- a/NeoMathEngine/test/src/inference/CMakeLists.txt +++ b/NeoMathEngine/test/src/inference/CMakeLists.txt @@ -108,6 +108,8 @@ target_sources(${PROJECT_NAME} INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}/VectorTanhDiffOpTest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/VectorTanhDiffTest.cpp ${CMAKE_CURRENT_SOURCE_DIR}/VectorTanhTest.cpp + + ${CMAKE_CURRENT_SOURCE_DIR}/VectorBenchmark.cpp ) target_include_directories(${PROJECT_NAME} INTERFACE ${CMAKE_CURRENT_SOURCE_DIR}) diff --git a/NeoMathEngine/test/src/inference/VectorBenchmark.cpp b/NeoMathEngine/test/src/inference/VectorBenchmark.cpp new file mode 100644 index 000000000..73fa7f091 --- /dev/null +++ b/NeoMathEngine/test/src/inference/VectorBenchmark.cpp @@ -0,0 +1,246 @@ +/* Copyright © 2023-2024 ABBYY + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--------------------------------------------------------------------------------------------------------------*/ + +#include +#include +#include +#include + +using namespace NeoML; +using namespace NeoMLTest; + +namespace NeoMLTest { + +static const char* vectorFunctionsNames[]{ + "VectorCopy", + "VectorFill", + "VectorAdd", + "VectorAddVal", + "VectorMultiply", + "VectorEltwiseMultiply", + "VectorEltwiseMultAdd", + "VectorReLU(0)", + "VectorReLU(Threshold)", + "VectorHSwish" +}; + +//------------------------------------------------------------------------------------------------------------ + +class VectorBenchmarkParams final { +public: + int Function = -2; + int TestCount = -1; + int VectorSize = -1; + IPerformanceCounters* Counters = nullptr; + std::ofstream FOut{}; + + VectorBenchmarkParams( int function, int testCount, int vectorSize, + const CInterval& valuesInterval, int seed ); + ~VectorBenchmarkParams() { delete Counters; } + + void SetNextSeedForFunction( int function, int seed ); + + CFloatWrapper& GetInputBuffer() { return *inputBuf; } + CFloatWrapper& GetSecondBuffer() { return *secondBuf; } + CFloatWrapper& GetResultBuffer() { return *resultBuf; } + CFloatWrapper& GetZeroVal() { return *zeroBuf; } + CFloatWrapper& GetMulVal() { return *mulBuf; } + +private: + const CInterval& valuesInterval; + CRandom random; + + std::vector input; + std::vector second; + std::vector result; + + std::unique_ptr inputBuf = nullptr; + std::unique_ptr secondBuf = nullptr; + std::unique_ptr resultBuf = nullptr; + std::unique_ptr zeroBuf = nullptr; + std::unique_ptr mulBuf = nullptr; +}; + +VectorBenchmarkParams::VectorBenchmarkParams( int function, int testCount, int vectorSize, + const CInterval& valuesInterval, int seed ) : + Function( function ), + TestCount( testCount ), + VectorSize( vectorSize ), + Counters( MathEngine().CreatePerformanceCounters() ), + FOut( std::ofstream( "VectorBenchmarkTest.csv", std::ios::app ) ), + valuesInterval( valuesInterval ) +{ + FOut << "\n---------------------------" << std::endl; + input.resize( vectorSize ); + second.resize( vectorSize ); + result.resize( vectorSize ); + float zero = 0; + zeroBuf.reset( new CFloatWrapper( MathEngine(), &zero, 1 ) ); + + SetNextSeedForFunction( function, seed ); +} + +void VectorBenchmarkParams::SetNextSeedForFunction( int function, int seed ) +{ + Function = function; + random = CRandom( seed ); + for( int i = 0; i < VectorSize; ++i ) { + input[i] = static_cast( random.Uniform( valuesInterval.Begin, valuesInterval.End ) ); + second[i] = static_cast( random.Uniform( valuesInterval.Begin, valuesInterval.End ) ); + result[i] = 0; + } + float multiplier = static_cast( random.Uniform( 1, valuesInterval.End ) ); + mulBuf.reset( new CFloatWrapper( MathEngine(), &multiplier, 1 ) ); + CConstFloatHandle mulHandle = *mulBuf; + ASSERT_EXPR( mulHandle.GetValueAt( 0 ) > 0 ); + + inputBuf.reset( new CFloatWrapper( MathEngine(), input.data(), VectorSize ) ); + secondBuf.reset( new CFloatWrapper( MathEngine(), second.data(), VectorSize ) ); + resultBuf.reset( new CFloatWrapper( MathEngine(), result.data(), VectorSize ) ); +} + +//------------------------------------------------------------------------------------------------------------ + +static double vectorBenchmark( VectorBenchmarkParams& params ) +{ + CFloatWrapper& input = params.GetInputBuffer(); + CFloatWrapper& second = params.GetSecondBuffer(); + CFloatWrapper& result = params.GetResultBuffer(); + CFloatWrapper& zeroVal = params.GetZeroVal(); + CFloatWrapper& mulVal = params.GetMulVal(); + const int vectorSize = params.VectorSize; + + if( params.Function == -1 ) { // warm-up + MathEngine().VectorCopy( result, input, vectorSize ); + MathEngine().VectorFill( result, vectorSize, mulVal ); + MathEngine().VectorAdd( input, second, result, vectorSize ); + MathEngine().VectorAddValue( input, result, vectorSize, mulVal ); + MathEngine().VectorMultiply( input, second, vectorSize, mulVal ); + MathEngine().VectorEltwiseMultiply( input, second, result, vectorSize ); + MathEngine().VectorEltwiseMultiplyAdd( input, second, result, vectorSize ); + MathEngine().VectorReLU( input, result, vectorSize, zeroVal ); //Threshold == 0 + MathEngine().VectorReLU( input, result, vectorSize, mulVal ); //Threshold > 0 + MathEngine().VectorHSwish( input, result, vectorSize ); + return 0; + } + + params.Counters->Synchronise(); + + for( int i = 0; i < params.TestCount; ++i ) { + switch( params.Function ) { + case 0: MathEngine().VectorCopy( result, input, vectorSize ); break; + case 1: MathEngine().VectorFill( result, vectorSize, mulVal ); break; + case 2: MathEngine().VectorAdd( input, second, result, vectorSize ); break; + case 3: MathEngine().VectorAddValue( input, result, vectorSize, mulVal ); break; + case 4: MathEngine().VectorMultiply( input, second, vectorSize, mulVal ); break; + case 5: MathEngine().VectorEltwiseMultiply( input, second, result, vectorSize ); break; + case 6: MathEngine().VectorEltwiseMultiplyAdd( input, second, result, vectorSize ); break; + case 7: MathEngine().VectorReLU( input, result, vectorSize, zeroVal ); break; //Threshold == 0 + case 8: MathEngine().VectorReLU( input, result, vectorSize, mulVal ); break; //Threshold > 0 + case 9: MathEngine().VectorHSwish( input, result, vectorSize ); break; + default: + ASSERT_EXPR( false ); + } + } + + params.Counters->Synchronise(); + const double time = double( ( *params.Counters )[0].Value ) / 1000000 / params.TestCount; // average time in milliseconds + params.FOut << time << ","; + return time; +} + +} // namespace NeoMLTest + +//------------------------------------------------------------------------------------------------------------ + +class CMathEngineVectorBenchmarkTest : public CTestFixtureWithParams { +}; + +INSTANTIATE_TEST_CASE_P( CMathEngineVectorBenchmarkTestInstantiation, CMathEngineVectorBenchmarkTest, + ::testing::Values( + CTestParams( + "TestCount = 10000;" + "RepeatCount = 10;" + "VectorSize = 16;" + "VectorValues = (-128..128);" + ), + CTestParams( + "TestCount = 10000;" + "RepeatCount = 10;" + "VectorSize = 25;" + "VectorValues = (-128..128);" + ), + CTestParams( + "TestCount = 10000;" + "RepeatCount = 10;" + "VectorSize = 32;" + "VectorValues = (-128..128);" + ), + CTestParams( + "TestCount = 10000;" + "RepeatCount = 10;" + "VectorSize = 64;" + "VectorValues = (-128..128);" + ), + CTestParams( + "TestCount = 10000;" + "RepeatCount = 10;" + "VectorSize = 69;" + "VectorValues = (-128..128);" + ), + CTestParams( + "TestCount = 10000;" + "RepeatCount = 10;" + "VectorSize = 100000;" + "VectorValues = (-50..50);" + ), + CTestParams( + "TestCount = 10000;" + "RepeatCount = 10;" + "VectorSize = 999989;" + "VectorValues = (-10..10);" + ), + CTestParams( + "TestCount = 1000;" + "RepeatCount = 10;" + "VectorSize = 1179648;" + "VectorValues = (-1..1);" + ) + ) +); + +TEST_P( CMathEngineVectorBenchmarkTest, DISABLED_Random ) +{ + CTestParams testParams = GetParam(); + const int testCount = testParams.GetValue( "TestCount" ); + const int repeatCount = testParams.GetValue( "RepeatCount" ); + const int vectorSize = testParams.GetValue( "VectorSize" ); + const CInterval valuesInterval = testParams.GetInterval( "VectorValues" ); + + VectorBenchmarkParams params( /*warm-up*/-1, testCount, vectorSize, valuesInterval, 282 ); + vectorBenchmark( params ); + + for( int function = 0; function < 10; ++function ) { + params.FOut << std::endl << vectorFunctionsNames[function] << ","; + + double timeSum = 0; + for( int test = 0; test < repeatCount; ++test ) { + const int seed = 282 + test * 10000 + test % 3; + params.SetNextSeedForFunction( function, seed ); + timeSum += vectorBenchmark( params ); + } + GTEST_LOG_( INFO ) << vectorFunctionsNames[function] << "\t" << timeSum; + } +} diff --git a/NeoMathEngine/test/src/learn/SubVectorFromMatrixColumnsTest.cpp b/NeoMathEngine/test/src/learn/SubVectorFromMatrixColumnsTest.cpp index a79853452..0991d6d1c 100644 --- a/NeoMathEngine/test/src/learn/SubVectorFromMatrixColumnsTest.cpp +++ b/NeoMathEngine/test/src/learn/SubVectorFromMatrixColumnsTest.cpp @@ -1,4 +1,4 @@ -/* Copyright © 2017-2020 ABBYY Production LLC +/* Copyright © 2017-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -31,10 +31,10 @@ static void subVectorFromMatrixColumnsTestImpl( const CTestParams& params, int s CREATE_FILL_FLOAT_ARRAY( getMatrix, valuesInterval.Begin, valuesInterval.End, height * width, random ) CREATE_FILL_FLOAT_ARRAY( vector, valuesInterval.Begin, valuesInterval.End, height, random ) - std::vector expectedMatrix; - expectedMatrix = getMatrix; + std::vector expectedMatrix = getMatrix; // copy vector - MathEngine().SubVectorFromMatrixColumns( CARRAY_FLOAT_WRAPPER( expectedMatrix ), CARRAY_FLOAT_WRAPPER( getMatrix ), height, width, CARRAY_FLOAT_WRAPPER( vector ) ); + MathEngine().SubVectorFromMatrixColumns( CARRAY_FLOAT_WRAPPER( expectedMatrix ), + CARRAY_FLOAT_WRAPPER( getMatrix ), height, width, CARRAY_FLOAT_WRAPPER( vector ) ); for( int h = 0; h < height; ++h ) { for( int w = 0; w < width; ++w ) { @@ -47,7 +47,6 @@ static void subVectorFromMatrixColumnsTestImpl( const CTestParams& params, int s } } - //--------------------------------------------------------------------------------------------------------------------- class CSubVectorFromMatrixColumnsTest : public CTestFixtureWithParams { @@ -58,19 +57,13 @@ INSTANTIATE_TEST_CASE_P( CSubVectorFromMatrixColumnsTestInstantiation, CSubVecto CTestParams( "Height = (1..50);" "Width = (1..50);" - "BatchSize = (1..5);" - "VectorSize = (1..20);" "Values = (-1..1);" - "Channels = (1..5);" "TestCount = 100;" ), CTestParams( "Height = (100..500);" "Width = (100..500);" - "BatchSize = (1..5);" - "VectorSize = (30..50);" "Values = (-1..1);" - "Channels = (1..5);" "TestCount = 5;" ) )