diff --git a/NeoML/include/NeoML/Dnn/Layers/FavorAttentionPerformerLayer.h b/NeoML/include/NeoML/Dnn/Layers/FavorAttentionPerformerLayer.h new file mode 100644 index 0000000000..8cfaf55fe0 --- /dev/null +++ b/NeoML/include/NeoML/Dnn/Layers/FavorAttentionPerformerLayer.h @@ -0,0 +1,97 @@ +/* Copyright © 2023-2024 ABBYY + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--------------------------------------------------------------------------------------------------------------*/ + +#pragma once + +#include + +namespace NeoML { + +struct CFavorAttentionDesc; + +// Computes FAVOR normalized self-attention. +// https://arxiv.org/pdf/2009.14794.pdf. +// +// Inputs: query, key, value +// Emulates equation: Output ~~ softmax( query * ( key )^T / normalizer ) * value +// +// output +// ^ +// | +// +---------------+ +// | F A V O R | <-- projection matrix +// | Attention | (random features) +// +---------------+ +// ^ ^ ^ +// | | | +// query key value +// +class NEOML_API CFavorAttentionPerformerLayer : public CBaseLayer { + NEOML_DNN_LAYER( CFavorAttentionPerformerLayer ) +public: + // Possible activation kernel transformations + enum class TAKernel { SoftMax = 0, ReLU = 1 }; + // Layer inputs numeration + enum TInput { TI_Q = 0, TI_K = 1, TI_V = 2 }; + // Constructs a random matrix Q using + enum class TRandomMaxrixStructMode { + QMatrix, // QR-factorization of a random 2D-tensor + GivensRotations // Givens random rotations + }; + static constexpr TRandomMaxrixStructMode StructMode = TRandomMaxrixStructMode::GivensRotations; + // For normalization of a random matrix Q use sum of rows' norms of a random matrix, or just =sqrt(dim) + static constexpr bool Scaling = false; + + // Constructor + CFavorAttentionPerformerLayer( IMathEngine& mathEngine, const char* name = nullptr ); + + // The projection matrix columns size if it is used, or 0 if not + // Set to 0, if the projection matrix should not be used + int GetRandomFeaturesCount() const { return randomFeaturesCount; } + void SetRandomFeaturesCount( int randomFeaturesCount ); + // The activation kernel transformations is used + int GetActivationKernel() const { return static_cast( activation ); } + void SetActivationKernel( int activation ); + // The auto-regressive attention is used or not + bool GetCausal() const { return causal; } + void SetCausal( bool causal ); + + void Serialize( CArchive& archive ) override; + +protected: + ~CFavorAttentionPerformerLayer(); + + // Create output blobs using the input blobs + void Reshape() override; + // One step of a forward pass + void RunOnce() override; + // One step of a backward pass + void BackwardOnce() override; + +private: + // Number of random features to be used + // For SoftMax should be > 0, the random projection matrix should be applied + int randomFeaturesCount = 0; + TAKernel activation = TAKernel::SoftMax; // Activation Kernel type + bool causal = false; // Auto-regressive attention or not + CFavorAttentionDesc* desc = nullptr; // Favor Attention desctiption + + void destroyFavorAttentionDesc(); +}; + +NEOML_API CLayerWrapper FavorAttentionPerformer( + int randomFeaturesCount, int activation, bool causal ); + +} // namespace NeoML diff --git a/NeoML/include/NeoML/Dnn/Layers/MultiheadAttentionPerformerLayer.h b/NeoML/include/NeoML/Dnn/Layers/MultiheadAttentionPerformerLayer.h new file mode 100644 index 0000000000..dd5a87ccb4 --- /dev/null +++ b/NeoML/include/NeoML/Dnn/Layers/MultiheadAttentionPerformerLayer.h @@ -0,0 +1,101 @@ +/* Copyright © 2023-2024 ABBYY + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--------------------------------------------------------------------------------------------------------------*/ + +#pragma once + +#include +#include + +namespace NeoML { + +// Multihead Self Attention Performer +// https://arxiv.org/pdf/2009.14794.pdf +// Implementation of multiheaded FAVOR-attention & FAVOR-self-attention layers. +// +// +----------------------+--------+------------------------------------------------------- +// | Parameter | Type | Description +// +----------------------+--------+------------------------------------------------------- +// | HiddenSize | int | size of trainable matrices, output dim of hidden layer +// | HeadCount | int | number of heads to repeat the same attention structure +// | OutputSize | int | size of the output +// | ActivationKernel | int | activation (ReLU or SoftMax) kernel transformation +// | RandomFeaturesCount | int | projection matrix columns number, or 0 if isn't used +// | Casual | bool | auto-regressive attention is used or not +// +----------------------+--------+------------------------------------------------------- +class NEOML_API CMultiheadAttentionPerformerLayer : public CCompositeLayer { + NEOML_DNN_LAYER( CMultiheadAttentionPerformerLayer ) +public: + explicit CMultiheadAttentionPerformerLayer( IMathEngine& mathEngine ); + + // Activation kernel type: SoftMax(=0), ReLU(=1) + // By default is SoftMax + int GetActivationKernel() const { return activationKernel; } + void SetActivationKernel( int activationKernel, int randomFeaturesCount, bool casual ); + int GetRandomFeaturesCount() const { return randomFeaturesCount; } + bool GetCasual() const { return casual; } + + // The number of heads in attention + // The GetHiddenSize() must be a multiple of this value + // By default attention consist of 1 head + int GetHeadCount() const { return headCount; } + void SetHeadCount( int headCount ); + + // The size of trainable matrices + // Must be a multiple of GetHeadCount() + int GetHiddenSize() const { return hiddenSize; } + void SetHiddenSize( int hiddenSize ); + + // The size of output + int GetOutputSize() const { return outputSize; } + void SetOutputSize( int outputSize ); + + void Serialize( CArchive& archive ) override; + + // Recreates the layer if forceRebuild is true or it doesn't contain sublayers + void Rebuild( bool forceRebuild ); + +protected: + void Reshape() override; + +private: + // FAVOR+ attention settings + int activationKernel; // Activation kernel transformation + int randomFeaturesCount; // Projection matrix size, if > 0 + bool casual; // Auto-regression or not + + // The amount of heads + int headCount; + // The size of the trainable matrix + int hiddenSize; + // Output size + int outputSize; + + // Layer inputs numeration + enum TInputs { I_Q = 0, I_K = 1, I_V = 2 }; + + bool isCreated() const { return HasLayer( "Q" ); } + void create(); + + CBaseLayer* multiplyInputByMatrixWeights( int size, const char* name, TInputs input ); + CBaseLayer* multiplyByMatrixWeights( CBaseLayer* input, int width ); + CBaseLayer* prepareQ( CBaseLayer* input ); + CBaseLayer* prepareKV( CBaseLayer* input, bool isK ); + CBaseLayer* prepareOutput( CBaseLayer* input ); +}; + +NEOML_API CLayerWrapper MultiheadAttentionPerformer( + int headCount, int hiddenSize, int outputSize, int activationKernel, int randomFeaturesCount, bool casual ); + +} // namespace NeoML diff --git a/NeoML/include/NeoML/NeoML.h b/NeoML/include/NeoML/NeoML.h index e445c8c84f..50eff3ab59 100644 --- a/NeoML/include/NeoML/NeoML.h +++ b/NeoML/include/NeoML/NeoML.h @@ -1,4 +1,4 @@ -/* Copyright © 2017-2023 ABBYY +/* Copyright © 2017-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -115,6 +115,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -130,6 +131,7 @@ limitations under the License. #include #include #include +#include #include #include #include diff --git a/NeoML/src/CMakeLists.txt b/NeoML/src/CMakeLists.txt index 5b53f9c019..f6063663de 100644 --- a/NeoML/src/CMakeLists.txt +++ b/NeoML/src/CMakeLists.txt @@ -117,6 +117,7 @@ set(NeoML_SOURCES Dnn/Layers/DotProductLayer.cpp Dnn/Layers/EnumBinarizationLayer.cpp Dnn/Layers/FocalLossLayer.cpp + Dnn/Layers/FavorAttentionPerformerLayer.cpp Dnn/Layers/FullyConnectedSourceLayer.cpp Dnn/Layers/GlobalMaxPoolingLayer.cpp Dnn/Layers/GlobalSumPoolingLayer.cpp @@ -132,6 +133,7 @@ set(NeoML_SOURCES Dnn/Layers/MaxOverTimePoolingLayer.cpp Dnn/Layers/MobileNetV3BlockLayer.cpp Dnn/Layers/ModelWrapperLayer.cpp + Dnn/Layers/MultiheadAttentionPerformerLayer.cpp Dnn/Layers/ObjectNormalizationLayer.cpp Dnn/Layers/Onnx/OnnxEltwiseLayer.cpp Dnn/Layers/Onnx/OnnxCastLayer.cpp @@ -377,6 +379,7 @@ set(NeoML_HEADERS ../include/NeoML/Dnn/Layers/DotProductLayer.h ../include/NeoML/Dnn/Layers/EnumBinarizationLayer.h ../include/NeoML/Dnn/Layers/FocalLossLayer.h + ../include/NeoML/Dnn/Layers/FavorAttentionPerformerLayer.h ../include/NeoML/Dnn/Layers/FullyConnectedSourceLayer.h ../include/NeoML/Dnn/Layers/GlobalMaxPoolingLayer.h ../include/NeoML/Dnn/Layers/GlobalSumPoolingLayer.h @@ -392,6 +395,7 @@ set(NeoML_HEADERS ../include/NeoML/Dnn/Layers/MaxOverTimePoolingLayer.h ../include/NeoML/Dnn/Layers/MobileNetV3BlockLayer.h ../include/NeoML/Dnn/Layers/ModelWrapperLayer.h + ../include/NeoML/Dnn/Layers/MultiheadAttentionPerformerLayer.h ../include/NeoML/Dnn/Layers/MultiHingeLossLayer.h ../include/NeoML/Dnn/Layers/ObjectNormalizationLayer.h ../include/NeoML/Dnn/Layers/Onnx/OnnxEltwiseLayer.h diff --git a/NeoML/src/Dnn/Dnn.cpp b/NeoML/src/Dnn/Dnn.cpp index 4f5fe19876..606dd5dfbb 100644 --- a/NeoML/src/Dnn/Dnn.cpp +++ b/NeoML/src/Dnn/Dnn.cpp @@ -1,4 +1,4 @@ -/* Copyright © 2017-2023 ABBYY +/* Copyright © 2017-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -72,6 +72,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -88,6 +89,7 @@ limitations under the License. #include #include #include +#include #include #include #include @@ -349,6 +351,7 @@ REGISTER_NEOML_LAYER( CCtcDecodingLayer, "FmlCnnCtcDecodingLayer" ) REGISTER_NEOML_LAYER( CCtcLossLayer, "FmlCnnCtcLossLayer" ) REGISTER_NEOML_LAYER( CDotProductLayer, "FmlCnnDotProductLayer" ) REGISTER_NEOML_LAYER( CEnumBinarizationLayer, "FmlCnnEnumBinarizationLayer" ) +REGISTER_NEOML_LAYER( CFavorAttentionPerformerLayer, "NeoMLDnnFavorAttentionPerformerLayer" ) REGISTER_NEOML_LAYER( CGlobalMaxPoolingLayer, "FmlCnnGlobalMaxPoolingLayer" ) REGISTER_NEOML_LAYER( CGrnLayer, "NeoMLDnnGrnLayer" ) REGISTER_NEOML_LAYER( CGruLayer, "FmlCnnGruLayer" ) @@ -360,6 +363,7 @@ REGISTER_NEOML_LAYER( CLoraFullyConnectedLayer, "NeoMLDnnLoraFullyConnectedLayer REGISTER_NEOML_LAYER( CMaxOverTimePoolingLayer, "FmlCnnMaxOverTimePoolingLayer" ) REGISTER_NEOML_LAYER( CMobileNetV3PreSEBlockLayer, "NeoMLDnnMobileNetV3PreSEBlockLayer" ) REGISTER_NEOML_LAYER( CMobileNetV3PostSEBlockLayer, "NeoMLDnnMobileNetV3PostSEBlockLayer" ) +REGISTER_NEOML_LAYER( CMultiheadAttentionPerformerLayer, "NeoMLDnnMultiheadAttentionPerformerLayer" ) REGISTER_NEOML_LAYER( CMultiHingeLossLayer, "FmlCnnMultyHingeLossLayer" ) REGISTER_NEOML_LAYER( CMultiSquaredHingeLossLayer, "FmlCnnMultySquaredHingeLossLayer" ) REGISTER_NEOML_LAYER( CPixelToImageLayer, "FmlCnnPixelToImageLayerClass" ) diff --git a/NeoML/src/Dnn/Layers/FavorAttentionPerformerLayer.cpp b/NeoML/src/Dnn/Layers/FavorAttentionPerformerLayer.cpp new file mode 100644 index 0000000000..9b4e942be0 --- /dev/null +++ b/NeoML/src/Dnn/Layers/FavorAttentionPerformerLayer.cpp @@ -0,0 +1,646 @@ +/* Copyright © 2023-2024 ABBYY + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--------------------------------------------------------------------------------------------------------------*/ + +#include +#pragma hdrstop + +#include +#include +#include +#include +#include + +namespace NeoML { + +// Very big value for normalization's denominator +static constexpr float bigConstant = 1e8f; +// Small positive constant for numerical stability, used to bound away from zero kernel values +static constexpr float numericalStabilizer = 0.001f; +// M_PI value from math.h +static constexpr double piValue = 3.14159265358979323846; + +//--------------------------------------------------------------------------------------------------------------------- + +// Favor Attention Performer descriptor +struct CFavorAttentionDesc final { +public: + CFavorAttentionDesc( IMathEngine& mathEngine, + const CBlobDesc& qDesc, const CBlobDesc& kDesc, const CBlobDesc& vDesc, const CBlobDesc& outputDesc, + int randomFeaturesCount, CFavorAttentionPerformerLayer::TAKernel activation, bool causal ); + + // Computes FAVOR normalized attention + void FavorAttention( const CConstFloatHandle& query, const CConstFloatHandle& key, const CConstFloatHandle& value, + const CFloatHandle& output ); + // Computes gradient of FAVOR normalized attention + void FavorAttentionBackward( + const CConstFloatHandle& query, const CConstFloatHandle& key, const CConstFloatHandle& value, + const CFloatHandle& queryDiff, const CFloatHandle& keyDiff, const CFloatHandle& valueDiff, + const CConstFloatHandle& output ); + +private: + IMathEngine& mathEngine; + + const int B; + const int LQ; // seqQ + const int L; // seqTo + const int H; + const int M; + const int D; + + const int dim; // number of rows and columns of the resulting 2D-tensor + const int randomFeaturesCount; // Number of random features to be used (relevant only if projection_matrix set) + const CFavorAttentionPerformerLayer::TAKernel activation; // Transformation produces kernel features for attention + const bool causal; // whether attention is auto-regressive or not + const bool projectionMatrixType; // Either random projection matrix will be applied (for SoftMax should be true) + + CFloatHandleVar projectionTransposedMatrix; + CFloatHandleVar ratio; + CFloatHandleVar reluUpperThreshold; + CFloatHandleVar dataNormalizer; + CFloatHandleVar numericalStabilizer; + CFloatHandleVar constHalf; + + // Computes random features for the ReLU-kernel from + // Args: + // input: input data tensor of the shape [B,H,L,D], where: + // B - batch dimension, H - heads, L - attention dimensions, D - features. + // output: corresponding kernel feature map. + // projection: random Gaussian matrix of shape [M,D], where M stands + // for the number of random features and each DxD sub-block has pairwise orthogonal rows. + // Returns corresponding kernel feature map. + void reluKernel( const CConstFloatHandle& in, const CFloatHandle& out, const CConstFloatHandle* proj, bool isQuery ); + void reluKernelBackward( const CConstFloatHandle& input, CConstFloatHandle& outputDiff, const CFloatHandle& inputDiff, + const CConstFloatHandle* projection, bool isQuery ); + // Computes random features for the softmax kernel using FAVOR+ mechanism from + // Args: + // input: input data tensor of the shape [B,H,L,D], where: + // B - batch dimension, H - heads, L - attention dimensions, D - features. + // output: corresponding kernel feature map. + // projection: random Gaussian matrix of shape [M,D], where M stands + // for the number of random features and each DxD sub-block has pairwise orthogonal rows. + // query: indicates whether input data is a query or key tensor. + void softmaxKernel( const CConstFloatHandle& in, const CFloatHandle& out, const CConstFloatHandle* proj, bool isQuery ); + void softmaxKernelBackward( const CConstFloatHandle& input, CConstFloatHandle& outputDiff, const CFloatHandle& inputDiff, + const CConstFloatHandle* projection, bool isQuery ); + + // Constructs random Q matrix of QR-factorization random a 2D-tensor, initilizaed by normal distribution. + // Args : + // Q -- returned result matrix + // seed -- randomization initialization + // mean and sigma -- parameters of generation float random values + void createFromUnstructuredBlock( const CFloatHandle& Q, int seed, float mean = 0.f, float sigma = 1.f ); + // Constructs a 2D-tensor which is a product of the form G_1 * ... * G_k, where G_i is a Givens random rotation. + // The resulting tensor mimics a matrix taken uniformly at random form the orthogonal group. + // Args : + // Q -- returned result matrix + // seed -- randomization initialization + // min and max -- parameters of generation float random values + void createProductsOfGivensRotations( const CFloatHandle& Q, int seed, float min = FLT_MIN, float max = FLT_MAX ); + // + void createMultiplier( const CFloatHandle& multiplier, bool scaling, int seed, float mean = 0.f, float sigma = 1.f ); + // Constructs a matrix of random orthogonal projections. + // Each projection vector has direction chosen uniformly at randomand either deterministic length \sqrt{ d } + // or length taken from the \chi( d ) distribution (in the latter case marginal distributions of the projections are + // d - dimensional Gaussian vectors with associated identity covariance matrix ). + // Args: + // m: number of random projections. + // d: dimensionality of each random projection. + // seed: random seed used to construct projections. + // scaling: True if all the random projections need to be renormalized to have + // length \sqrt{ d }, False if the lengths of random projections should follow \chi( d ) distribution. + // structMode: if True then products of Givens rotations will be used to construct random orthogonal matrix. + // This bypasses Gram-Schmidt orthogonalization. + // Returns the matrix of random projections of the shape[m, d]. + void createProjectionMatrix( const CFloatHandle& projectionMatrixTransposed, int seed = 0, + bool scaling = CFavorAttentionPerformerLayer::Scaling, + CFavorAttentionPerformerLayer::TRandomMaxrixStructMode structMode = CFavorAttentionPerformerLayer::StructMode ); + + // Computes not-normalized FAVOR noncausal attention AV. + // Args: query and key tensors of a shape [B,H,M,L] and value tensor of a shape [B,H,D,L]. + // Returns Not - normalized FAVOR noncausal attention AV of a shape [B,H,M,D]. + void nonCausalNumerator( const CConstFloatHandle& qs, const CConstFloatHandle& ks, const CConstFloatHandle& vs, + const CFloatHandle& result, const CFloatHandle& temp ); + void nonCausalNumeratorBackward( const CConstFloatHandle& qs, const CConstFloatHandle& ks, const CConstFloatHandle& vs, + const CFloatHandle& result, const CFloatHandle& temp ); + + // Computes FAVOR normalizer in noncausal attention. + // Args: query and key tensors of a shape [B,H,M,L]. + // Returns FAVOR normalizer in noncausal attention of a shape [B,H,L]. + void nonCausalDenominator( const CConstFloatHandle& qs, const CConstFloatHandle& ks, + const CFloatHandle& result, const CFloatHandle& temp ); + void nonCausalDenominatorBackward( const CConstFloatHandle& qs, const CConstFloatHandle& ks, + const CFloatHandle& result, const CFloatHandle& temp ); + + //Computes not-normalized FAVOR causal attention A_{masked}V. + // Args: query and key tensors of a shape [B,H,M,L] and value tensor of a shape [B,H,D,L]. + // Returns Not - normalized FAVOR causal attention A_{masked}V of a shape [B,H,M,D]. + void causalNumerator( const CConstFloatHandle& qs, const CConstFloatHandle& ks, const CConstFloatHandle& vs, + const CFloatHandle& result, const CFloatHandle& temp ); + void causalNumeratorBackward( const CConstFloatHandle& qs, const CConstFloatHandle& ks, const CConstFloatHandle& vs, + const CFloatHandle& result, const CFloatHandle& temp ); + + //Computes FAVOR normalizer in causal attention + // Args: query and key tensors of a shape [B,H,M,L]. + // Returns FAVOR normalizer in causal attention of a shape [B,H,L]. + void causalDenominator( const CConstFloatHandle& qs, const CConstFloatHandle& ks, + const CFloatHandle& result, const CFloatHandle& temp ); + void causalDenominatorBackward( const CConstFloatHandle& qs, const CConstFloatHandle& ks, + const CFloatHandle& result, const CFloatHandle& temp ); +}; + +//--------------------------------------------------------------------------------------------------------------------- + +CFavorAttentionDesc::CFavorAttentionDesc( IMathEngine& mathEngine, + const CBlobDesc& qDesc, const CBlobDesc& kDesc, const CBlobDesc& vDesc, const CBlobDesc& outputDesc, + int randomFeaturesCount, CFavorAttentionPerformerLayer::TAKernel activation, bool causal ) : + mathEngine( mathEngine ), + B( qDesc.BatchWidth() ), + LQ( qDesc.Width() ), // seqQ + L( vDesc.Width() ), // seqTo + H( qDesc.ListSize() ), //heads_count + M( qDesc.Channels() ), // data_per_head + D( vDesc.Channels() ), // data_per_head + dim( M ), + randomFeaturesCount( randomFeaturesCount ), + activation( activation ), + causal( causal ), + projectionMatrixType( randomFeaturesCount > 0 ), + projectionTransposedMatrix( mathEngine, projectionMatrixType ? ( randomFeaturesCount * dim ) : 0 ), + ratio( mathEngine ), + reluUpperThreshold( mathEngine ), + dataNormalizer( mathEngine ), + numericalStabilizer( mathEngine ), + constHalf( mathEngine ) +{ + NeoAssert( B == kDesc.BatchWidth() ); + NeoAssert( B == vDesc.BatchWidth() ); + NeoAssert( H == kDesc.ListSize() ); + NeoAssert( H == vDesc.ListSize() ); + NeoAssert( L == kDesc.Width() ); + NeoAssert( M == kDesc.Channels() ); + + NeoAssert( outputDesc.BatchWidth() == B ); + NeoAssert( outputDesc.ListSize() == H ); + NeoAssert( outputDesc.Width() == LQ ); + NeoAssert( outputDesc.Channels() == D ); + + ratio.SetValue( static_cast( 1. / std::sqrt( randomFeaturesCount ) ) ); + reluUpperThreshold.SetValue( 0.f ); + dataNormalizer.SetValue( static_cast( 1. / std::sqrt( std::sqrt( dim ) ) ) ); + numericalStabilizer.SetValue( NeoML::numericalStabilizer ); + constHalf.SetValue( 0.5f ); +} + +void CFavorAttentionDesc::createFromUnstructuredBlock( const CFloatHandle& Q, int seed, float mean, float sigma ) +{ + CRandom random( seed ); + const int unstructuredBlockSize = dim * dim; + const CFloatHandle& unstructuredBlock = Q; + for( int i = 0; i < unstructuredBlockSize; ++i ) { + unstructuredBlock.SetValueAt( i, static_cast( random.Normal( mean, sigma ) ) ); + } + CFloatHandleStackVar qTransposed( mathEngine, static_cast( unstructuredBlockSize ) ); + CFloatHandle qTransposedHandle = qTransposed.GetHandle(); + mathEngine.QRFactorization( dim, dim, unstructuredBlock, &qTransposedHandle, + /*R*/nullptr, /*inplace*/false, /*returnQ*/true, /*returnR*/false ); + mathEngine.TransposeMatrix( /*batchSize*/1, qTransposedHandle, + /*height*/dim, /*mid*/1, /*width*/dim, /*channels*/1, Q, unstructuredBlockSize ); +} + +void CFavorAttentionDesc::createProductsOfGivensRotations( const CFloatHandle& Q, int seed, float min, float max ) +{ + CRandom random( seed ); + auto getQ = [&]( int i, int j ) -> float { return Q.GetValueAt( i * dim + j ); }; + auto setQ = [&]( int i, int j, float v ) { Q.SetValueAt( i * dim + j, v ); }; + + const int numGivensRotations = static_cast( dim * std::ceil( std::log( dim ) ) ); + for( int i = 0; i < numGivensRotations; ++i ) { + const float randomAngle = static_cast( piValue * random.Uniform( min, max ) ); + const float sinA = std::sin( randomAngle ); + const float cosA = std::cos( randomAngle ); + + const int randomIndexMin = random.UniformInt( 0, dim - 1 ); + const int randomIndexMax = random.UniformInt( randomIndexMin + 1, dim ); + for( int j = 0; j < dim; ++j ) { + const float tmpMin = cosA * getQ( randomIndexMin, j ) + sinA * getQ( randomIndexMax, j ); + const float tmpMax = -sinA * getQ( randomIndexMin, j ) + cosA * getQ( randomIndexMax, j ); + setQ( randomIndexMin, j, tmpMin ); + setQ( randomIndexMax, j, tmpMax ); + } + } +} + +void CFavorAttentionDesc::createMultiplier( const CFloatHandle& multiplier, bool scaling, int seed, float mean, float sigma ) +{ + if( scaling == true ) { + mathEngine.VectorFill( multiplier, static_cast( std::sqrt( dim ) ), randomFeaturesCount ); + } else if( scaling == false ) { + CFloatHandleStackVar tempUnstructuredBlock( mathEngine, static_cast( dim ) ); + CFloatHandle unstructuredBlock = tempUnstructuredBlock.GetHandle(); + + CArray values; + values.SetSize( dim ); + + CRandom random( seed ); + for( int feature = 0; feature < randomFeaturesCount; ++feature ) { + for( int i = 0; i < dim; ++i ) { + values[i] = static_cast( random.Normal( mean, sigma ) ); + } + mathEngine.DataExchangeRaw( unstructuredBlock, values.GetPtr(), dim ); + mathEngine.VectorEltwiseMultiply( unstructuredBlock, unstructuredBlock, unstructuredBlock, dim ); + mathEngine.VectorSum( unstructuredBlock, dim, multiplier + feature ); + } + mathEngine.VectorSqrt( multiplier, multiplier, randomFeaturesCount ); + } +} + +void CFavorAttentionDesc::createProjectionMatrix( const CFloatHandle& projectionMatrixTransposed, int seed, + bool scaling, CFavorAttentionPerformerLayer::TRandomMaxrixStructMode structMode ) +{ + const int m = randomFeaturesCount; + const size_t qSize = static_cast( dim ) * dim; + const int finalSize = m * dim; + CFloatHandleStackVar finalMatrix( mathEngine, static_cast( finalSize ) ); + + const int numFullBlocks = m / dim; + const int remainingRows = m - numFullBlocks * dim; + int current_seed = seed; + CFloatHandle Q = finalMatrix.GetHandle(); + for( int i = 0; i < numFullBlocks; ++i ) { + if( structMode == CFavorAttentionPerformerLayer::TRandomMaxrixStructMode::QMatrix ) { + createProductsOfGivensRotations( Q, seed ); + } else { + createFromUnstructuredBlock( Q, current_seed++ ); + } + Q += qSize; + } + + if( remainingRows > 0 ) { + CFloatHandleStackVar tempQVar( mathEngine, qSize ); + CFloatHandle tempQ = tempQVar.GetHandle(); + if( structMode == CFavorAttentionPerformerLayer::TRandomMaxrixStructMode::QMatrix ) { + createProductsOfGivensRotations( tempQ, seed ); + } else { + createFromUnstructuredBlock( tempQ, current_seed++ ); + } + mathEngine.VectorCopy( Q, tempQ, remainingRows * dim ); + } + + CFloatHandleStackVar projectionMatrix( mathEngine, static_cast( finalSize ) ); + { + CFloatHandleStackVar mulVal( mathEngine, static_cast( m ) ); + CFloatHandle multiplier = mulVal.GetHandle(); + createMultiplier( multiplier, scaling, current_seed ); + + mathEngine.MultiplyDiagMatrixByMatrix( multiplier, m, finalMatrix, dim, projectionMatrix, finalSize ); + } + mathEngine.TransposeMatrix( /*batchSize*/1, projectionMatrix, + /*height*/m, /*mid*/1, /*width*/dim, /*channels*/1, projectionMatrixTransposed, finalSize ); +} + +void CFavorAttentionDesc::nonCausalNumerator( const CConstFloatHandle& qs, const CConstFloatHandle& ks, + const CConstFloatHandle& vs, const CFloatHandle& result, const CFloatHandle& kvs ) +{ + mathEngine.MultiplyTransposedMatrixByMatrix( B * H, //bhlm,bhld->bhmd + ks, L, M, + vs, D, + kvs, B * H * M * D ); + mathEngine.MultiplyMatrixByMatrix( B * H, //bhlm,bhmd->bhld + qs, LQ, M, + kvs, D, + result, B * H * LQ * D ); +} + +void CFavorAttentionDesc::nonCausalNumeratorBackward( const CConstFloatHandle& /*qs*/, const CConstFloatHandle& /*ks*/, + const CConstFloatHandle& /*vs*/, const CFloatHandle& /*result*/, const CFloatHandle& /*temp*/ ) +{ + NeoAssert( false ); +} + +void CFavorAttentionDesc::nonCausalDenominator( const CConstFloatHandle& qs, const CConstFloatHandle& ks, + const CFloatHandle& result, const CFloatHandle& ksum ) +{ + //all_ones = tf.ones( [ks.shape[0]] ); //[L,B,H,M] + CFloatHandleStackVar allOnes( mathEngine, L ); + mathEngine.VectorFill( allOnes, 1, L ); + + CFloatHandle ksumPtr = ksum; + CConstFloatHandle ksPtr = ks; + for( int b = 0; b < B * H; ++b ) { + mathEngine.MultiplyTransposedMatrixByMatrix( /*batchSize*/1, //bhlm,l->bhm + ksPtr, L, M, + allOnes, 1, + ksumPtr, B * H * M ); + ksPtr += L * M; + ksumPtr += M; + } + mathEngine.MultiplyMatrixByMatrix( B * H, //bhlm,bhm->bhl + qs, LQ, M, + ksum, 1, + result, B * H * LQ ); +} + +void CFavorAttentionDesc::nonCausalDenominatorBackward( const CConstFloatHandle& /*qs*/, const CConstFloatHandle& /*ks*/, + const CFloatHandle& /*result*/, const CFloatHandle& /*temp*/ ) +{ + NeoAssert( false ); +} + +void CFavorAttentionDesc::causalNumerator( const CConstFloatHandle& /*qs*/, const CConstFloatHandle& /*ks*/, + const CConstFloatHandle& /*vs*/, const CFloatHandle& /*result*/, const CFloatHandle& ) +{ + NeoAssert( false ); // TODO +} + +void CFavorAttentionDesc::causalNumeratorBackward( const CConstFloatHandle& /*qs*/, const CConstFloatHandle& /*ks*/, + const CConstFloatHandle& /*vs*/, const CFloatHandle& /*result*/, const CFloatHandle& /*temp*/ ) +{ + NeoAssert( false ); +} + +void CFavorAttentionDesc::causalDenominator( const CConstFloatHandle& /*qs*/, const CConstFloatHandle& /*ks*/, + const CFloatHandle& /*result*/, const CFloatHandle& ) +{ + NeoAssert( false ); +} + +void CFavorAttentionDesc::causalDenominatorBackward( const CConstFloatHandle& /*qs*/, const CConstFloatHandle& /*ks*/, + const CFloatHandle& /*result*/, const CFloatHandle& /*temp*/ ) +{ + NeoAssert( false ); +} + +void CFavorAttentionDesc::FavorAttention( + const CConstFloatHandle& query, const CConstFloatHandle& key, const CConstFloatHandle& value, + const CFloatHandle& output ) +{ + CConstFloatHandle projectionTransposedConstHandle = projectionTransposedMatrix.GetHandle(); + CConstFloatHandle* projectionTransposed = nullptr; + if( projectionMatrixType == true ) { + CFloatHandleStackVar reduce( mathEngine ); + mathEngine.VectorSum( query, B * H * LQ * M, reduce ); + const int seed = static_cast( std::ceil( std::abs( reduce.GetValue() * NeoML::bigConstant ) ) ); + CFloatHandle projectionTransposedHandle = projectionTransposedMatrix.GetHandle(); + createProjectionMatrix( projectionTransposedHandle, seed ); + projectionTransposed = &projectionTransposedConstHandle; + } + + CFloatHandleStackVar tempQueryPrime( mathEngine, B * H * LQ * M ); + CFloatHandleStackVar tempKeyPrime( mathEngine, B * H * L * M ); + + CFloatHandle queryPrime = tempQueryPrime; + CFloatHandle keyPrime = tempKeyPrime; + + if( activation == CFavorAttentionPerformerLayer::TAKernel::SoftMax ) { + softmaxKernel( query, queryPrime, projectionTransposed, true ); //[B,H,L,M] + softmaxKernel( key, keyPrime, projectionTransposed, false ); //[B,H,L,M] + } else { + NeoAssert( activation == CFavorAttentionPerformerLayer::TAKernel::ReLU ); + reluKernel( query, queryPrime, projectionTransposed, true ); //[B,H,L,M] + reluKernel( key, keyPrime, projectionTransposed, false ); //[B,H,L,M] + } + + CFloatHandleStackVar temp( mathEngine, B * H * M * D ); + CFloatHandle attentionNorm = tempKeyPrime; + + if( causal ) { + causalNumerator( queryPrime, keyPrime, value, output, temp ); //[B,H,LQ,D] + causalDenominator( queryPrime, keyPrime, attentionNorm, temp ); //[B,H,LQ] + } else { + nonCausalNumerator( queryPrime, keyPrime, value, output, temp ); //[B,H,LQ,D] + nonCausalDenominator( queryPrime, keyPrime, attentionNorm, temp ); //[B,H,LQ] + } + mathEngine.MatrixColumnsEltwiseDivide( output, /*height*/( B * H * LQ ), /*width*/D, attentionNorm, output ); +} + +void CFavorAttentionDesc::FavorAttentionBackward( + const CConstFloatHandle& /*query*/, const CConstFloatHandle& /*key*/, const CConstFloatHandle& /*value*/, + const CFloatHandle& /*queryDiff*/, const CFloatHandle& /*keyDiff*/, const CFloatHandle& /*valueDiff*/, + const CConstFloatHandle& /*output*/ ) +{ + NeoAssert( false ); +} + +void CFavorAttentionDesc::reluKernel( const CConstFloatHandle& input, const CFloatHandle& output, + const CConstFloatHandle* projectionTransposed, bool isQuery ) +{ + const int LZ = isQuery ? LQ : L; + const int size = B * H * LZ * M; + if( projectionTransposed == nullptr ) { + mathEngine.VectorReLU( input, output, size, reluUpperThreshold ); + mathEngine.VectorAddValue( output, output, size, numericalStabilizer ); + } else { + CConstFloatHandle inputPtr = input; + CFloatHandle outputPtr = output; + for( int i = 0; i < B * H; ++i ) { + mathEngine.MultiplyMatrixByTransposedMatrix( /*batchSize*/1, //bhlm,dm->bhld + inputPtr, LZ, M, + *projectionTransposed, dim, + outputPtr, size ); + inputPtr += LZ * M; + outputPtr += LZ * dim; + } + mathEngine.VectorMultiply( output, output, size, ratio ); + mathEngine.VectorReLU( output, output, size, reluUpperThreshold ); + mathEngine.VectorAddValue( output, output, size, numericalStabilizer ); + } +} + +void CFavorAttentionDesc::reluKernelBackward( const CConstFloatHandle& input, CConstFloatHandle& outputDiff, + const CFloatHandle& inputDiff, const CConstFloatHandle* projectionTransposed, bool isQuery ) +{ + const int LZ = isQuery ? LQ : L; + const int size = B * H * LZ * M; + if( projectionTransposed == nullptr ) { + mathEngine.VectorReLUDiffOp( /*outputBlob*/input, outputDiff, inputDiff, size, reluUpperThreshold ); + } else { + NeoAssert( false ); + } +} + +void CFavorAttentionDesc::softmaxKernel( const CConstFloatHandle& input, const CFloatHandle& output, + const CConstFloatHandle* projectionTransposed, bool isQuery ) +{ + const int LZ = isQuery ? LQ : L; + const int size = B * H * LZ * M; + ASSERT_EXPR( projectionTransposed != nullptr ); + + CFloatHandleStackVar tempInput( mathEngine, size ); + mathEngine.VectorMultiply( input, tempInput, size, dataNormalizer ); + + CConstFloatHandle inputPtr = tempInput; + CFloatHandle outputPtr = output; + for( int i = 0; i < B * H; ++i ) { + mathEngine.MultiplyMatrixByTransposedMatrix( /*batchSize*/1, //bhlm,dm->bhld + inputPtr, LZ, M, + *projectionTransposed, dim, + outputPtr, size ); + inputPtr += LZ * M; + outputPtr += LZ * dim; + } + mathEngine.VectorEltwiseMultiply( tempInput, tempInput, tempInput, size ); + + const int reduceSize = size / M; + CFloatHandleStackVar tempDiag( mathEngine, reduceSize ); + mathEngine.VectorSumAlongDimension( tempInput, /*before*/reduceSize, /*size*/M, /*after*/1, tempDiag ); + mathEngine.VectorMultiply( tempDiag, tempDiag, reduceSize, constHalf ); + + CFloatHandleStackVar tempReduceMax( mathEngine, reduceSize ); + mathEngine.FindMaxValueInRows( output, reduceSize, M, tempReduceMax, reduceSize ); + if( !isQuery ) { + mathEngine.FindMaxValueInRows( tempReduceMax, B * H, L, tempReduceMax, reduceSize ); + mathEngine.SubVectorFromMatrixColumns( output, output, B * H, L * M, tempReduceMax ); + } else { + mathEngine.SubVectorFromMatrixColumns( output, output, reduceSize, M, tempReduceMax ); + } + mathEngine.SubVectorFromMatrixColumns( output, output, reduceSize, M, tempDiag ); + mathEngine.VectorExp( output, output, size ); + + mathEngine.VectorAddValue( output, output, size, numericalStabilizer ); + mathEngine.VectorMultiply( output, output, size, ratio ); +} + +void CFavorAttentionDesc::softmaxKernelBackward( const CConstFloatHandle& /*input*/, CConstFloatHandle& /*outputDiff*/, + const CFloatHandle& /*inputDiff*/, const CConstFloatHandle* /*projectionTransposed*/, bool /*isQuery*/ ) +{ + NeoAssert( false ); +} + +//--------------------------------------------------------------------------------------------------------------------- + +CFavorAttentionPerformerLayer::CFavorAttentionPerformerLayer( IMathEngine& mathEngine, const char* name ) : + CBaseLayer( mathEngine, ( name == nullptr ) ? "CDnnFavorAttentionPerformerLayer" : name, /*isLearnable*/false ) +{} + +CFavorAttentionPerformerLayer::~CFavorAttentionPerformerLayer() +{ + destroyFavorAttentionDesc(); +} + +void CFavorAttentionPerformerLayer::destroyFavorAttentionDesc() +{ + if( desc != nullptr ) { + delete desc; + desc = nullptr; + } +} + +static const int FavorAttentionPerformerLayerVersion = 0; + +void CFavorAttentionPerformerLayer::Serialize( CArchive& archive ) +{ + ( void ) archive.SerializeVersion( FavorAttentionPerformerLayerVersion ); + CBaseLayer::Serialize( archive ); + + archive.Serialize( randomFeaturesCount ); + archive.SerializeEnum( activation ); + archive.Serialize( causal ); + + if( archive.IsLoading() ) { + destroyFavorAttentionDesc(); + } +} + +void CFavorAttentionPerformerLayer::SetRandomFeaturesCount( int _randomFeaturesCount ) +{ + NeoAssert( _randomFeaturesCount >= 0 ); + if( randomFeaturesCount != _randomFeaturesCount ) { + randomFeaturesCount = _randomFeaturesCount; + destroyFavorAttentionDesc(); + } +} + +void CFavorAttentionPerformerLayer::SetActivationKernel( int _activation ) +{ + TAKernel newActivation = static_cast( _activation ); + NeoAssert( newActivation == TAKernel::SoftMax || newActivation == TAKernel::ReLU ); + + if( activation != newActivation ) { + activation = newActivation; + destroyFavorAttentionDesc(); + } +} + +void CFavorAttentionPerformerLayer::SetCausal( bool _causal ) +{ + if( causal != _causal ) { + causal = _causal; + destroyFavorAttentionDesc(); + } +} + +void CFavorAttentionPerformerLayer::Reshape() +{ + CheckInputs(); + CheckLayerArchitecture( GetInputCount() == 3, "Favor Attention layer inputs count should be 3" ); + CheckLayerArchitecture( GetOutputCount() == 1, "Favor Attention layer outputs count should be 1" ); + + // For each layer element there is a channel in the output blob + outputDescs[0] = inputDescs[TI_Q]; + outputDescs[0].SetDimSize( BD_Channels, inputDescs[TI_V].Channels() ); + + destroyFavorAttentionDesc(); +} + +void CFavorAttentionPerformerLayer::RunOnce() +{ + if( desc == nullptr ) { + desc = new CFavorAttentionDesc( MathEngine(), + inputBlobs[TI_Q]->GetDesc(), inputBlobs[TI_K]->GetDesc(), inputBlobs[TI_V]->GetDesc(), + outputBlobs[0]->GetDesc(), randomFeaturesCount, activation, causal ); + } + + CFloatHandle query = inputBlobs[TI_Q]->GetData(); // [B, n_head, seq_Q, d_k] + CFloatHandle key = inputBlobs[TI_K]->GetData(); // [B, n_head, seq_to, d_k] + CFloatHandle value = inputBlobs[TI_V]->GetData(); // [B, n_head, seq_to, d_k] + CFloatHandle output = outputBlobs[0]->GetData(); // [B, n_head, seq_Q, d_k] + + // Linearly project the query, key and value using different learned projections. + // Splitting heads is automatically done during the linear projections --> + // [batchSize, sequenceLength, headCount, sizePerHead] --> (names = [B,L,H,M]) + // (Tensor's [BatchWidth, ListSize, Width, Channels] inited) + desc->FavorAttention( query, key, value, output ); +} + +void CFavorAttentionPerformerLayer::BackwardOnce() +{ + NeoAssert( desc != nullptr ); + + CFloatHandle queryDiff = inputDiffBlobs[0]->GetData(); + CFloatHandle keyDiff = inputDiffBlobs[1]->GetData(); + CFloatHandle valueDiff = inputDiffBlobs[2]->GetData(); + CFloatHandle outputDiff = outputDiffBlobs[0]->GetData(); + + CConstFloatHandle query = inputBlobs[0]->GetData(); + CConstFloatHandle key = inputDiffBlobs[1]->GetData(); + CConstFloatHandle value = inputDiffBlobs[2]->GetData(); + + desc->FavorAttentionBackward( query, key, value, queryDiff, keyDiff, valueDiff, outputDiff ); +} + +//--------------------------------------------------------------------------------------------------------------------- + +NEOML_API CLayerWrapper FavorAttentionPerformer( + int randomFeaturesCount, int activation, bool causal ) +{ + return CLayerWrapper( "FavorAttentionPerformer", + [=]( CFavorAttentionPerformerLayer* layer ) { + layer->SetRandomFeaturesCount( randomFeaturesCount ); + layer->SetActivationKernel( activation ); + layer->SetCausal( causal ); + } ); +} + +} // namespace NeoML diff --git a/NeoML/src/Dnn/Layers/MultiheadAttentionLayer.cpp b/NeoML/src/Dnn/Layers/MultiheadAttentionLayer.cpp index 534c4e4cd6..59524c3407 100644 --- a/NeoML/src/Dnn/Layers/MultiheadAttentionLayer.cpp +++ b/NeoML/src/Dnn/Layers/MultiheadAttentionLayer.cpp @@ -1,4 +1,4 @@ -/* Copyright © 2017-2020 ABBYY Production LLC +/* Copyright © 2017-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -45,42 +45,54 @@ void CMultiheadAttentionLayer::SetHeadCount( int _headCount ) { NeoAssert( _headCount >= 1 ); - headCount = _headCount; - DeleteAllLayers(); + if( headCount != _headCount ) { + headCount = _headCount; + DeleteAllLayers(); + } } void CMultiheadAttentionLayer::SetHiddenSize( int _hiddenSize ) { NeoAssert( _hiddenSize >= 1 ); - hiddenSize = _hiddenSize; - DeleteAllLayers(); + if( hiddenSize != _hiddenSize ) { + hiddenSize = _hiddenSize; + DeleteAllLayers(); + } } void CMultiheadAttentionLayer::SetDropoutRate( float _dropoutRate ) { - dropoutRate = _dropoutRate; - DeleteAllLayers(); + if( dropoutRate != _dropoutRate ) { + dropoutRate = _dropoutRate; + DeleteAllLayers(); + } } void CMultiheadAttentionLayer::SetUseMask( bool newValue ) { - useMask = newValue; - DeleteAllLayers(); + if( useMask != newValue ) { + useMask = newValue; + DeleteAllLayers(); + } } void CMultiheadAttentionLayer::SetMaskType( TMaskType _maskType ) { - maskType = _maskType; - DeleteAllLayers(); + if( maskType != _maskType ) { + maskType = _maskType; + DeleteAllLayers(); + } } void CMultiheadAttentionLayer::SetOutputSize( int _outputSize ) { NeoAssert( _outputSize > 0 ); - outputSize = _outputSize; - DeleteAllLayers(); + if( outputSize != _outputSize ) { + outputSize = _outputSize; + DeleteAllLayers(); + } } void CMultiheadAttentionLayer::SetCompatibilityMode( bool value ) diff --git a/NeoML/src/Dnn/Layers/MultiheadAttentionPerformerLayer.cpp b/NeoML/src/Dnn/Layers/MultiheadAttentionPerformerLayer.cpp new file mode 100644 index 0000000000..736f65a283 --- /dev/null +++ b/NeoML/src/Dnn/Layers/MultiheadAttentionPerformerLayer.cpp @@ -0,0 +1,321 @@ +/* Copyright © 2023-2024 ABBYY + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +--------------------------------------------------------------------------------------------------------------*/ + +#include +#pragma hdrstop + +#include +#include +#include +//#include +#include +#include +#include + +namespace NeoML { + +//--------------------------------------------------------------------------------------------------------------------- + +CMultiheadAttentionPerformerLayer::CMultiheadAttentionPerformerLayer( IMathEngine& mathEngine ) : + CCompositeLayer( mathEngine ), + activationKernel( 0 ), + randomFeaturesCount( 0 ), + casual( true ), + headCount( 1 ), + hiddenSize( 8 ), + outputSize( 8 ) +{} + +void CMultiheadAttentionPerformerLayer::SetActivationKernel( int _activationKernel, int _randomFeaturesCount, bool _casual ) +{ + NeoAssert( _activationKernel == 0 || _activationKernel == 1 ); + NeoAssert( _randomFeaturesCount >= 0 ); + + if( activationKernel != _activationKernel + || randomFeaturesCount != _randomFeaturesCount + || casual != _casual ) + { + activationKernel = _activationKernel; + randomFeaturesCount = _randomFeaturesCount; + casual = _casual; + + DeleteAllLayers(); + } +} + +void CMultiheadAttentionPerformerLayer::SetHeadCount( int _headCount ) +{ + NeoAssert( _headCount >= 1 ); + if( headCount != _headCount ) { + headCount = _headCount; + DeleteAllLayers(); + } +} + +void CMultiheadAttentionPerformerLayer::SetHiddenSize( int _hiddenSize ) +{ + NeoAssert( _hiddenSize >= 1 ); + if( hiddenSize != _hiddenSize ) { + hiddenSize = _hiddenSize; + DeleteAllLayers(); + } +} + +void CMultiheadAttentionPerformerLayer::SetOutputSize( int _outputSize ) +{ + NeoAssert( _outputSize > 0 ); + if( outputSize != _outputSize ) { + outputSize = _outputSize; + DeleteAllLayers(); + } +} + +static const int MultiheadAttentionPerformerLayerVersion = 0; + +void CMultiheadAttentionPerformerLayer::Serialize( CArchive& archive ) +{ + ( void ) archive.SerializeVersion( MultiheadAttentionPerformerLayerVersion ); + CCompositeLayer::Serialize( archive ); + + archive.Serialize( activationKernel ); + archive.Serialize( randomFeaturesCount ); + archive.Serialize( casual ); + archive.Serialize( headCount ); + archive.Serialize( hiddenSize ); + archive.Serialize( outputSize ); +} + +void CMultiheadAttentionPerformerLayer::Reshape() +{ + CheckInputs(); + CheckLayerArchitecture( GetInputCount() == 3, "MultiheadAttentionPerformer layer inputs count should be 3" ); + CheckLayerArchitecture( GetOutputCount() == 1, "MultiheadAttentionPerformer layer outputs count should be 1" ); + + if( !isCreated() ) { + create(); + } + CFullyConnectedLayer* Q = CheckCast( GetLayer( "Q" ) ); + const bool uninitializedWeights = ( Q->Weights() == nullptr ); + + CCompositeLayer::Reshape(); + + if( uninitializedWeights ) { + // Glorot initialization + + // For layers for linearly projecting the queries, keys, and values initialization + const int inputSize = inputDescs[I_Q].Channels(); + const float attentionGlorotLimit = static_cast( std::sqrt( 6.0 / ( inputSize + hiddenSize ) ) ); + CDnnUniformInitializer attentionInitializer( GetDnn()->Random(), -attentionGlorotLimit, attentionGlorotLimit ); + + attentionInitializer.InitializeLayerParams( *( Q->Weights() ), /*unused*/0 ); + attentionInitializer.InitializeLayerParams( *( Q->FreeTerms() ), /*unused*/0 ); + + CFullyConnectedLayer* K = CheckCast( GetLayer( "K" ) ); + attentionInitializer.InitializeLayerParams( *( K->Weights() ), /*unused*/0 ); + attentionInitializer.InitializeLayerParams( *( K->FreeTerms() ), /*unused*/0 ); + + CFullyConnectedLayer* V = CheckCast( GetLayer( "V" ) ); + attentionInitializer.InitializeLayerParams( *( V->Weights() ), /*unused*/0 ); + attentionInitializer.InitializeLayerParams( *( V->FreeTerms() ), /*unused*/0 ); + + // Output layer + const float outputGlorotLimit = static_cast( std::sqrt( 6.0 / ( hiddenSize + hiddenSize ) ) ); + CDnnUniformInitializer outputInitializer( GetDnn()->Random(), -outputGlorotLimit, outputGlorotLimit ); + CFullyConnectedLayer* Out = CheckCast( GetLayer( "Out.Dense" ) ); + outputInitializer.InitializeLayerParams( *( Out->Weights() ), /*unused*/0 ); + outputInitializer.InitializeLayerParams( *( Out->FreeTerms() ), /*unused*/0 ); + } +} + +// Recreates the layer if forceRebuild is true or it doesn't contain sublayers +void CMultiheadAttentionPerformerLayer::Rebuild( bool forceRebuild ) +{ + if( forceRebuild && isCreated() ) { + DeleteAllLayers(); + } + if ( !isCreated() ) { + create(); + } +} + +// Creates layer with new parameters +// Here and further blob sizes are shown as [BatchWidth, ListSize, Width, Channels] +void CMultiheadAttentionPerformerLayer::create() +{ + NeoAssert( headCount > 0 ); + NeoAssert( hiddenSize % headCount == 0 ); + + // Applying W_Q, W_K and W_V to the corresponding inputs + // [B, seq_Q, 1, hiddenSize] + CBaseLayer* Q = multiplyInputByMatrixWeights( hiddenSize, "Q", I_Q ); + + // [B, seq_to, 1, hiddenSize] + CBaseLayer* K = multiplyInputByMatrixWeights( hiddenSize, "K", I_K ); + CBaseLayer* V = multiplyInputByMatrixWeights( hiddenSize, "V", I_V ); + + // [B, n_head, seq_Q, d_k] + Q = prepareQ( Q ); + // [B, n_head, seq_to, d_k] + K = prepareKV( K, true ); + // [B, n_head, seq_to, d_k] + V = prepareKV( V, false ); + + CPtr favor = new CFavorAttentionPerformerLayer( MathEngine(), "favor" ); + favor->SetActivationKernel( activationKernel ); + favor->SetRandomFeaturesCount( randomFeaturesCount ); + favor->SetCausal( casual ); + favor->Connect( CFavorAttentionPerformerLayer::TI_Q, *Q ); + favor->Connect( CFavorAttentionPerformerLayer::TI_K, *K ); + favor->Connect( CFavorAttentionPerformerLayer::TI_V, *V ); + AddLayer( *favor ); + + // [B, seq_Q, 1, hidden_size] + CPtr output = prepareOutput( favor ); + output = multiplyByMatrixWeights( output, outputSize ); + + SetOutputMapping( /*O_Output*/0, *output ); +} + +// Multiplies input by trainable weights +CBaseLayer* CMultiheadAttentionPerformerLayer::multiplyInputByMatrixWeights( + int size, const char* name, TInputs input ) +{ + NeoAssert( size > 0 ); + + CPtr fcLayer = new CFullyConnectedLayer( MathEngine(), name ); + fcLayer->SetNumberOfElements( size ); + fcLayer->SetZeroFreeTerm( false ); + AddLayer( *fcLayer ); + + // Connect input with this sublayer + SetInputMapping( input, *fcLayer, 0 ); + + return fcLayer; +} + +// Multiplies by trainable weights +CBaseLayer* CMultiheadAttentionPerformerLayer::multiplyByMatrixWeights( CBaseLayer* input, int width ) +{ + NeoAssert( width >= 0 ); + NeoAssert( input != 0 ); + + CPtr fcLayer = new CFullyConnectedLayer( MathEngine(), "Out.Dense" ); + fcLayer->SetNumberOfElements( width ); + fcLayer->SetZeroFreeTerm( false ); + fcLayer->Connect( *input ); + AddLayer( *fcLayer ); + + return fcLayer; +} + +// [B, n_head, seq_Q, d_k] +CBaseLayer* CMultiheadAttentionPerformerLayer::prepareQ( CBaseLayer* input ) +{ + NeoAssert( input != 0 ); + + // [B, seq_Q, n_head, d_k] + CPtr reshape0 = new CTransformLayer( MathEngine() ); + reshape0->SetName( "Q.reshape0" ); + reshape0->Connect( *input ); + reshape0->SetDimensionRule( BD_BatchLength, CTransformLayer::O_Multiply, 1 ); + reshape0->SetDimensionRule( BD_BatchWidth, CTransformLayer::O_Multiply, 1 ); + reshape0->SetDimensionRule( BD_ListSize, CTransformLayer::O_Multiply, 1 ); + reshape0->SetDimensionRule( BD_Height, CTransformLayer::O_SetSize, 1 ); + reshape0->SetDimensionRule( BD_Width, CTransformLayer::O_SetSize, headCount ); + reshape0->SetDimensionRule( BD_Depth, CTransformLayer::O_SetSize, 1 ); + reshape0->SetDimensionRule( BD_Channels, CTransformLayer::O_SetSize, hiddenSize / headCount ); + AddLayer( *reshape0 ); + + // [B, n_head, seq_Q, d_k] + CPtr transpose0 = new CTransposeLayer( MathEngine() ); + transpose0->SetName( "Q.transpose0" ); + transpose0->SetTransposedDimensions( BD_ListSize, BD_Width ); + transpose0->Connect( *reshape0 ); + AddLayer( *transpose0 ); + + return transpose0; +} + +// [B, n_head, seq_to, d_k] +CBaseLayer* CMultiheadAttentionPerformerLayer::prepareKV( CBaseLayer* input, bool isK ) +{ + NeoAssert( input != 0 ); + + // [B, seq_to, n_head, d_k] + CPtr reshape0 = new CTransformLayer( MathEngine() ); + reshape0->SetName( isK ? "K.reshape0" : "V.reshape0" ); + reshape0->Connect( *input ); + reshape0->SetDimensionRule( BD_BatchLength, CTransformLayer::O_Multiply, 1 ); + reshape0->SetDimensionRule( BD_BatchWidth, CTransformLayer::O_Multiply, 1 ); + reshape0->SetDimensionRule( BD_ListSize, CTransformLayer::O_Multiply, 1 ); + reshape0->SetDimensionRule( BD_Height, CTransformLayer::O_SetSize, 1 ); + reshape0->SetDimensionRule( BD_Width, CTransformLayer::O_SetSize, headCount ); + reshape0->SetDimensionRule( BD_Depth, CTransformLayer::O_SetSize, 1 ); + reshape0->SetDimensionRule( BD_Channels, CTransformLayer::O_SetSize, hiddenSize / headCount ); + AddLayer( *reshape0 ); + + // [B, n_head, seq_to, d_k] + CPtr transpose0 = new CTransposeLayer( MathEngine() ); + transpose0->SetName( isK ? "K.transpose0" : "V.transpose0" ); + transpose0->SetTransposedDimensions( BD_ListSize, BD_Width ); + transpose0->Connect( *reshape0 ); + AddLayer( *transpose0 ); + + return transpose0; +} + +// [B, seq_Q, 1, hidden_size] +CBaseLayer* CMultiheadAttentionPerformerLayer::prepareOutput( CBaseLayer* input ) +{ + NeoAssert( input != 0 ); + + // [B, seq_Q, n_head, d_k] + CPtr transpose0 = new CTransposeLayer( MathEngine() ); + transpose0->SetName( "Out.transpose0.Out" ); + transpose0->SetTransposedDimensions( BD_ListSize, BD_Width ); + transpose0->Connect( *input ); + AddLayer( *transpose0 ); + + // [B, seq_Q, 1, hidden_size] + CPtr reshape0 = new CTransformLayer( MathEngine() ); + reshape0->SetName( "Out.reshape0.Out" ); + reshape0->Connect( *transpose0 ); + reshape0->SetDimensionRule( BD_BatchLength, CTransformLayer::O_Multiply, 1 ); + reshape0->SetDimensionRule( BD_BatchWidth, CTransformLayer::O_Multiply, 1 ); + reshape0->SetDimensionRule( BD_ListSize, CTransformLayer::O_Multiply, 1 ); + reshape0->SetDimensionRule( BD_Height, CTransformLayer::O_SetSize, 1 ); + reshape0->SetDimensionRule( BD_Width, CTransformLayer::O_SetSize, 1 ); + reshape0->SetDimensionRule( BD_Depth, CTransformLayer::O_SetSize, 1 ); + reshape0->SetDimensionRule( BD_Channels, CTransformLayer::O_SetSize, hiddenSize ); + AddLayer( *reshape0 ); + + return reshape0; +} + +//--------------------------------------------------------------------------------------------------------------------- + +CLayerWrapper MultiheadAttentionPerformer( + int headCount, int hiddenSize, int outputSize, int activationKernel, int randomFeaturesCount, bool casual ) +{ + return CLayerWrapper( "MultiheadAttentionPerformer", + [=]( CMultiheadAttentionPerformerLayer* result ) { + result->SetHeadCount( headCount ); + result->SetHiddenSize( hiddenSize ); + result->SetOutputSize( outputSize ); + result->SetActivationKernel( activationKernel, randomFeaturesCount, casual ); + } ); +} + +} // namespace NeoML diff --git a/NeoML/test/data/LayersSerializationTestData/NeoMLDnnFavorAttentionPerformerLayer.arch b/NeoML/test/data/LayersSerializationTestData/NeoMLDnnFavorAttentionPerformerLayer.arch new file mode 100644 index 0000000000..465042cbb6 Binary files /dev/null and b/NeoML/test/data/LayersSerializationTestData/NeoMLDnnFavorAttentionPerformerLayer.arch differ diff --git a/NeoML/test/data/LayersSerializationTestData/NeoMLDnnMultiheadAttentionPerformerLayer.arch b/NeoML/test/data/LayersSerializationTestData/NeoMLDnnMultiheadAttentionPerformerLayer.arch new file mode 100644 index 0000000000..02c2b0b731 Binary files /dev/null and b/NeoML/test/data/LayersSerializationTestData/NeoMLDnnMultiheadAttentionPerformerLayer.arch differ diff --git a/NeoML/test/src/DnnLayersSerializationTest.cpp b/NeoML/test/src/DnnLayersSerializationTest.cpp index e919650a35..4e37f3fc31 100644 --- a/NeoML/test/src/DnnLayersSerializationTest.cpp +++ b/NeoML/test/src/DnnLayersSerializationTest.cpp @@ -1,4 +1,4 @@ -/* Copyright © 2021-2023 ABBYY +/* Copyright © 2021-2024 ABBYY Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -1977,6 +1977,76 @@ GTEST_TEST( SerializeFromFile, MultiheadAttentionLayerSerialization ) // ==================================================================================================================== +// CMultiheadAttentionPerformerLayer + +#ifdef GENERATE_SERIALIZATION_FILES + +static void setSpecificParams( CMultiheadAttentionPerformerLayer& layer ) +{ + layer.SetHeadCount( 5 ); + layer.SetHiddenSize( 25 ); + layer.SetOutputSize( 123 ); + layer.SetActivationKernel( /*ReLU*/1, /*randomFeaturesCount*/7, /*casual*/true ); +} + +GTEST_TEST( SerializeToFile, MultiheadAttentionPerformerLayerSerialization ) +{ + serializeToFile( "NeoMLDnnMultiheadAttentionPerformerLayer" ); +} + +#endif // GENERATE_SERIALIZATION_FILES + +template<> +inline void checkSpecificParams( CMultiheadAttentionPerformerLayer& layer ) +{ + EXPECT_EQ( 5, layer.GetHeadCount() ); + EXPECT_EQ( 25, layer.GetHiddenSize() ); + EXPECT_EQ( 123, layer.GetOutputSize() ); + EXPECT_EQ( /*ReLU*/1, layer.GetActivationKernel() ); + EXPECT_EQ( 7, layer.GetRandomFeaturesCount() ); + EXPECT_EQ( true, layer.GetCasual() ); +} + +GTEST_TEST( SerializeFromFile, MultiheadAttentionPerformerLayerSerialization ) +{ + checkSerializeLayer( "NeoMLDnnMultiheadAttentionPerformerLayer" ); +} + +// ==================================================================================================================== + +// CFavorAttentionPerformerLayer + +#ifdef GENERATE_SERIALIZATION_FILES + +static void setSpecificParams( CFavorAttentionPerformerLayer& layer ) +{ + layer.SetRandomFeaturesCount( 5 ); + layer.SetActivationKernel( /*ReLU*/1 ); + layer.SetCausal( true ); +} + +GTEST_TEST( SerializeToFile, FavorAttentionPerformerLayerSerialization ) +{ + serializeToFile( "NeoMLDnnFavorAttentionPerformerLayer" ); +} + +#endif // GENERATE_SERIALIZATION_FILES + +template<> +inline void checkSpecificParams( CFavorAttentionPerformerLayer& layer ) +{ + EXPECT_EQ( 5, layer.GetRandomFeaturesCount() ); + EXPECT_EQ( /*ReLU*/1, layer.GetActivationKernel() ); + EXPECT_EQ( true, layer.GetCausal() ); +} + +GTEST_TEST( SerializeFromFile, FavorAttentionPerformerLayerSerialization ) +{ + checkSerializeLayer( "NeoMLDnnFavorAttentionPerformerLayer" ); +} + +// ==================================================================================================================== + // CPositionalEmbeddingLayer #ifdef GENERATE_SERIALIZATION_FILES