Skip to content

Commit

Permalink
Fixed a bug in RegPar derivs and enabled fixed ref data for variance …
Browse files Browse the repository at this point in the history
…obj (#548)

* Fixed a bug for the regPar total and added constant ref data for variance obj.

* Updated the test ref.
  • Loading branch information
friedenhe authored Dec 31, 2023
1 parent fae63b7 commit f6d9a19
Show file tree
Hide file tree
Showing 9 changed files with 307 additions and 42 deletions.
31 changes: 18 additions & 13 deletions dafoam/pyDAFoam.py
Original file line number Diff line number Diff line change
Expand Up @@ -2275,7 +2275,7 @@ def calcTotalDerivsACT(self, objFuncName, designVarName, designVarType, dFScalin
else:
self.adjTotalDeriv[objFuncName][designVarName][i] = totalDerivSeq[i]

def calcTotalDerivsRegPar(self, objFuncName, designVarName, accumulateTotal=False):
def calcTotalDerivsRegPar(self, objFuncName, designVarName, dFScaling=1.0, accumulateTotal=False):

xDV = self.DVGeo.getValues()
nDVs = len(xDV[designVarName])
Expand All @@ -2284,34 +2284,39 @@ def calcTotalDerivsRegPar(self, objFuncName, designVarName, accumulateTotal=Fals
if nDVs != nParameters:
raise Error("number of parameters not valid!")

# We assume dFdRegPar is always zero
# call the total deriv

xvArray = self.vec2Array(self.xvVec)
wArray = self.vec2Array(self.wVec)
seedArray = self.vec2Array(self.adjVectors[objFuncName])
parameters = xDV[designVarName].copy(order="C")
productArray = np.zeros(nDVs)
totalDerivArray = np.zeros(nDVs)
dFdRegPar = np.zeros(nDVs)

# calc dFdRegPar
self.solverAD.calcdFdRegParAD(
xvArray, wArray, parameters, objFuncName.encode(), designVarName.encode(), dFdRegPar
)
dFdRegPar *= dFScaling

# calculate dRdFieldT*Psi and save it to totalDeriv
self.solverAD.calcdRdRegParTPsiAD(xvArray, wArray, parameters, seedArray, productArray)
self.solverAD.calcdRdRegParTPsiAD(xvArray, wArray, parameters, seedArray, totalDerivArray)
# all reduce because parameters is a global DV
productArray = self.comm.allreduce(productArray, op=MPI.SUM)
totalDerivArray = self.comm.allreduce(totalDerivArray, op=MPI.SUM)

# totalDeriv = dFdRegPar - dRdRegParT*psi
productArray *= -1.0
totalDerivArray = dFdRegPar - totalDerivArray

# assign the total derivative to self.adjTotalDeriv
if self.adjTotalDeriv[objFuncName][designVarName] is None:
self.adjTotalDeriv[objFuncName][designVarName] = np.zeros(nDVs, self.dtype)

# NOTE: productArray is already in Seq
# NOTE: totalDerivArray is already in Seq because we have called all reduce in dFdRegPar
# and after calcdRdRegParTPsiAD

for i in range(nDVs):
if accumulateTotal is True:
self.adjTotalDeriv[objFuncName][designVarName][i] += productArray[i]
self.adjTotalDeriv[objFuncName][designVarName][i] += totalDerivArray[i]
else:
self.adjTotalDeriv[objFuncName][designVarName][i] = productArray[i]
self.adjTotalDeriv[objFuncName][designVarName][i] = totalDerivArray[i]

def solveAdjointUnsteady(self):
"""
Expand Down Expand Up @@ -2514,7 +2519,7 @@ def solveAdjointUnsteady(self):
fieldType = designVarDict[designVarName]["fieldType"]
self.calcTotalDerivsField(objFuncName, designVarName, fieldType, dFScaling, True)
elif designVarDict[designVarName]["designVarType"] == "RegPar":
self.calcTotalDerivsRegPar(objFuncName, designVarName, True)
self.calcTotalDerivsRegPar(objFuncName, designVarName, dFScaling, True)
else:
raise Error("designVarType not valid!")

Expand Down Expand Up @@ -4064,7 +4069,7 @@ def _initOption(self, name, value):
"Expected data type is %-47s \n "
"Received data type is %-47s" % (name, self.defaultOptions[name][0], type(value))
)

def setRegressionParameter(self, idx, val):
"""
Update the regression parameters
Expand Down
61 changes: 51 additions & 10 deletions src/adjoint/DAObjFunc/DAObjFuncVariance.C
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ DAObjFuncVariance::DAObjFuncVariance(
objFuncDict_.readEntry<labelList>("components", components_);
}

timeDependentRefData_ = objFuncDict_.getLabel("timeDependentRefData");

if (daIndex.adjStateNames.found(varName_))
{
objFuncConInfo_ = {{varName_}};
Expand All @@ -75,14 +77,29 @@ DAObjFuncVariance::DAObjFuncVariance(
scalar deltaT = mesh_.time().deltaT().value();
label nTimeSteps = round(endTime / deltaT);

word checkRefDataFolder;
label nRefValueInstances;
if (timeDependentRefData_)
{
// check if we can find the ref data in the endTime folder
checkRefDataFolder = Foam::name(endTime);
nRefValueInstances = nTimeSteps;
}
else
{
// check if we can find the ref data in the 0 folder
checkRefDataFolder = Foam::name(0);
nRefValueInstances = 1;
}

// check if the reference data files exist
isRefData_ = 1;
if (varType_ == "scalar")
{
volScalarField varData(
IOobject(
varName_ + "Data",
Foam::name(endTime),
checkRefDataFolder,
mesh_,
IOobject::READ_IF_PRESENT,
IOobject::NO_WRITE),
Expand All @@ -100,7 +117,7 @@ DAObjFuncVariance::DAObjFuncVariance(
volVectorField varData(
IOobject(
varName_ + "Data",
Foam::name(endTime),
checkRefDataFolder,
mesh_,
IOobject::READ_IF_PRESENT,
IOobject::NO_WRITE),
Expand Down Expand Up @@ -134,15 +151,23 @@ DAObjFuncVariance::DAObjFuncVariance(
{
// varData file found, we need to read in the ref values for all time instances

refValue_.setSize(nTimeSteps);
refValue_.setSize(nRefValueInstances);

// set refValue
if (varType_ == "scalar")
{
for (label n = 0; n < nTimeSteps; n++)
for (label n = 0; n < nRefValueInstances; n++)
{
scalar t = (n + 1) * deltaT;
word timeName = Foam::name(t);
word timeName;
if (timeDependentRefData_)
{
scalar t = (n + 1) * deltaT;
timeName = Foam::name(t);
}
else
{
timeName = Foam::name(0);
}

volScalarField varData(
IOobject(
Expand Down Expand Up @@ -210,10 +235,18 @@ DAObjFuncVariance::DAObjFuncVariance(
}
else if (varType_ == "vector")
{
for (label n = 0; n < nTimeSteps; n++)
for (label n = 0; n < nRefValueInstances; n++)
{
scalar t = (n + 1) * deltaT;
word timeName = Foam::name(t);
word timeName;
if (timeDependentRefData_)
{
scalar t = (n + 1) * deltaT;
timeName = Foam::name(t);
}
else
{
timeName = Foam::name(0);
}

volVectorField varData(
IOobject(
Expand Down Expand Up @@ -339,7 +372,15 @@ void DAObjFuncVariance::calcObjFunc(

const objectRegistry& db = mesh_.thisDb();

label timeIndex = mesh_.time().timeIndex();
label timeIndex;
if (timeDependentRefData_)
{
timeIndex = mesh_.time().timeIndex();
}
else
{
timeIndex = 1;
}

if (varName_ == "wallShearStress")
{
Expand Down
3 changes: 3 additions & 0 deletions src/adjoint/DAObjFunc/DAObjFuncVariance.H
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ protected:
/// whether we find the reference data
label isRefData_;

/// whether the ref data is time dependent if yes we need data in all time folders otherwise get it from the 0 folder
label timeDependentRefData_;

/// DATurbulenceModel object
const DATurbulenceModel& daTurb_;

Expand Down
Loading

0 comments on commit f6d9a19

Please sign in to comment.