Skip to content

Commit

Permalink
Added radialBasisFunc and ReLU for regModel (#619)
Browse files Browse the repository at this point in the history
* Added the ReLU activation func.

* Added the radialBasisFunc for regModel.
  • Loading branch information
friedenhe authored Apr 6, 2024
1 parent 1b56851 commit 35e0a4c
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 2 deletions.
71 changes: 69 additions & 2 deletions src/adjoint/DARegression/DARegression.C
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ DARegression::DARegression(
{
regSubDict.readEntry<labelList>("hiddenLayerNeurons", hiddenLayerNeurons_);
regSubDict.readEntry<word>("activationFunction", activationFunction_);
if (activationFunction_ == "ReLU")
{
leakyCoeff_ = regSubDict.lookupOrDefault<scalar>("leakyCoeff", 0.0);
}
}
else if (modelType_ == "radialBasisFunction")
{
nRBFs_ = regSubDict.getLabel("nRBFs");
}
else
{
FatalErrorIn("") << "modelType_: " << modelType_ << " not supported. Options are: neuralNetwork and radialBasisFunction" << abort(FatalError);
}

// initialize parameters and give it large values
Expand Down Expand Up @@ -323,9 +335,16 @@ label DARegression::compute()
{
layerVals[layerI][neuronI] = (1 - exp(-2 * layerVals[layerI][neuronI])) / (1 + exp(-2 * layerVals[layerI][neuronI]));
}
else if (activationFunction_ == "ReLU")
{
if (layerVals[layerI][neuronI] < 0)
{
layerVals[layerI][neuronI] = leakyCoeff_ * layerVals[layerI][neuronI];
}
}
else
{
FatalErrorIn("") << "activationFunction not valid. Options are: sigmoid and tanh" << abort(FatalError);
FatalErrorIn("") << "activationFunction not valid. Options are: sigmoid, tanh, and ReLU" << abort(FatalError);
}
}
}
Expand All @@ -350,9 +369,44 @@ label DARegression::compute()

outputField.correctBoundaryConditions();
}
else if (modelType_ == "radialBasisFunction")
{
List<List<scalar>> inputFields;
inputFields.setSize(inputNames_.size());

this->calcInput(inputFields);

label nInputs = inputNames_.size();

// increment of the parameters for each RBF basis
label dP = 2 * nInputs + 1;

forAll(mesh_.cells(), cellI)
{
scalar outputVal = 0.0;
for (label i = 0; i < nRBFs_; i++)
{
scalar expCoeff = 0.0;
for (label j = 0; j < nInputs; j++)
{
scalar A = (inputFields[j][cellI] - parameters_[dP * i + 2 * j]) * (inputFields[j][cellI] - parameters_[dP * i + 2 * j]);
scalar B = 2 * parameters_[dP * i + 2 * j + 1] * parameters_[dP * i + 2 * j + 1];
expCoeff += A / B;
}
outputVal += parameters_[(dP + 1) * i + dP] * exp(-expCoeff);
}

outputField[cellI] = outputScale_ * (outputVal + outputShift_);
}

// check if the output values are valid otherwise fix/bound them
fail = this->checkOutput(outputField);

outputField.correctBoundaryConditions();
}
else
{
FatalErrorIn("") << "modelType_: " << modelType_ << " not supported. Options are: neuralNetwork" << abort(FatalError);
FatalErrorIn("") << "modelType_: " << modelType_ << " not supported. Options are: neuralNetwork and radialBasisFunction" << abort(FatalError);
}

return fail;
Expand Down Expand Up @@ -397,6 +451,19 @@ label DARegression::nParameters()

return nParameters;
}
else if (modelType_ == "radialBasisFunction")
{
label nInputs = inputNames_.size();

// each RBF has a weight, nInputs mean, and nInputs std
label nParameters = nRBFs_ * (2 * nInputs + 1);

return nParameters;
}
else
{
FatalErrorIn("") << "modelType_: " << modelType_ << " not supported. Options are: neuralNetwork and radialBasisFunction" << abort(FatalError);
}
}

label DARegression::checkOutput(volScalarField& outputField)
Expand Down
6 changes: 6 additions & 0 deletions src/adjoint/DARegression/DARegression.H
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,9 @@ protected:
/// neural network activation function
word activationFunction_;

/// if the ReLU activation function is used we can prescribe a potentially leaky coefficient
scalar leakyCoeff_ = 0.0;

/// the upper bound for the output
scalar outputUpperBound_;

Expand All @@ -93,6 +96,9 @@ protected:
/// default output values
scalar defaultOutputValue_;

/// number of radial basis function
label nRBFs_;

public:
/// Constructors
DARegression(
Expand Down

0 comments on commit 35e0a4c

Please sign in to comment.