diff --git a/NeoMathEngine/include/NeoMathEngine/SimdMathEngine.h b/NeoMathEngine/include/NeoMathEngine/SimdMathEngine.h index 22bdbfc1f..30909b19b 100644 --- a/NeoMathEngine/include/NeoMathEngine/SimdMathEngine.h +++ b/NeoMathEngine/include/NeoMathEngine/SimdMathEngine.h @@ -48,6 +48,33 @@ class ISimdMathEngine : public CCrtAllocatedObject { virtual void Exp( float* dst, const float* src, size_t dataSize, bool isMultithread = true ) = 0; virtual void RunOnceRestOfLstm( CMathEngineLstmDesc* desc, const CConstFloatHandle& inputStateBackLink, const CFloatHandle& outputStateBackLink, const CFloatHandle& outputMainBackLink, bool isMultithread = true ) = 0; + + using vectorAddFunc = void (*)( const float* first, const float* second, float* result, int vectorSize ); + using alignedVectorAdd = void (*)( const float* first, float* second, int vectorSize ); + using vectorEltwiseMax = void (*)( const float* first, const float* second, float* result, int vectorSize ); + using vectorReLU = void (*)( const float* first, float* result, int vectorSize ); + using vectorReLUTreshold = void (*)( const float* first, float* result, int vectorSize, float threshold ); + using alignedVectorMultiplyAndAdd = void (*)( const float* first, const float* second, + float* result, int vectorSize, const float* mult ); + using vectorMultiply = void (*)( const float* first, float multiplier, float* result, int vectorSize ); + using vectorEltwiseMultiply = void (*)( const float* first, const float* second, float* result, int vectorSize ); + using vectorEltwiseMultiplyAdd = void (*)( const float* first, const float* second, float* result, int vectorSize ); + using vectorAddValue = void (*)( const float* first, float value, float* result, int vectorSize ); + using vectorDotProduct = void (*)( const float* first, const float* second, float* result, int vectorSize ); + using vectorMinMax = void (*)( const float* first, float* result, int vectorSize, const float minValue, const float maxValue ); + + virtual vectorAddFunc GetVectorAddFunc() = 0; + virtual alignedVectorAdd GetAlignedVectorAddFunc() = 0; + virtual vectorEltwiseMax GetVectorMaxFunc() = 0; + virtual vectorReLU GetVectorReLUFunc() = 0; + virtual vectorReLUTreshold GetVectorReLUTresholdFunc() = 0; + virtual alignedVectorMultiplyAndAdd GetAlignedVectorMultiplyAndAddFunc() = 0; + virtual vectorMultiply GetVectorMultiplyFunc() = 0; + virtual vectorEltwiseMultiply GetVectorEltwiseMultiplyFunc() = 0; + virtual vectorEltwiseMultiplyAdd GetVectorEltwiseMultiplyAddFunc() = 0; + virtual vectorAddValue GetVectorAddValueFunc() = 0; + virtual vectorDotProduct GetVectorDotProductFunc() = 0; + virtual vectorMinMax GetVectorMinMaxFunc() = 0; }; } diff --git a/NeoMathEngine/src/CPU/CpuMathEngine.cpp b/NeoMathEngine/src/CPU/CpuMathEngine.cpp index 2e3227e8e..09b349169 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngine.cpp +++ b/NeoMathEngine/src/CPU/CpuMathEngine.cpp @@ -24,6 +24,7 @@ limitations under the License. #include #include #include +#include #if FINE_PLATFORM( FINE_ANDROID ) || FINE_PLATFORM( FINE_LINUX ) #include @@ -78,6 +79,33 @@ CCpuMathEngine::CCpuMathEngine( int _threadCount, size_t _memoryLimit ) : #ifdef NEOML_USE_MKL vmlSetMode( VML_ERRMODE_NOERR ); #endif + if( simdMathEngine != nullptr ) { + vectorAdd = simdMathEngine->GetVectorAddFunc(); + alignedVectorAdd = simdMathEngine->GetAlignedVectorAddFunc(); + vectorEltwiseMax = simdMathEngine->GetVectorMaxFunc(); + vectorReLU = simdMathEngine->GetVectorReLUFunc(); + vectorReLUTreshold = simdMathEngine->GetVectorReLUTresholdFunc(); + alignedVectorMultiplyAndAdd = simdMathEngine->GetAlignedVectorMultiplyAndAddFunc(); + vectorMultiply = simdMathEngine->GetVectorMultiplyFunc(); + vectorEltwiseMultiply = simdMathEngine->GetVectorEltwiseMultiplyFunc(); + vectorEltwiseMultiplyAdd = simdMathEngine->GetVectorEltwiseMultiplyAddFunc(); + vectorAddValue = simdMathEngine->GetVectorAddValueFunc(); + vectorDotProduct = simdMathEngine->GetVectorDotProductFunc(); + vectorMinMax = simdMathEngine->GetVectorMinMaxFunc(); + } else { + vectorAdd = &NeoML::vectorAdd; + alignedVectorAdd = &NeoML::alignedVectorAdd; + vectorEltwiseMax = &NeoML::vectorEltwiseMax; + vectorReLU = &NeoML::vectorReLU; + vectorReLUTreshold = &NeoML::vectorReLUTreshold; + alignedVectorMultiplyAndAdd = &NeoML::alignedVectorMultiplyAndAdd; + vectorMultiply = &NeoML::vectorMultiply; + vectorEltwiseMultiply = &NeoML::vectorEltwiseMultiply; + vectorEltwiseMultiplyAdd = &NeoML::vectorEltwiseMultiplyAdd; + vectorAddValue = &NeoML::vectorAddValue; + vectorDotProduct = &NeoML::vectorDotProduct; + vectorMinMax = &NeoML::vectorMinMax; + } } CCpuMathEngine::~CCpuMathEngine() diff --git a/NeoMathEngine/src/CPU/CpuMathEngine.h b/NeoMathEngine/src/CPU/CpuMathEngine.h index 7f96278ab..3e4f81a43 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngine.h +++ b/NeoMathEngine/src/CPU/CpuMathEngine.h @@ -558,6 +558,20 @@ class CCpuMathEngine : public IMathEngine, public IRawMemoryManager { std::unique_ptr simdMathEngine; // interface for using simd instructions SgemmFunc customSgemmFunction; // Used when it is availabled and is faster then default sgemm + void ( *vectorAdd )( const float* first, const float* second, float* result, int vectorSize ); + void ( *alignedVectorAdd )( const float* first, float* second, int vectorSize ); + void ( *vectorEltwiseMax )( const float* first, const float* second, float* result, int vectorSize ); + void ( *vectorReLU )( const float* first, float* result, int vectorSize ); + void ( *vectorReLUTreshold )( const float* first, float* result, int vectorSize, float threshold ); + void ( *alignedVectorMultiplyAndAdd )( const float* first, const float* second, + float* result, int vectorSize, const float* mult ); + void ( *vectorMultiply )( const float* first, float multiplier, float* result, int vectorSize ); + void ( *vectorEltwiseMultiply )( const float* first, const float* second, float* result, int vectorSize ); + void ( *vectorEltwiseMultiplyAdd )( const float* first, const float* second, float* result, int vectorSize ); + void ( *vectorAddValue )( const float* first, float value, float* result, int vectorSize ); + void ( *vectorDotProduct )( const float* first, const float* second, float* result, int vectorSize ); + void ( *vectorMinMax )( const float* first, float* result, int vectorSize, const float minValue, const float maxValue ); + IMathEngine& mathEngine() { IMathEngine* engine = this; return *engine; } void blob3dConvolution1x1x1( const CBlobDesc& source, const CBlobDesc& filter, const CBlobDesc& result, diff --git a/NeoMathEngine/src/CPU/CpuMathEngineBlas.cpp b/NeoMathEngine/src/CPU/CpuMathEngineBlas.cpp index f05070f41..20711a80f 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngineBlas.cpp +++ b/NeoMathEngine/src/CPU/CpuMathEngineBlas.cpp @@ -219,7 +219,7 @@ void CCpuMathEngine::AddVectorToMatrixColumns(const CConstFloatHandle& matrixHan const float* vector = GetRaw( vectorHandle ); for(int i = 0; i < matrixHeight; ++i) { - vectorAddValue(matrix, result, matrixWidth, *vector); + vectorAddValue(matrix, *vector, result, matrixWidth); matrix += matrixWidth; result += matrixWidth; ++vector; @@ -276,7 +276,7 @@ void CCpuMathEngine::RowMultiplyMatrixByMatrix(const CConstFloatHandle& firstHan float* result = GetRaw( resultHandle ); for(int i = 0; i < height; ++i) { - vectorDotProduct(first, second, width, result); + vectorDotProduct(first, second, result, width); first += width; second += width; ++result; @@ -819,7 +819,7 @@ void CCpuMathEngine::MultiplyDiagMatrixByMatrix( const CConstFloatHandle& firstH NEOML_OMP_FOR_NUM_THREADS( curThreadCount ) for( int i = 0; i < firstSize; i++ ) { const float multiplier = *( first + i ); - vectorMultiply( second + i * secondWidth, result + i * secondWidth, multiplier, secondWidth ); + vectorMultiply( second + i * secondWidth, multiplier, result + i * secondWidth, secondWidth ); } } diff --git a/NeoMathEngine/src/CPU/CpuMathEngineDnn3dConv.cpp b/NeoMathEngine/src/CPU/CpuMathEngineDnn3dConv.cpp index 24f7784b5..247f2d7f8 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngineDnn3dConv.cpp +++ b/NeoMathEngine/src/CPU/CpuMathEngineDnn3dConv.cpp @@ -98,7 +98,7 @@ void CCpuMathEngine::blob3dConvolution1x1x1Backward( const CCommon3dConvolutionD for( int i = 0; i < resultBlob.Width(); ++i ) { float* inputDiffPixel = inputDiffCol; for( int k = 0; k < resultBlob.Depth(); ++k ) { - NeoML::vectorAdd( inputDiffPixel, resultData, inputDiffPixel, inputDiff.Channels() ); + vectorAdd( inputDiffPixel, resultData, inputDiffPixel, inputDiff.Channels() ); inputDiffPixel += inputDiff.Channels() * desc.StrideDepth; resultData += inputDiff.Channels(); } @@ -437,8 +437,8 @@ void CCpuMathEngine::blob3dConvolutionBackward( const CCommon3dConvolutionDesc& int outputLineCount; if( OmpGetTaskIndexAndCount( outputLineY, outputLineStart, outputLineCount ) ) { if( freeTermData == 0 ) { - vectorFill( resultData + outputLineStart * outputRowSize, - 0, outputLineCount * outputRowSize ); + vectorFill0( resultData + outputLineStart * outputRowSize, + outputLineCount * outputRowSize ); } int outputLineEnd = outputLineStart + outputLineCount; diff --git a/NeoMathEngine/src/CPU/CpuMathEngineDnnChannelwiseConv.cpp b/NeoMathEngine/src/CPU/CpuMathEngineDnnChannelwiseConv.cpp index 469d14efe..77cc2c7a9 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngineDnnChannelwiseConv.cpp +++ b/NeoMathEngine/src/CPU/CpuMathEngineDnnChannelwiseConv.cpp @@ -166,7 +166,7 @@ void CCpuMathEngine::blobChannelwiseConvolutionFilter3x3Padding1Stride2( const C if( freeTerm != 0 ) { fillResultRow( desc, freeTerm, resultFirstRow ); } else { - NeoML::vectorFill( resultFirstRow, 0, resultDesc.Width() * channels ); + NeoML::vectorFill0( resultFirstRow, resultDesc.Width() * channels ); } processFilterRowStride2( desc, filter + filterRowSize, sourceFirstRow, resultFirstRow ); @@ -181,7 +181,7 @@ void CCpuMathEngine::blobChannelwiseConvolutionFilter3x3Padding1Stride2( const C if( freeTerm != 0 ) { fillResultRow( desc, freeTerm, resRow ); } else { - NeoML::vectorFill( resRow, 0, resultDesc.Width() * channels ); + NeoML::vectorFill0( resRow, resultDesc.Width() * channels ); } processFilterRowStride2( desc, filter, srcRow, resRow ); @@ -193,7 +193,7 @@ void CCpuMathEngine::blobChannelwiseConvolutionFilter3x3Padding1Stride2( const C if( freeTerm != 0 ) { fillResultRow( desc, freeTerm, resultLastRow ); } else { - NeoML::vectorFill( resultLastRow, 0, resultDesc.Width() * channels ); + NeoML::vectorFill0( resultLastRow, resultDesc.Width() * channels ); } processFilterRowStride2( desc, filter, sourceLastRow, resultLastRow ); @@ -263,7 +263,7 @@ void CCpuMathEngine::blobChannelwiseConvolutionFilter3x3Padding1Stride1( const C if( freeTerm != 0 ) { fillResultRow( desc, freeTerm, resultFirstRow ); } else { - NeoML::vectorFill( resultFirstRow, 0, resultDesc.Width() * channels ); + NeoML::vectorFill0( resultFirstRow, resultDesc.Width() * channels ); } processFilterRowStride1( desc, filter + filterRowSize, sourceFirstRow, resultFirstRow ); if( resultCount >= 0 ) { @@ -277,7 +277,7 @@ void CCpuMathEngine::blobChannelwiseConvolutionFilter3x3Padding1Stride1( const C if( freeTerm != 0 ) { fillResultRow( desc, freeTerm, resRow ); } else { - NeoML::vectorFill( resRow, 0, resultDesc.Width() * channels ); + NeoML::vectorFill0( resRow, resultDesc.Width() * channels ); } processFilterRowStride1( desc, filter, srcRow, resRow ); @@ -289,7 +289,7 @@ void CCpuMathEngine::blobChannelwiseConvolutionFilter3x3Padding1Stride1( const C if( freeTerm != 0 ) { fillResultRow( desc, freeTerm, resultLastRow ); } else { - NeoML::vectorFill( resultLastRow, 0, resultDesc.Width() * channels ); + NeoML::vectorFill0( resultLastRow, resultDesc.Width() * channels ); } processFilterRowStride1( desc, filter, sourceLastRow, resultLastRow ); @@ -366,7 +366,7 @@ void CCpuMathEngine::BlobChannelwiseConvolution( const CChannelwiseConvolutionDe NeoML::dataCopy(rowStart, freeTerm, channels); } } else { - NeoML::vectorFill(resultRow, 0, resultDesc.Width() * channels); + NeoML::vectorFill0(resultRow, resultDesc.Width() * channels); } const int filterFirstRow = max( 0, -firstFilteredRow ); diff --git a/NeoMathEngine/src/CPU/CpuMathEngineDnnConv.cpp b/NeoMathEngine/src/CPU/CpuMathEngineDnnConv.cpp index 905f05f19..4d1f6036f 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngineDnnConv.cpp +++ b/NeoMathEngine/src/CPU/CpuMathEngineDnnConv.cpp @@ -355,7 +355,7 @@ void CCpuMathEngine::fillTempData( const float* sourceData, float* tempData, con for( int h = 0; h < desc.Filter.Height(); h++ ) { if( 0 <= sourceHeight + h * desc.DilationHeight && sourceHeight + h * desc.DilationHeight < desc.Source.Height() ) { if( startPaddingSize > 0 ) { - NeoML::vectorFill( tempStartPaddingPtr, 0.0, startPaddingSize * channelsCount ); + NeoML::vectorFill0( tempStartPaddingPtr, startPaddingSize * channelsCount ); } if( desc.DilationWidth == 1 ) { @@ -369,10 +369,10 @@ void CCpuMathEngine::fillTempData( const float* sourceData, float* tempData, con } if( endPaddingSize > 0 ) { - NeoML::vectorFill( tempEndPaddingPtr, 0.0, endPaddingSize * channelsCount ); + NeoML::vectorFill0( tempEndPaddingPtr, endPaddingSize * channelsCount ); } } else { - NeoML::vectorFill( tempStartPaddingPtr, 0.0, filterLineSize ); + NeoML::vectorFill0( tempStartPaddingPtr, filterLineSize ); } tempStartPaddingPtr += filterLineSize; @@ -577,7 +577,7 @@ void CCpuMathEngine::backwardConvolutionAddFilterToOutput( const CCpuConvolution // Set the free term setVectorToMatrixRows( outputDataPtr, output.Width(), output.Depth() * output.Channels(), freeTermDataRaw ); } else { - vectorFill( outputDataPtr, 0, output.Width() * output.Depth() * output.Channels() ); + vectorFill0( outputDataPtr, output.Width() * output.Depth() * output.Channels() ); } int batch = step / output.Height(); diff --git a/NeoMathEngine/src/CPU/CpuMathEngineDnnLstm.h b/NeoMathEngine/src/CPU/CpuMathEngineDnnLstm.h index 65c427cd5..67a89700b 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngineDnnLstm.h +++ b/NeoMathEngine/src/CPU/CpuMathEngineDnnLstm.h @@ -105,19 +105,19 @@ inline void CMathEngineLstmDesc::RunOnceRestOfLstm( const CConstFloatHandle& inp NeoML::vectorSigmoid( GetRaw( outputData ), GetRaw( outputData ), CurDataSize ); // Multiply input gates - NeoML::vectorEltwiseMultiply( GetRaw( inputData ), GetRaw( inputTanhData ), GetRaw( inputData ), CurDataSize ); + vectorEltwiseMultiply( GetRaw( inputData ), GetRaw( inputTanhData ), GetRaw( inputData ), CurDataSize ); // Multiply state backlink with forget gate - NeoML::vectorEltwiseMultiply( GetRaw( forgetData ), GetRaw( inputStateBackLink + OffsetBackLink ), GetRaw( forgetData ), CurDataSize ); + vectorEltwiseMultiply( GetRaw( forgetData ), GetRaw( inputStateBackLink + OffsetBackLink ), GetRaw( forgetData ), CurDataSize ); // Append input gate to state backlink - NeoML::vectorAdd( GetRaw( forgetData ), GetRaw( inputData ), GetRaw( outputStateBackLink + OffsetBackLink ), CurDataSize ); + vectorAdd( GetRaw( forgetData ), GetRaw( inputData ), GetRaw( outputStateBackLink + OffsetBackLink ), CurDataSize ); // Apply tanh to state baclink - NeoML::vectorTanh( GetRaw( outputStateBackLink + OffsetBackLink ), GetRaw( inputData ), CurDataSize ); + vectorTanh( GetRaw( outputStateBackLink + OffsetBackLink ), GetRaw( inputData ), CurDataSize ); // Multiply output gate with result of previous operation - NeoML::vectorEltwiseMultiply( GetRaw( outputData ), GetRaw( inputData ), GetRaw( outputMainBackLink + OffsetBackLink ), CurDataSize ); + vectorEltwiseMultiply( GetRaw( outputData ), GetRaw( inputData ), GetRaw( outputMainBackLink + OffsetBackLink ), CurDataSize ); } } } diff --git a/NeoMathEngine/src/CPU/CpuMathEngineDnnRleConv.cpp b/NeoMathEngine/src/CPU/CpuMathEngineDnnRleConv.cpp index 2907b56f2..268a4947c 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngineDnnRleConv.cpp +++ b/NeoMathEngine/src/CPU/CpuMathEngineDnnRleConv.cpp @@ -115,7 +115,7 @@ static inline void updateFilterConv( IMathEngine& mathEngine, CCpuRleConvolution const float* convertFilterDataPtr = GetRaw( desc.ConvertedFilter.GetHandle() ); for( int j = 0; j < filterHeight; ++j ) { for( int i = 0; i < filterWidth; ++i ) { - alignedVectorAdd( zeroFilterConvPtr, convertFilterDataPtr, filterCount ); + alignedVectorAdd( convertFilterDataPtr, zeroFilterConvPtr, filterCount ); convertFilterDataPtr += filterCount; } zeroFilterConvPtr += filterCount; @@ -268,7 +268,7 @@ void CCpuMathEngine::BlobRleConvolution( const CRleConvolutionDesc& convDesc, co const float* curFilterConvData = filterConvData + index * filterConvStep; float* curOutput = output; for( int j = 0; j < jCount; ++j ) { - alignedVectorAdd( curOutput, curFilterConvData, filterCount ); + alignedVectorAdd( curFilterConvData, curOutput, filterCount ); curFilterConvData += strideHeight * filterCount; curOutput += outputRowSize; } @@ -382,7 +382,7 @@ void CCpuMathEngine::BlobRleConvolutionLearnAdd( const CRleConvolutionDesc& conv // Calculate diff separately for the free terms for( int j = 0; j < outputDiff.Height(); ++j ) { for( int k = 0; k < outputDiff.Width(); ++k ) { - alignedVectorAdd( freeTermDiffReductionPrivatePtr, outputDiffDataPtr, filterCount ); + alignedVectorAdd( outputDiffDataPtr, freeTermDiffReductionPrivatePtr, filterCount ); outputDiffDataPtr += filterCount; } } diff --git a/NeoMathEngine/src/CPU/CpuMathEngineVectorMath.cpp b/NeoMathEngine/src/CPU/CpuMathEngineVectorMath.cpp index 3bdff904e..c43e3f7a9 100644 --- a/NeoMathEngine/src/CPU/CpuMathEngineVectorMath.cpp +++ b/NeoMathEngine/src/CPU/CpuMathEngineVectorMath.cpp @@ -128,11 +128,11 @@ void CCpuMathEngine::VectorAdd(const CConstFloatHandle& firstHandle, const CCons NEOML_OMP_NUM_THREADS( curThreadCount ) { int index, count; if( OmpGetTaskIndexAndCount( vectorSize, 16, index, count ) ) { - NeoML::vectorAdd( GetRaw(firstHandle + index), GetRaw(secondHandle + index), GetRaw(resultHandle + index), count ); + vectorAdd( GetRaw(firstHandle + index), GetRaw(secondHandle + index), GetRaw(resultHandle + index), count ); } } } else { - NeoML::vectorAdd( GetRaw(firstHandle), GetRaw(secondHandle), GetRaw(resultHandle), vectorSize ); + vectorAdd( GetRaw(firstHandle), GetRaw(secondHandle), GetRaw(resultHandle), vectorSize ); } } @@ -352,7 +352,7 @@ void CCpuMathEngine::VectorAddValue(const CConstFloatHandle& firstHandle, const float* result = GetRaw( resultHandle ); float value = *GetRaw( addition ); - vectorAddValue( first, result, vectorSize, value ); + vectorAddValue( first, value, result, vectorSize ); } void CCpuMathEngine::VectorDotProduct(const CConstFloatHandle& firstHandle, const CConstFloatHandle& secondHandle, @@ -367,7 +367,7 @@ void CCpuMathEngine::VectorDotProduct(const CConstFloatHandle& firstHandle, cons const float* second = GetRaw( secondHandle ); float* result = GetRaw( resultHandle ); - vectorDotProduct( first, second, vectorSize, result ); + vectorDotProduct( first, second, result, vectorSize ); } void CCpuMathEngine::VectorTopK(const CConstFloatHandle& firstHandle, int firstSize, int k, const CFloatHandle& resultHandle, @@ -470,11 +470,11 @@ void CCpuMathEngine::VectorMultiply(const CConstFloatHandle& firstHandle, NEOML_OMP_NUM_THREADS( curThreadCount ) { int index, count; if( OmpGetTaskIndexAndCount( vectorSize, 16, index, count ) ) { - vectorMultiply( GetRaw( firstHandle + index ), GetRaw( resultHandle + index ), multiplier, count ); + vectorMultiply( GetRaw( firstHandle + index ), multiplier, GetRaw( resultHandle + index ), count ); } } } else { - vectorMultiply( GetRaw( firstHandle ), GetRaw( resultHandle ), multiplier, vectorSize ); + vectorMultiply( GetRaw( firstHandle ), multiplier, GetRaw( resultHandle ), vectorSize ); } } @@ -492,11 +492,11 @@ void CCpuMathEngine::VectorEltwiseMultiply(const CConstFloatHandle& firstHandle, NEOML_OMP_NUM_THREADS( curThreadCount ) { int index, count; if( OmpGetTaskIndexAndCount( vectorSize, 16, index, count ) ) { - NeoML::vectorEltwiseMultiply( GetRaw( firstHandle + index ), GetRaw( secondHandle + index ), GetRaw( resultHandle + index ), count ); + vectorEltwiseMultiply( GetRaw( firstHandle + index ), GetRaw( secondHandle + index ), GetRaw( resultHandle + index ), count ); } } } else { - NeoML::vectorEltwiseMultiply( GetRaw( firstHandle ), GetRaw( secondHandle ), GetRaw( resultHandle ), vectorSize ); + vectorEltwiseMultiply( GetRaw( firstHandle ), GetRaw( secondHandle ), GetRaw( resultHandle ), vectorSize ); } } @@ -656,11 +656,11 @@ void CCpuMathEngine::VectorMinMax(const CConstFloatHandle& firstHandle, const CF NEOML_OMP_NUM_THREADS( curThreadCount ) { int index, count; if( OmpGetTaskIndexAndCount( vectorSize, 16, index, count ) ) { - vectorMinMax( GetRaw(firstHandle + index), GetRaw(resultHandle + index), minValue, maxValue, count ); + vectorMinMax( GetRaw(firstHandle + index), GetRaw(resultHandle + index), count, minValue, maxValue ); } } } else { - vectorMinMax( GetRaw(firstHandle ), GetRaw(resultHandle ), minValue, maxValue, vectorSize ); + vectorMinMax( GetRaw(firstHandle ), GetRaw(resultHandle ), vectorSize, minValue, maxValue ); } } diff --git a/NeoMathEngine/src/CPU/arm/CpuArmMathEngineVectorMath.cpp b/NeoMathEngine/src/CPU/arm/CpuArmMathEngineVectorMath.cpp index a7ca49970..371b5d94d 100644 --- a/NeoMathEngine/src/CPU/arm/CpuArmMathEngineVectorMath.cpp +++ b/NeoMathEngine/src/CPU/arm/CpuArmMathEngineVectorMath.cpp @@ -138,7 +138,7 @@ void CCpuMathEngine::VectorReLU(const CConstFloatHandle& firstHandle, int count; if( OmpGetTaskIndexAndCount( vectorSize, 16, index, count ) ) { if( threshold > 0 ) { - vectorReLU( first + index, result + index, count, threshold ); + vectorReLUTreshold( first + index, result + index, count, threshold ); } else { vectorReLU( first + index, result + index, count ); } diff --git a/NeoMathEngine/src/CPU/arm/CpuArmMathEngineVectorMathPrivate.h b/NeoMathEngine/src/CPU/arm/CpuArmMathEngineVectorMathPrivate.h index 19343fdb9..287a11c05 100644 --- a/NeoMathEngine/src/CPU/arm/CpuArmMathEngineVectorMathPrivate.h +++ b/NeoMathEngine/src/CPU/arm/CpuArmMathEngineVectorMathPrivate.h @@ -209,29 +209,29 @@ inline void vectorAdd( const float* first, const float* second, float* result, i //------------------------------------------------------------------------------------------------------------ -inline void alignedVectorAdd( float* first, const float* second, int vectorSize ) +inline void alignedVectorAdd( const float* first, float* second, int vectorSize ) { int coord = 0; - for( ; coord <= vectorSize - 16; coord += 16, first += 16, second += 16 ) { - NEON_LOAD_16_FLOATS(first, first); + for( ; coord <= vectorSize - 16; coord += 16, second += 16, first += 16 ) { NEON_LOAD_16_FLOATS(second, second); + NEON_LOAD_16_FLOATS(first, first); - float32x4_t result0 = vaddq_f32(first0, second0); - float32x4_t result1 = vaddq_f32(first1, second1); - float32x4_t result2 = vaddq_f32(first2, second2); - float32x4_t result3 = vaddq_f32(first3, second3); + float32x4_t result0 = vaddq_f32(second0, first0); + float32x4_t result1 = vaddq_f32(second1, first1); + float32x4_t result2 = vaddq_f32(second2, first2); + float32x4_t result3 = vaddq_f32(second3, first3); - NEON_STORE_16_FLOATS(result, first); + NEON_STORE_16_FLOATS(result, second); } - for( ; coord <= vectorSize - 4; coord += 4, first += 4, second += 4 ) { - float32x4_t first0 = LoadNeon4(first); + for( ; coord <= vectorSize - 4; coord += 4, second += 4, first += 4 ) { float32x4_t second0 = LoadNeon4(second); + float32x4_t first0 = LoadNeon4(first); - float32x4_t result0 = vaddq_f32(first0, second0); + float32x4_t result0 = vaddq_f32(second0, first0); - StoreNeon4(result0, first); + StoreNeon4(result0, second); } } @@ -272,7 +272,7 @@ inline void alignedVectorMultiplyAndAdd( const float* first, const float* second //------------------------------------------------------------------------------------------------------------ -inline void vectorMultiply( const float* first, float* result, float multiplier, int vectorSize ) +inline void vectorMultiply( const float* first, float multiplier, float* result, int vectorSize ) { int count = GetCount4(vectorSize); float32x4_t mult = vdupq_n_f32(multiplier); @@ -467,7 +467,7 @@ inline void vectorReLU( const float* first, float* result, int vectorSize ) } } -inline void vectorReLU( const float* first, float* result, int vectorSize, float threshold ) +inline void vectorReLUTreshold( const float* first, float* result, int vectorSize, float threshold ) { int coord = 0; @@ -504,7 +504,7 @@ inline void vectorReLU( const float* first, float* result, int vectorSize, float //------------------------------------------------------------------------------------------------------------ -inline void vectorAddValue( const float* first, float* result, int vectorSize, float value ) +inline void vectorAddValue( const float* first, float value, float* result, int vectorSize ) { float32x4_t addition = vdupq_n_f32(value); @@ -526,7 +526,7 @@ inline void vectorAddValue( const float* first, float* result, int vectorSize, f //------------------------------------------------------------------------------------------------------------ -inline void vectorDotProduct( const float* first, const float* second, int vectorSize, float* result ) +inline void vectorDotProduct( const float* first, const float* second, float* result, int vectorSize ) { float32x4_t acc = vdupq_n_f32(0); @@ -709,7 +709,7 @@ static inline void qrnnIfPoolingStep( const float* z, const float* f, const floa } } -inline void vectorMinMax( const float* first, float* result, const float minValue, const float maxValue, int vectorSize ) +inline void vectorMinMax( const float* first, float* result, int vectorSize, const float minValue, const float maxValue ) { int count = GetCount4(vectorSize); diff --git a/NeoMathEngine/src/CPU/x86/CpuX86.h b/NeoMathEngine/src/CPU/x86/CpuX86.h index 387511d4b..873c5b17a 100644 --- a/NeoMathEngine/src/CPU/x86/CpuX86.h +++ b/NeoMathEngine/src/CPU/x86/CpuX86.h @@ -324,90 +324,13 @@ inline void checkSse2(int size, int& sseSize, int& nonSseSize) return checkSse(size, sseSize, nonSseSize); } -inline void dataCopy(float* dst, const float* src, int vectorSize) +template +inline void dataCopy(T* dst, const T* src, int vectorSize) { - static_assert( sizeof(float) == sizeof(unsigned int), "Size of float isn't equal to size of unsigned int." ); - - int sseSize; - int nonSseSize; - checkSse(vectorSize, sseSize, nonSseSize); - - while( sseSize >= 4 ) { - _mm_storeu_ps(dst, _mm_loadu_ps(src)); - dst += 4; - src += 4; - _mm_storeu_ps(dst, _mm_loadu_ps(src)); - dst += 4; - src += 4; - _mm_storeu_ps(dst, _mm_loadu_ps(src)); - dst += 4; - src += 4; - _mm_storeu_ps(dst, _mm_loadu_ps(src)); - dst += 4; - src += 4; - - sseSize -= 4; + if( vectorSize > 0 ) { + std::copy_n( src, vectorSize, dst ); } - - for(int i = 0; i < sseSize; ++i) { - _mm_storeu_ps(dst, _mm_loadu_ps(src)); - dst += 4; - src += 4; - } - -#if FINE_PLATFORM(FINE_WINDOWS) - if( nonSseSize > 0 ) { - __movsd((DWORD*)dst, (DWORD*)src, nonSseSize); - } -#elif FINE_PLATFORM(FINE_LINUX) || FINE_PLATFORM(FINE_DARWIN) || FINE_PLATFORM(FINE_ANDROID) || FINE_PLATFORM(FINE_IOS) - for(int i = 0; i < nonSseSize; ++i) { - *dst++ = *src++; - } -#else - #error "Platform isn't supported!" -#endif -} - -inline void dataCopy(int* dst, const int* src, int vectorSize) -{ - int sseSize; - int nonSseSize; - checkSse2(vectorSize, sseSize, nonSseSize); - - while( sseSize >= 4 ) { - _mm_storeu_si128((__m128i*)dst, _mm_loadu_si128((__m128i*)src)); - dst += 4; - src += 4; - _mm_storeu_si128((__m128i*)dst, _mm_loadu_si128((__m128i*)src)); - dst += 4; - src += 4; - _mm_storeu_si128((__m128i*)dst, _mm_loadu_si128((__m128i*)src)); - dst += 4; - src += 4; - _mm_storeu_si128((__m128i*)dst, _mm_loadu_si128((__m128i*)src)); - dst += 4; - src += 4; - - sseSize -= 4; - } - - for(int i = 0; i < sseSize; ++i) { - _mm_storeu_si128((__m128i*)dst, _mm_loadu_si128((const __m128i*)src)); - dst += 4; - src += 4; - } - -#if FINE_PLATFORM(FINE_WINDOWS) - if(nonSseSize > 0) { - __movsd((unsigned long*)dst, (const unsigned long*)src, nonSseSize); - } -#elif FINE_PLATFORM(FINE_LINUX) || FINE_PLATFORM(FINE_DARWIN) || FINE_PLATFORM(FINE_ANDROID) || FINE_PLATFORM(FINE_IOS) - for(int i = 0; i < nonSseSize; ++i) { - *dst++ = *src++; - } -#else - #error "Platform isn't supported!" -#endif + } inline float euclidianNoSSE( const float* x, const float* y, const int size ) diff --git a/NeoMathEngine/src/CPU/x86/CpuX86MathEngineDnn3dConv.cpp b/NeoMathEngine/src/CPU/x86/CpuX86MathEngineDnn3dConv.cpp index 99c492536..db31bf22b 100644 --- a/NeoMathEngine/src/CPU/x86/CpuX86MathEngineDnn3dConv.cpp +++ b/NeoMathEngine/src/CPU/x86/CpuX86MathEngineDnn3dConv.cpp @@ -57,7 +57,7 @@ void CCpuMathEngine::blob3dConvolution1x1x1( const CBlobDesc& source, const CBl if( freeTermData != 0 ) { NeoML::setVectorToMatrixRows(outputDataPtr, geomCount, newChannels, freeTermData); } else { - NeoML::vectorFill(outputDataPtr, 0, geomCount * newChannels); + NeoML::vectorFill0(outputDataPtr, geomCount * newChannels); } multiplyMatrixByTransposedMatrixAndAdd(sourceData + geomStart * channels, geomCount, channels, channels, @@ -81,7 +81,7 @@ void CCpuMathEngine::blob3dConvolution1x1x1( const CBlobDesc& source, const CBl } } else { for( float* res = resultPtr; res < resultEnd; res += newChannels ) { - NeoML::vectorFill(res, 0, channelCount); + NeoML::vectorFill0(res, channelCount); } } multiplyMatrixByTransposedMatrixAndAdd(sourceData, @@ -120,7 +120,7 @@ void CCpuMathEngine::blob3dConvolution1x1x1( const CBlobDesc& source, const CBl if( freeTermData != 0 ) { NeoML::setVectorToMatrixRows(outputDataPtr, geomCount, newChannels, freeTermData); } else { - NeoML::vectorFill(outputDataPtr, 0, geomCount * newChannels); + NeoML::vectorFill0(outputDataPtr, geomCount * newChannels); } multiplyMatrixByTransposedMatrixAndAdd(repackedData + geomStart * channels, geomCount, channels, channels, diff --git a/NeoMathEngine/src/CPU/x86/CpuX86MathEngineVectorMath.cpp b/NeoMathEngine/src/CPU/x86/CpuX86MathEngineVectorMath.cpp index b2d424bcd..c92eb331c 100644 --- a/NeoMathEngine/src/CPU/x86/CpuX86MathEngineVectorMath.cpp +++ b/NeoMathEngine/src/CPU/x86/CpuX86MathEngineVectorMath.cpp @@ -377,7 +377,7 @@ void CCpuMathEngine::VectorReLU( const CConstFloatHandle& firstHandle, const CFl int index, count; if( OmpGetTaskIndexAndCount( vectorSize, 16, index, count ) ) { if( threshold > 0 ) { - vectorReLU( first + index, result + index, count, threshold ); + vectorReLUTreshold( first + index, result + index, count, threshold ); } else { vectorReLU( first + index, result + index, count ); } @@ -385,7 +385,7 @@ void CCpuMathEngine::VectorReLU( const CConstFloatHandle& firstHandle, const CFl } } else { if( threshold > 0 ) { - vectorReLU( first, result, vectorSize, threshold ); + vectorReLUTreshold( first, result, vectorSize, threshold ); } else { vectorReLU( first, result, vectorSize ); } diff --git a/NeoMathEngine/src/CPU/x86/CpuX86MathEngineVectorMathPrivate.h b/NeoMathEngine/src/CPU/x86/CpuX86MathEngineVectorMathPrivate.h index 875883f99..e50f7cca9 100644 --- a/NeoMathEngine/src/CPU/x86/CpuX86MathEngineVectorMathPrivate.h +++ b/NeoMathEngine/src/CPU/x86/CpuX86MathEngineVectorMathPrivate.h @@ -82,107 +82,18 @@ inline void channelwise1x3( const float* source, const float* filter0, const flo //------------------------------------------------------------------------------------------------------------ -inline void vectorFill( float* result, float value, int vectorSize ) +template +inline void vectorFill( T* result, T value, int vectorSize ) { - int sseSize; - int nonSseSize; - checkSse(vectorSize, sseSize, nonSseSize); - - __m128 valueSse = _mm_set_ps1(value); - - while( sseSize >= 4 ) { - _mm_storeu_ps(result, valueSse); - result += 4; - _mm_storeu_ps(result, valueSse); - result += 4; - _mm_storeu_ps(result, valueSse); - result += 4; - _mm_storeu_ps(result, valueSse); - result += 4; - - sseSize -= 4; - } - - while( sseSize > 0 ) { - _mm_storeu_ps(result, valueSse); - result += 4; - sseSize--; - } - - for(int i = 0; i < nonSseSize; ++i) { - *result++ = value; - } -} - -inline void vectorFill( int* result, int value, int vectorSize ) -{ - int sseSize; - int nonSseSize; - checkSse( vectorSize, sseSize, nonSseSize ); - - __m128i valueSse = _mm_set1_epi32( value ); - - while( sseSize >= 4 ) { - _mm_storeu_si128( ( __m128i* )result, valueSse ); - result += 4; - _mm_storeu_si128( ( __m128i* )result, valueSse ); - result += 4; - _mm_storeu_si128( ( __m128i* )result, valueSse ); - result += 4; - _mm_storeu_si128( ( __m128i* )result, valueSse ); - result += 4; - - sseSize -= 4; + if( vectorSize > 0 ) { + std::fill_n( result, vectorSize, value ); } - - while( sseSize > 0 ) { - _mm_storeu_si128( ( __m128i* )result, valueSse ); - result += 4; - sseSize--; - } - -#if FINE_PLATFORM(FINE_WINDOWS) - if( nonSseSize > 0 ) { - __stosd( (DWORD*) result, value, nonSseSize ); - } -#elif FINE_PLATFORM(FINE_LINUX) || FINE_PLATFORM(FINE_DARWIN) || FINE_PLATFORM(FINE_ANDROID) || FINE_PLATFORM(FINE_IOS) - for( int i = 0; i < nonSseSize; ++i ) { - *result++ = value; - } -#else -#error "Platform isn't supported!" -#endif } inline void vectorFill0( float* result, int vectorSize ) { - int sseSize; - int nonSseSize; - checkSse(vectorSize, sseSize, nonSseSize); - - __m128 valueSse = _mm_setzero_ps(); - - while( sseSize >= 4 ) { - _mm_storeu_ps(result, valueSse); - result += 4; - _mm_storeu_ps(result, valueSse); - result += 4; - _mm_storeu_ps(result, valueSse); - result += 4; - _mm_storeu_ps(result, valueSse); - result += 4; - - sseSize -= 4; - } - - while( sseSize > 0 ) { - _mm_storeu_ps(result, valueSse); - result += 4; - sseSize--; - } - - for( int i = 0; i < nonSseSize; ++i ) { - *result++ = 0; + if( vectorSize > 0 ) { + std::fill_n( result, vectorSize, 0.0f ); } } @@ -270,20 +181,20 @@ inline void vectorAdd(const float* first, const float* second, float* result, in //------------------------------------------------------------------------------------------------------------ -inline void alignedVectorAdd( float* first, const float* second, int vectorSize ) +inline void alignedVectorAdd( const float* first, float* second, int vectorSize ) { int sseSize = vectorSize / 4; while( sseSize >= 4 ) { - _mm_store_ps( first, _mm_add_ps( _mm_load_ps( first ), _mm_load_ps( second ) ) ); + _mm_store_ps( second, _mm_add_ps( _mm_load_ps( first ), _mm_load_ps( second ) ) ); first += 4; second += 4; - _mm_store_ps( first, _mm_add_ps( _mm_load_ps( first ), _mm_load_ps( second ) ) ); + _mm_store_ps( second, _mm_add_ps( _mm_load_ps( first ), _mm_load_ps( second ) ) ); first += 4; second += 4; - _mm_store_ps( first, _mm_add_ps( _mm_load_ps( first ), _mm_load_ps( second ) ) ); + _mm_store_ps( second, _mm_add_ps( _mm_load_ps( first ), _mm_load_ps( second ) ) ); first += 4; second += 4; - _mm_store_ps( first, _mm_add_ps( _mm_load_ps( first ), _mm_load_ps( second ) ) ); + _mm_store_ps( second, _mm_add_ps( _mm_load_ps( first ), _mm_load_ps( second ) ) ); first += 4; second += 4; @@ -291,7 +202,7 @@ inline void alignedVectorAdd( float* first, const float* second, int vectorSize } while( sseSize > 0 ) { - _mm_store_ps( first, _mm_add_ps( _mm_load_ps( first ), _mm_load_ps( second ) ) ); + _mm_store_ps( second, _mm_add_ps( _mm_load_ps( first ), _mm_load_ps( second ) ) ); first += 4; second += 4; sseSize--; @@ -334,7 +245,7 @@ inline void alignedVectorMultiplyAndAdd( const float* first, const float* second } //------------------------------------------------------------------------------------------------------------ -inline void vectorMultiply( const float* first, float* result, float multiplier, int vectorSize ) +inline void vectorMultiply( const float* first, float multiplier, float* result, int vectorSize ) { int sseSize; int nonSseSize; @@ -569,7 +480,7 @@ inline void vectorReLU( const float* first, float* result, int vectorSize ) } } -inline void vectorReLU( const float* first, float* result, int vectorSize, float threshold ) +inline void vectorReLUTreshold( const float* first, float* result, int vectorSize, float threshold ) { int sseSize; int nonSseSize; @@ -610,7 +521,7 @@ inline void vectorReLU( const float* first, float* result, int vectorSize, float //------------------------------------------------------------------------------------------------------------ -inline void vectorAddValue( const float* first, float* result, int vectorSize, float value ) +inline void vectorAddValue( const float* first, float value, float* result, int vectorSize ) { int sseSize; int nonSseSize; @@ -632,7 +543,7 @@ inline void vectorAddValue( const float* first, float* result, int vectorSize, f //------------------------------------------------------------------------------------------------------------ -inline void vectorDotProduct( const float* first, const float* second, int vectorSize, float* result ) +inline void vectorDotProduct( const float* first, const float* second, float* result, int vectorSize ) { int sseSize; int nonSseSize; @@ -861,7 +772,7 @@ static inline void qrnnIfPoolingStep( const float* z, const float* f, const floa } } -inline void vectorMinMax( const float* first, float* result, const float minValue, const float maxValue, int vectorSize ) +inline void vectorMinMax( const float* first, float* result, int vectorSize, const float minValue, const float maxValue ) { int sseSize; int nonSseSize; @@ -905,7 +816,7 @@ inline void vectorTanh( const float* first, float* result, int vectorSize ) inline void vectorExp( const float* first, float* result, int vectorSize ) { #ifdef NEOML_USE_MKL - vectorMinMax( first, result, FLT_MIN_LOG, FLT_MAX_LOG, vectorSize ); + vectorMinMax( first, result, vectorSize, FLT_MIN_LOG, FLT_MAX_LOG ); vsExp( vectorSize, result, result ); #else for( int i = 0; i < vectorSize; ++i ) { diff --git a/NeoMathEngine/src/CPU/x86/avx/src/AvxMathEngine.cpp b/NeoMathEngine/src/CPU/x86/avx/src/AvxMathEngine.cpp index 89aa772c2..601000136 100644 --- a/NeoMathEngine/src/CPU/x86/avx/src/AvxMathEngine.cpp +++ b/NeoMathEngine/src/CPU/x86/avx/src/AvxMathEngine.cpp @@ -67,6 +67,19 @@ class CAvxMathEngine : public ISimdMathEngine { void RunOnceRestOfLstm( CMathEngineLstmDesc* desc, const CConstFloatHandle& inputStateBackLink, const CFloatHandle& outputStateBackLink, const CFloatHandle& outputMainBackLink, bool isMultithread ) override; + vectorAddFunc GetVectorAddFunc() override { return reinterpret_cast( primitives.GetFunctionRawPtr() ); } + alignedVectorAdd GetAlignedVectorAddFunc() override { return reinterpret_cast( primitives.GetFunctionRawPtr() ); } + vectorEltwiseMax GetVectorMaxFunc() override { return reinterpret_cast( primitives.GetFunctionRawPtr() ); } + vectorReLU GetVectorReLUFunc() override { return reinterpret_cast( primitives.GetFunctionRawPtr() ); } + vectorReLUTreshold GetVectorReLUTresholdFunc() override { return reinterpret_cast( primitives.GetFunctionRawPtr() ); } + alignedVectorMultiplyAndAdd GetAlignedVectorMultiplyAndAddFunc() override { return reinterpret_cast< alignedVectorMultiplyAndAdd >( primitives.GetFunctionRawPtr() ); } + vectorMultiply GetVectorMultiplyFunc() override { return reinterpret_cast< vectorMultiply >( primitives.GetFunctionRawPtr() ); } + vectorEltwiseMultiply GetVectorEltwiseMultiplyFunc() override { return reinterpret_cast< vectorEltwiseMultiply >( primitives.GetFunctionRawPtr() ); } + vectorEltwiseMultiplyAdd GetVectorEltwiseMultiplyAddFunc() override { return reinterpret_cast< vectorEltwiseMultiplyAdd >( primitives.GetFunctionRawPtr() ); } + vectorAddValue GetVectorAddValueFunc() override { return reinterpret_cast< vectorAddValue >( primitives.GetFunctionRawPtr() ); } + vectorDotProduct GetVectorDotProductFunc() override { return reinterpret_cast< vectorDotProduct >( primitives.GetFunctionRawPtr() ); } + vectorMinMax GetVectorMinMaxFunc() override { return reinterpret_cast< vectorMinMax >( primitives.GetFunctionRawPtr() ); } + private: IMathEngine* mathEngine; int threadCount; diff --git a/NeoMathEngine/src/CPU/x86/avx/src/JitCommon.cpp b/NeoMathEngine/src/CPU/x86/avx/src/JitCommon.cpp index 330896622..98b0ea3ae 100644 --- a/NeoMathEngine/src/CPU/x86/avx/src/JitCommon.cpp +++ b/NeoMathEngine/src/CPU/x86/avx/src/JitCommon.cpp @@ -34,7 +34,9 @@ Xbyak::Address CJitCommon::Prologue( const reg64Vec_t& preservedGPR, push( rbp ); mov( rbp, rsp ); - sub( rsp, static_cast( preservedYmm.size() * SizeOfYmm ) ); + if( preservedYmm.size() != 0 ) { + sub( rsp, static_cast< uint32_t >( preservedYmm.size() * SizeOfYmm ) ); + } for( int i = 0; i < preservedYmm.size(); i++ ) { vmovdqu( ptr[rsp + i * SizeOfYmm], preservedYmm[i] ); } @@ -50,6 +52,8 @@ Xbyak::Address CJitCommon::Prologue( const reg64Vec_t& preservedGPR, void CJitCommon::Epilogue( const reg64Vec_t& preservedGPR, const ymmVec_t& preservedYmm ) { + vzeroupper(); + for( int i = static_cast( preservedGPR.size() - 1 ); i >= 0; i-- ) { pop( preservedGPR[i] ); } @@ -80,4 +84,9 @@ void CJitCommon::StopDownCountLoop() loopDescs.pop(); } +void CJitCommon::JmpIfZero( reg64_t counter, const char* label ) { + test( counter, counter ); + jz( label, T_NEAR ); +} + } diff --git a/NeoMathEngine/src/CPU/x86/avx/src/JitCommon.h b/NeoMathEngine/src/CPU/x86/avx/src/JitCommon.h index 7ef70a605..9c6a22c4a 100644 --- a/NeoMathEngine/src/CPU/x86/avx/src/JitCommon.h +++ b/NeoMathEngine/src/CPU/x86/avx/src/JitCommon.h @@ -43,6 +43,11 @@ constexpr reg64_t Param1{Xbyak::Operand::RCX}; constexpr reg64_t Param2{Xbyak::Operand::RDX}; constexpr reg64_t Param3{Xbyak::Operand::R8}; constexpr reg64_t Param4{Xbyak::Operand::R9}; + +constexpr reg64_t Params[4] = { Param1, Param2, Param3, Param4 }; + +const int LowerPreservedYmm = 6; + #else constexpr reg64_t Param1{Xbyak::Operand::RDI}; constexpr reg64_t Param2{Xbyak::Operand::RSI}; @@ -50,6 +55,11 @@ constexpr reg64_t Param3{Xbyak::Operand::RDX}; constexpr reg64_t Param4{Xbyak::Operand::RCX}; constexpr reg64_t Param5{Xbyak::Operand::R8}; constexpr reg64_t Param6{Xbyak::Operand::R9}; + +constexpr reg64_t Params[6] = { Param1, Param2, Param3, Param4, Param5, Param6 }; + +// 16 means 'Don't preserve' +const int LowerPreservedYmm = 16; #endif constexpr unsigned int NumFloatInYmm = 8; @@ -57,6 +67,22 @@ constexpr unsigned int SizeOfYmm = NumFloatInYmm * sizeof( float ); constexpr unsigned int SizeofReg64 = 8; constexpr unsigned int MaxYmmCount = 16; +// Windows and Linux calling conventions treat floating point arguments in different maner: +// Windows passes only 4 arguments through GPR (four for both integer/pointer and floating point). +// In windows registers for passing arguments (rcx,rdx,r8,r9 and xmm0-xmm3) are strictly fixed +// by sequence number of argument. +// Example (win) ( int, void*, float, int ) will be passed through ( rcx, rdx, xmm2, r9 ) +// In linux GPR and Xmm are indexed in continuous manner even in case of interleaving og GPR and XMM +// Example (linux) ( int, void*, float, int ) will be passed through ( rdi, rsi, xmm0, rdx ) +constexpr int GetFirstXmmArgIdx( int argNum ) { +#ifdef _WIN32 + assert( argNum < 4 ); + return argNum; +#else + return 0; +#endif +} + class CJitCommon : public Xbyak::CodeGenerator { public: using Base = Xbyak::CodeGenerator; @@ -80,6 +106,7 @@ class CJitCommon : public Xbyak::CodeGenerator { void StartDownCountLoop( reg64_t counter, size_t step ); void StopDownCountLoop(); + inline void JmpIfZero( reg64_t counter, const char* label ); template bool HasSameSize( const LastVec& ) { @@ -183,7 +210,11 @@ class CJitCommon : public Xbyak::CodeGenerator { XBYAK_FORWARD_CAST_2( vmovups, Address, Xmm ) XBYAK_FORWARD_CAST_2( vmovups, Xmm, Operand ) XBYAK_FORWARD_CAST_3( vaddps, Xmm, Operand, Operand ) + XBYAK_FORWARD_CAST_3( vmaxps, Xmm, Operand, Operand ) + XBYAK_FORWARD_CAST_3( vminps, Xmm, Operand, Operand ) XBYAK_FORWARD_CAST_3( vmulps, Xmm, Operand, Operand ) + XBYAK_FORWARD_CAST_3( vxorps, Xmm, Operand, Operand ) + XBYAK_FORWARD_CAST_3( vfmadd231ps, Xmm, Xmm, Operand ) private: struct CLoopDesc { diff --git a/NeoMathEngine/src/CPU/x86/avx/src/PrimitivesJit.cpp b/NeoMathEngine/src/CPU/x86/avx/src/PrimitivesJit.cpp index 9bd809304..3a0aa5bb6 100644 --- a/NeoMathEngine/src/CPU/x86/avx/src/PrimitivesJit.cpp +++ b/NeoMathEngine/src/CPU/x86/avx/src/PrimitivesJit.cpp @@ -186,6 +186,373 @@ void CPrimitivesJit::addVal( TTableKey key, uint32_t val, size_t repeatNum ) fill_n( pTable, repeatNum, val ); } +void CPrimitivesJit::initEltwisePrimitive( CPrimitivesJit::TPrimitive P, bool hasOp2, bool op2IsScalar ) +{ + assert( ( op2IsScalar && hasOp2 ) || !op2IsScalar ); + + using namespace Xbyak; + using namespace Xbyak::util; + // create new instance + auto& gen = gens[static_cast< size_t >( P )].gen; + + Address stackArgsPtr = gen.Prologue( {}, {} ); + + constexpr int xmmArgIdx = GetFirstXmmArgIdx( 1 ); + int gprArgIdx = 0; + // *** Define registers *** + const reg64_t regOp1Ptr = Params[gprArgIdx++]; + + // Define both and choose which of them we need (ymmScalar or regOp2Ptr) + const reg64_t regOp2Ptr = Params[gprArgIdx]; + const ymm_t ymmScalar = ymm5; + if( op2IsScalar ) { +#ifdef _WIN32 + // as mentioned in description for function GetFirstXmmArgIdx() in Windows, + // GPR which is used for integer/pointer passing is depended only on sequence number of argument. + gprArgIdx++; +#endif + gen.vbroadcastss( ymmScalar, Xmm( xmmArgIdx ) ); + } else if( hasOp2 ) { + gprArgIdx++; + } + + const reg64_t regResPtr = Params[gprArgIdx++]; + const reg64_t regCount = Params[gprArgIdx]; + + EltwiseGenFunc eltwiseFunc = GetEltwiseFuncPtr( P ); + + auto insertKernel = [&]( unsigned int stepCount ) { + if( stepCount > 0 ) { + for( unsigned int i = 0; i < stepCount; i++ ) { + if( op2IsScalar ) { + ( gen.*eltwiseFunc )( Ymm( i ), ymmScalar, ptr[regOp1Ptr + i * SizeOfYmm] ); + } else { + gen.vmovups( Ymm( i ), ptr[regOp1Ptr + i * SizeOfYmm] ); + ( gen.*eltwiseFunc )( Ymm( i ), Ymm( i ), ptr[regOp2Ptr + i * SizeOfYmm] ); + } + gen.vmovups( ptr[regResPtr + i * SizeOfYmm], Ymm( i ) ); + } + gen.add( regOp1Ptr, stepCount * SizeOfYmm ); + gen.add( regResPtr, stepCount * SizeOfYmm ); + if( hasOp2 && !op2IsScalar ) { + gen.add( regOp2Ptr, stepCount * SizeOfYmm ); + } + } else { + // Tail processing (ymm0 - is always mask) + ymm_t ymmMask = ymm0; + ymm_t ymmLastOp1 = ymm1; + ymm_t ymmLastOp2 = op2IsScalar ? ymmScalar : ymm2; + ymm_t ymmLastRes = ymm2; + gen.vmaskmovps( ymmLastOp1, ymmMask, gen.ptr[regOp1Ptr] ); + if( !op2IsScalar ) { + gen.vmaskmovps( ymmLastOp2, ymmMask, gen.ptr[regOp2Ptr] ); + } + ( gen.*eltwiseFunc )( ymmLastRes, ymmLastOp1, ymmLastOp2 ); + gen.vmaskmovps( gen.ptr[regResPtr], ymmMask, ymmLastRes ); + } + }; + + insertSimpleMathFunction( {}, {}, gen, regCount, insertKernel, { 4, 1, 0 } ); +} + +void CPrimitivesJit::initMinMaxFunction( CPrimitivesJit::TPrimitive P, bool useLowerBound, bool useUpperBuond ) +{ + using namespace Xbyak; + using namespace Xbyak::util; + // create new instance + auto& gen = gens[static_cast< size_t >( P )].gen; + const ymmVec_t preservedYmm = initVecRange( LowerPreservedYmm, 15 ); + Address stackArgsPtr = gen.Prologue( {}, preservedYmm ); + + Ymm ymmLowerBound = ymm14; + Ymm ymmUpperBound = ymm15; + + constexpr int xmmArgIdx = GetFirstXmmArgIdx( 3 ); + // *** Define registers *** + const reg64_t regOp1Ptr = Param1; + const reg64_t regResPtr = Param2; + const reg64_t regCount = Param3; + // Set lower ans upper bound + if( useLowerBound ) { + gen.vbroadcastss( ymmLowerBound, Xmm( xmmArgIdx ) ); + if( useUpperBuond ) { +#ifdef _WIN32 + gen.vbroadcastss( ymmUpperBound, stackArgsPtr ); +#else + gen.vbroadcastss( ymmUpperBound, Xmm( xmmArgIdx + 1 ) ); +#endif + } + } else { + gen.vxorps( ymmLowerBound, ymmLowerBound, ymmLowerBound ); + if( useUpperBuond ) { + gen.vbroadcastss( ymmUpperBound, Xmm( xmmArgIdx ) ); + } + } + + + auto insertKernel = [&]( unsigned int stepCount ) { + if( stepCount > 0 ) { + for( unsigned int i = 0; i < stepCount; i++ ) { + gen.vmaxps( Ymm( i ), ymmLowerBound, ptr[regOp1Ptr + i * SizeOfYmm] ); + if( useUpperBuond ) { + gen.vminps( Ymm( i ), Ymm( i ), ymmUpperBound ); + } + gen.vmovups( ptr[regResPtr + i * SizeOfYmm], Ymm( i ) ); + } + gen.add( regOp1Ptr, stepCount * SizeOfYmm ); + gen.add( regResPtr, stepCount * SizeOfYmm ); + } else { + // Tail processing (ymm0 - is always mask) + ymm_t ymmMask = ymm0; + ymm_t ymmLast = ymm1; + gen.vmaskmovps( ymmLast, ymmMask, gen.ptr[regOp1Ptr] ); + gen.vmaxps( ymmLast, ymmLast, ymmLowerBound ); + if( useUpperBuond ) { + gen.vminps( ymmLast, ymmLast, ymmUpperBound ); + } + gen.vmaskmovps( gen.ptr[regResPtr], ymmMask, ymmLast ); + } + }; + + insertSimpleMathFunction( {}, preservedYmm, gen, regCount, insertKernel, { 14, 4, 1, 0 } ); +} + +void CPrimitivesJit::insertSimpleMathFunction( const reg64Vec_t& preservedGPR, const ymmVec_t& preservedYmm, + CJitCommon& gen, const reg64_t& regCount, + const std::function& insertKernel, const std::vector& loopUnrollingSteps, + const std::function& callBeforeRet ) +{ + using namespace Xbyak; + using namespace Xbyak::util; + + for( auto step : loopUnrollingSteps ) { + if( step > 0 ) { + gen.StartDownCountLoop( regCount, step * NumFloatInYmm ); + insertKernel( step ); + gen.StopDownCountLoop(); + gen.JmpIfZero( regCount, "end" ); + } else { + // Process tail + ymm_t ymmMask = ymm0; + // Multiply by 8 for calculate right offset + gen.mov( regTablePtr, ( uint64_t )table.data() ); + gen.shl( regCount, 3 ); + gen.vmovups( ymmMask, gen.ptr[regTablePtr + regCount * sizeof( float ) + getOfft( TTableKey::LoadMask )] ); + insertKernel( step ); + } + } + gen.L( "end" ); + + if( callBeforeRet ) { + callBeforeRet(); + } + + gen.Epilogue( preservedGPR, preservedYmm ); + gen.ret(); +} + +template<> +void CPrimitivesJit::initPrimitive () +{ + initEltwisePrimitive( TPrimitive::VectorAdd, true ); +} + +template<> +void CPrimitivesJit::initPrimitive () +{ + initEltwisePrimitive( TPrimitive::VectorAlignedAdd, false ); +} + +template<> +void CPrimitivesJit::initPrimitive () +{ + initEltwisePrimitive( TPrimitive::VectorMax, true ); +} + +template<> +void CPrimitivesJit::initPrimitive () +{ + initMinMaxFunction( TPrimitive::VectorReLU, false, false ); +} + +template<> +void CPrimitivesJit::initPrimitive () +{ + initMinMaxFunction( TPrimitive::VectorReLUTreshold, false, true ); +} + +template<> +void CPrimitivesJit::initPrimitive () +{ + using namespace Xbyak; + using namespace Xbyak::util; + // create new instance + auto& gen = gens[static_cast< size_t >( TPrimitive::VectorAlignedMultiplyAndAdd )].gen; + + Address stackArgsPtr = gen.Prologue( {}, {} ); + + // *** Define registers *** + const reg64_t regOp1Ptr = Param1; + const reg64_t regOp2Ptr = Param2; + const reg64_t regResPtr = Param3; + const reg64_t regCount = Param4; +#ifdef _WIN32 + const reg64_t regMul = rax; // param5 + gen.mov( regMul, stackArgsPtr ); +#else + const reg64_t regMul = Param5; +#endif + + ymm_t ymmMul = ymm5; + gen.vbroadcastss( ymmMul, gen.ptr[regMul] ); + + auto insertKernel = [&]( unsigned int stepCount ) { + if( stepCount > 0 ) { + for( unsigned int i = 0; i < stepCount; i++ ) { + gen.vmovups( Ymm( i ), ptr[regOp1Ptr + i * SizeOfYmm] ); + gen.vfmadd231ps( Ymm( i ), ymmMul, ptr[regOp2Ptr + i * SizeOfYmm] ); + gen.vmovups( ptr[regResPtr + i * SizeOfYmm], Ymm( i ) ); + } + gen.add( regOp1Ptr, stepCount * SizeOfYmm ); + gen.add( regOp2Ptr, stepCount * SizeOfYmm ); + gen.add( regResPtr, stepCount * SizeOfYmm ); + } else { + // Tail processing (ymm0 - is always mask) + ymm_t ymmMask = ymm0; + ymm_t ymmLastOp1 = ymm1; + ymm_t ymmLastOp2 = ymm2; + ymm_t ymmLastRes = ymmLastOp1; + gen.vmaskmovps( ymmLastOp1, ymmMask, gen.ptr[regOp1Ptr] ); + gen.vmaskmovps( ymmLastOp2, ymmMask, gen.ptr[regOp2Ptr] ); + gen.vfmadd231ps( ymmLastOp1, ymmMul, ymmLastOp2 ); + gen.vmaskmovps( gen.ptr[regResPtr], ymmMask, ymmLastRes ); + } + }; + + insertSimpleMathFunction( {}, {}, gen, regCount, insertKernel, { 4, 1, 0 } ); +} + +template<> +void CPrimitivesJit::initPrimitive () +{ + initEltwisePrimitive( TPrimitive::VectorMultiply, true, true ); +} + +template<> +void CPrimitivesJit::initPrimitive () +{ + initEltwisePrimitive( TPrimitive::VectorEltwiseMultiply, true ); +} + +template<> +void CPrimitivesJit::initPrimitive () +{ + using namespace Xbyak; + using namespace Xbyak::util; + // create new instance + auto& gen = gens[static_cast< size_t >( TPrimitive::VectorEltwiseMultiplyAdd )].gen; + const ymmVec_t preservedYmm = initVecRange( LowerPreservedYmm, 7 ); + + gen.Prologue( {}, preservedYmm ); + + // *** Define registers *** + const reg64_t regOp1Ptr = Param1; + const reg64_t regOp2Ptr = Param2; + const reg64_t regResPtr = Param3; + const reg64_t regCount = Param4; + + auto insertKernel = [&]( unsigned int stepCount ) { + assert( stepCount <= 4 ); + if( stepCount > 0 ) { + for( unsigned int i = 0; i < stepCount; i++ ) { + gen.vmovups( Ymm( i ), ptr[regResPtr + i * SizeOfYmm] ); + gen.vmovups( Ymm( i + 4 ), ptr[regOp1Ptr + i * SizeOfYmm] ); + gen.vfmadd231ps( Ymm( i ), Ymm( i + 4 ), ptr[regOp2Ptr + i * SizeOfYmm] ); + gen.vmovups( ptr[regResPtr + i * SizeOfYmm], Ymm( i ) ); + } + gen.add( regOp1Ptr, stepCount * SizeOfYmm ); + gen.add( regOp2Ptr, stepCount * SizeOfYmm ); + gen.add( regResPtr, stepCount * SizeOfYmm ); + } else { + // Tail processing (ymm0 - is always mask) + ymm_t ymmMask = ymm0; + ymm_t ymmLastOp1 = ymm1; + ymm_t ymmLastOp2 = ymm2; + ymm_t ymmLastRes = ymm3; + gen.vmaskmovps( ymmLastOp1, ymmMask, gen.ptr[regOp1Ptr] ); + gen.vmaskmovps( ymmLastOp2, ymmMask, gen.ptr[regOp2Ptr] ); + gen.vmaskmovps( ymmLastRes, ymmMask, gen.ptr[regResPtr] ); + gen.vfmadd231ps( ymmLastRes, ymmLastOp1, ymmLastOp2 ); + gen.vmaskmovps( gen.ptr[regResPtr], ymmMask, ymmLastRes ); + } + }; + + insertSimpleMathFunction( {}, preservedYmm, gen, regCount, insertKernel, { 4, 1, 0 } ); +} + +template<> +void CPrimitivesJit::initPrimitive () +{ + initEltwisePrimitive( TPrimitive::VectorAddValue, true, true ); +} + +template<> +void CPrimitivesJit::initPrimitive () +{ + using namespace Xbyak; + using namespace Xbyak::util; + // create new instance + auto& gen = gens[static_cast< size_t >( TPrimitive::VectorDotProduct )].gen; + + gen.Prologue( {}, {} ); + + // *** Define registers *** + const reg64_t regOp1Ptr = Param1; + const reg64_t regOp2Ptr = Param2; + const reg64_t regResPtr = Param3; + const reg64_t regCount = Param4; + + ymm_t ymmRes = ymm5; + gen.vxorps( ymmRes, ymmRes, ymmRes ); + + auto insertKernel = [&]( unsigned int stepCount ) { + assert( stepCount <= 4 ); + if( stepCount > 0 ) { + for( unsigned int i = 0; i < stepCount; i++ ) { + gen.vmovups( Ymm( i ), ptr[regOp1Ptr + i * SizeOfYmm] ); + gen.vfmadd231ps( ymmRes, Ymm( i ), ptr[regOp2Ptr + i * SizeOfYmm] ); + } + gen.add( regOp1Ptr, stepCount * SizeOfYmm ); + gen.add( regOp2Ptr, stepCount * SizeOfYmm ); + } else { + // Tail processing (ymm0 - is always mask) + ymm_t ymmMask = ymm0; + ymm_t ymmLastOp1 = ymm1; + ymm_t ymmLastOp2 = ymm2; + gen.vmaskmovps( ymmLastOp1, ymmMask, gen.ptr[regOp1Ptr] ); + gen.vmaskmovps( ymmLastOp2, ymmMask, gen.ptr[regOp2Ptr] ); + gen.vfmadd231ps( ymmRes, ymmLastOp1, ymmLastOp2 ); + } + }; + + auto flushResult = [&]() { + // Horisonally add result and store it + gen.vextractf128( xmm0, ymmRes, 1 ); + gen.vhaddps( ymmRes, ymm0, ymmRes ); + gen.vhaddps( ymmRes, ymmRes, ymmRes ); + gen.vhaddps( ymmRes, ymmRes, ymmRes ); + gen.vmovss( gen.ptr[regResPtr], ymmRes.copyAndSetKind( Operand::XMM ) ); + }; + + insertSimpleMathFunction( {}, {}, gen, regCount, insertKernel, { 4, 1, 0 }, flushResult ); +} + +template<> +void CPrimitivesJit::initPrimitive () +{ + initMinMaxFunction( TPrimitive::VectorMinMax, true, true ); +} + template<> void CPrimitivesJit::initPrimitive () { @@ -193,7 +560,7 @@ void CPrimitivesJit::initPrimitive () auto& gen = gens[static_cast< size_t >( TPrimitive::Tanh )].gen; const reg64Vec_t preservedReg64; - const ymmVec_t preservedYmm = initVecRange( 6, 11 ); + const ymmVec_t preservedYmm = initVecRange( LowerPreservedYmm, 11 ); const ymmVec_t ymmSrc = initVecRange( 10, 11 ); const ymmVec_t ymmAux = initVecRange( 0, 9 ); @@ -208,7 +575,7 @@ void CPrimitivesJit::initPrimitive () auto& gen = gens[static_cast< size_t >( TPrimitive::Sigmoid )].gen; const reg64Vec_t preservedReg64; - const ymmVec_t preservedYmm = initVecRange( 6, 12 ); + const ymmVec_t preservedYmm = initVecRange( LowerPreservedYmm, 12 ); const ymmVec_t ymmSrc = initVecRange( 0, 2 ); const ymmVec_t ymmAux = initVecRange( 3, 12 ); @@ -229,7 +596,7 @@ void CPrimitivesJit::initPrimitive () auto& gen = gens[static_cast< size_t >( TPrimitive::Exp )].gen; const reg64Vec_t preservedReg64; - const ymmVec_t preservedYmm = initVecRange( 6, 15 ); + const ymmVec_t preservedYmm = initVecRange( LowerPreservedYmm, 15 ); const ymmVec_t ymmSrc = initVecRange( 12, 15 ); const ymmVec_t ymmAux = initVecRange( 0, 11 ); @@ -507,13 +874,13 @@ void CPrimitivesJit::initActivationFunction( const std::function& afterP const reg64_t regCount = Param4; auto insertCode = [&]( const ymmVec_t& ymmSrc, const ymmVec_t& ymmAux ) { - size_t stepCount = ymmSrc.size(); + uint32_t stepCount = static_cast( ymmSrc.size() ); gen.StartDownCountLoop( regCount, stepCount * NumFloatInYmm ); - for( int i = 0; i < stepCount; i++ ) { gen.vmovups( ymmSrc[i], ptr[regSrcPtr + i * SizeOfYmm] ); } + for( uint32_t i = 0; i < stepCount; i++ ) { gen.vmovups( ymmSrc[i], ptr[regSrcPtr + i * SizeOfYmm] ); } insertPrimitive

