Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the FormSum memory leak #3897

Merged
merged 12 commits into from
Dec 6, 2024
6 changes: 4 additions & 2 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,10 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
return sum(weight * arg for weight, arg in zip(expr.weights(), args))
elif all(isinstance(op, firedrake.Cofunction) for op in args):
V, = set(a.function_space() for a in args)
res = sum([w*op.dat for (op, w) in zip(args, expr.weights())])
return firedrake.Cofunction(V, res)
result = firedrake.Cofunction(V)
for op, w in zip(args, expr.weights()):
result.dat.axpy(w, op.dat)
return result
elif all(isinstance(op, ufl.Matrix) for op in args):
res = tensor.petscmat if tensor else PETSc.Mat()
is_set = False
Expand Down
29 changes: 29 additions & 0 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,20 @@ def norm(self):
from math import sqrt
return sqrt(self.inner(self).real)

def axpy(self, alpha: float, other: 'Dat') -> None:
connorjward marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the operation :math:`y = \\alpha x + y`.

:arg alpha: a scalar
:arg other: the :class:`Dat` to add to this one
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved

"""
if isinstance(other._data, np.ndarray):
connorjward marked this conversation as resolved.
Show resolved Hide resolved
np.add(
alpha * other.data_ro, self.data_ro,
out=self.data_wo)
else:
raise NotImplementedError("Not implemented for GPU")

def __pos__(self):
pos = Dat(self)
return pos
Expand Down Expand Up @@ -1022,6 +1036,21 @@ def inner(self, other):
ret += s.inner(o)
return ret

def axpy(self, alpha: float, other: 'MixedDat') -> None:
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the operation :math:`y = \\alpha x + y`.

:arg alpha: a scalar
:arg other: the :class:`Dat` to add to this one
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved

"""
for dat_result, dat_other in zip(self, other):
connorjward marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(dat_result._data, np.ndarray):
np.add(
alpha * dat_other.data_ro, dat_result.data_ro,
out=dat_result.data_wo)
else:
raise NotImplementedError("Not implemented for GPU")

def _op(self, other, op):
ret = []
if np.isscalar(other):
Expand Down
8 changes: 8 additions & 0 deletions pyop2/types/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ def inner(self, other):
assert issubclass(type(other), type(self))
return np.dot(self.data_ro, np.conj(other.data_ro))

def axpy(self, alpha, other):
Ig-dolci marked this conversation as resolved.
Show resolved Hide resolved
"""Compute the operation :math:`y = \\alpha x + y`.
"""
JHopeCollins marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(self._data, np.ndarray):
np.add(alpha * other.data_ro, self.data_ro, out=self.data_wo)
else:
raise NotImplementedError("Not implemented for GPU")


# must have comm, can be modified in parloop (implies a reduction)
class Global(SetFreeDataCarrier, VecAccessMixin):
Expand Down
Loading