( gen, ymmSrc, ymmAux ); - for( int i = 0; i < stepCount; i++ ) { gen.vmovups( ptr[regDstPtr + i * SizeOfYmm], ymmSrc[i] ); } - gen.lea( regSrcPtr, gen.ptr[regSrcPtr + stepCount * SizeOfYmm] ); - gen.lea( regDstPtr, gen.ptr[regDstPtr + stepCount * SizeOfYmm] ); + for( uint32_t i = 0; i < stepCount; i++ ) { gen.vmovups( ptr[regDstPtr + i * SizeOfYmm], ymmSrc[i] ); } + gen.add( regSrcPtr, stepCount * SizeOfYmm ); + gen.add( regDstPtr, stepCount * SizeOfYmm ); gen.StopDownCountLoop(); }; @@ -716,18 +1083,7 @@ void CPrimitivesJit::insertPrimitive( CJitC template inline void CPrimitivesJit::callPrimitive( size_t dataSize, bool isMultithread, Args... args ) { - // args - usually are different kind of pointers - using namespace Xbyak::util; - - CGenerator& genInst = gens[static_cast< size_t >( P )]; - PrimitiveFuncType func; - - genInst.lock.lock(); - if( genInst.gen.getSize() == 0 ) { - initPrimitive

(); - } - genInst.lock.unlock(); - func = genInst.gen.getCode(); + PrimitiveFuncType func = reinterpret_cast( GetFunctionRawPtr

() ); const int curThreadCount = isMultithread && IsOmpRelevant( static_cast< int >( dataSize ) ) ? threadCount : 1; if( curThreadCount != 1 ) { diff --git a/NeoMathEngine/src/CPU/x86/avx/src/PrimitivesJit.h b/NeoMathEngine/src/CPU/x86/avx/src/PrimitivesJit.h index 6decb392e..eb6c7a75e 100644 --- a/NeoMathEngine/src/CPU/x86/avx/src/PrimitivesJit.h +++ b/NeoMathEngine/src/CPU/x86/avx/src/PrimitivesJit.h @@ -30,19 +30,20 @@ class IMathEngine; class CPrimitivesJit { public: - CPrimitivesJit( IMathEngine* _mathEngine, int _threadCount ); - - void Tanh( float* dst, const float* src, size_t dataSize, bool isMultithread = true ); - void Sigmoid( float* dst, const float* src, size_t dataSize, bool isMultithread = true ); - void Exp( float* dst, const float* src, size_t dataSize, bool isMultithread = true ); - - // Process part of lstm layer which follow after fullyconnected layers. - void RestOfLstm( CMathEngineLstmDesc* desc, const CConstFloatHandle& inputStateBackLink, - const CFloatHandle& outputStateBackLink, const CFloatHandle& outputMainBackLink, - bool isMultithread ); - -private: enum class TPrimitive { + VectorAdd, + VectorAlignedAdd, + VectorMax, + VectorReLU, + VectorReLUTreshold, + VectorAlignedMultiplyAndAdd, + VectorMultiply, + VectorEltwiseMultiply, + VectorEltwiseMultiplyAdd, + VectorAddValue, + VectorDotProduct, + VectorMinMax, + Tanh, Sigmoid, Exp, @@ -51,6 +52,29 @@ class CPrimitivesJit { Count }; + CPrimitivesJit( IMathEngine* _mathEngine, int _threadCount ); + + void Tanh( float* dst, const float* src, size_t dataSize, bool isMultithread ); + void Sigmoid( float* dst, const float* src, size_t dataSize, bool isMultithread ); + void Exp( float* dst, const float* src, size_t dataSize, bool isMultithread ); + + // Process part of lstm layer which follow after fullyconnected layers. + void RestOfLstm( CMathEngineLstmDesc* desc, const CConstFloatHandle& inputStateBackLink, + const CFloatHandle& outputStateBackLink, const CFloatHandle& outputMainBackLink, + bool isMultithread ); + + template + void* GetFunctionRawPtr() { + CGenerator& genInst = gens[static_cast< size_t >( P )]; + genInst.lock.lock(); + if( genInst.gen.getSize() == 0 ) { + initPrimitive

(); + } + genInst.lock.unlock(); + return static_cast( const_cast( genInst.gen.getCode() ) ); + } + +private: enum class TTableKey { // Tanh specific items TanhPolyCoeff, // Coefficients of tanh polynome @@ -83,7 +107,8 @@ class CPrimitivesJit { }; static constexpr int MantissaNumBits = 23; - + // Last pointer is always result + using EltwiseFunc = void( * )( const float* op1, const float* op2, float* res, size_t count ); using ActivationFunc = void( * )( float* dst, const float* src, size_t offset, size_t count ); using RestOfLstmFunc = void( * )( size_t hiddenSize, const float* inputStateBackLinkPtr, float* outputStateBackLinkPtr, float* outputMainBackLinkPtr, float* inputFullyConnectedResultPtr, float* reccurentFullyConnectedResultPtr, size_t offset, size_t count ); @@ -109,7 +134,32 @@ class CPrimitivesJit { void addVector( TTableKey key, std::initializer_list&& data, size_t repeatNum = 1 ); // repeatNum specifies how many times value will be repeated in the table void addVal( TTableKey key, uint32_t val, size_t repeatNum = NumFloatInYmm ); - + + using EltwiseGenFunc = void( CJitCommon::* )( const Xbyak::Xmm&, const Xbyak::Operand&, const Xbyak::Operand& ); + + static EltwiseGenFunc GetEltwiseFuncPtr( TPrimitive p ) { + switch( p ) { + case TPrimitive::VectorAdd: + case TPrimitive::VectorAlignedAdd: + case TPrimitive::VectorAddValue: + return static_cast( &CJitCommon::vaddps ); + case TPrimitive::VectorMax: + return static_cast( &CJitCommon::vmaxps ); + case TPrimitive::VectorMultiply: + case TPrimitive::VectorEltwiseMultiply: + return static_cast( &CJitCommon::vmulps ); + default: + assert( false ); + return nullptr; + } + } + + void initEltwisePrimitive( TPrimitive P, bool hasOp2, bool op2IsScalar = false ); + void initMinMaxFunction( TPrimitive P, bool useLowerBound, bool useUpperBuond ); + void insertSimpleMathFunction( const reg64Vec_t& preservedGPR, const ymmVec_t& preservedYmm, + CJitCommon& gen, const reg64_t& regCount, + const std::function& insertKernel, const std::vector& loopUnrollingSteps, + const std::function& callBeforeRet = std::function() ); template void initPrimitive(); template @@ -141,8 +191,10 @@ class CPrimitivesJit { template std::vector initVecRange( int firstIdx, int lastIdx ) { const int VecSize = lastIdx - firstIdx + 1; - assert( VecSize > 0 ); - assert( firstIdx >= 0 && lastIdx < 16 ); + if( VecSize <= 0 ) { + return {}; + } + assert( lastIdx < 16 ); std::vector ret( VecSize ); int idx = firstIdx; for( auto& v : ret ) {