Skip to content

Commit

Permalink
Fix the FormSum memory leak (#3897)
Browse files Browse the repository at this point in the history
* Add axpy method

* Add maxpy and tests
---------

Co-authored-by: Connor Ward <[email protected]>
  • Loading branch information
Ig-dolci and connorjward authored Dec 6, 2024
1 parent aa0b077 commit ded8f14
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 2 deletions.
5 changes: 3 additions & 2 deletions firedrake/assemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,8 +469,9 @@ def base_form_assembly_visitor(self, expr, tensor, *args):
return sum(weight * arg for weight, arg in zip(expr.weights(), args))
elif all(isinstance(op, firedrake.Cofunction) for op in args):
V, = set(a.function_space() for a in args)
res = sum([w*op.dat for (op, w) in zip(args, expr.weights())])
return firedrake.Cofunction(V, res)
result = firedrake.Cofunction(V)
result.dat.maxpy(expr.weights(), [a.dat for a in args])
return result
elif all(isinstance(op, ufl.Matrix) for op in args):
res = tensor.petscmat if tensor else PETSc.Mat()
is_set = False
Expand Down
53 changes: 53 additions & 0 deletions pyop2/types/dat.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ctypes
import itertools
import operator
from collections.abc import Sequence

import loopy as lp
import numpy as np
Expand Down Expand Up @@ -492,6 +493,41 @@ def norm(self):
from math import sqrt
return sqrt(self.inner(self).real)

def maxpy(self, scalar: Sequence, x: Sequence) -> None:
"""Compute a sequence of axpy operations.
This is equivalent to calling :meth:`axpy` for each pair of
scalars and :class:`Dat` in the input sequences.
Parameters
----------
scalar :
A sequence of scalars.
x :
A sequence of :class:`Dat`.
"""
if len(scalar) != len(x):
raise ValueError("scalar and x must have the same length")
for alpha_i, x_i in zip(scalar, x):
self.axpy(alpha_i, x_i)

def axpy(self, alpha: float, other: 'Dat') -> None:
"""Compute the operation :math:`y = \\alpha x + y`.
In this case, ``self`` is ``y`` and ``other`` is ``x``.
"""
self._check_shape(other)
if isinstance(other._data, np.ndarray):
if not np.isscalar(alpha):
raise TypeError("alpha must be a scalar")
np.add(
alpha * other.data_ro, self.data_ro,
out=self.data_wo)
else:
raise NotImplementedError("Not implemented for GPU")

def __pos__(self):
pos = Dat(self)
return pos
Expand Down Expand Up @@ -1022,6 +1058,23 @@ def inner(self, other):
ret += s.inner(o)
return ret

def axpy(self, alpha: float, other: 'MixedDat') -> None:
"""Compute the operation :math:`y = \\alpha x + y`.
In this case, ``self`` is ``y`` and ``other`` is ``x``.
"""
self._check_shape(other)
for dat_result, dat_other in zip(self, other):
if isinstance(dat_result._data, np.ndarray):
if not np.isscalar(alpha):
raise TypeError("alpha must be a scalar")
np.add(
alpha * dat_other.data_ro, dat_result.data_ro,
out=dat_result.data_wo)
else:
raise NotImplementedError("Not implemented for GPU")

def _op(self, other, op):
ret = []
if np.isscalar(other):
Expand Down
33 changes: 33 additions & 0 deletions pyop2/types/glob.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ctypes
import operator
import warnings
from collections.abc import Sequence

import numpy as np
from petsc4py import PETSc
Expand Down Expand Up @@ -203,6 +204,38 @@ def inner(self, other):
assert issubclass(type(other), type(self))
return np.dot(self.data_ro, np.conj(other.data_ro))

def maxpy(self, scalar: Sequence, x: Sequence) -> None:
"""Compute a sequence of axpy operations.
This is equivalent to calling :meth:`axpy` for each pair of
scalars and :class:`Dat` in the input sequences.
Parameters
----------
scalar :
A sequence of scalars.
x :
A sequence of `Global`.
"""
if len(scalar) != len(x):
raise ValueError("scalar and x must have the same length")
for alpha_i, x_i in zip(scalar, x):
self.axpy(alpha_i, x_i)

def axpy(self, alpha: float, other: 'Global') -> None:
"""Compute the operation :math:`y = \\alpha x + y`.
In this case, ``self`` is ``y`` and ``other`` is ``x``.
"""
if isinstance(self._data, np.ndarray):
if not np.isscalar(alpha):
raise ValueError("alpha must be a scalar")
np.add(alpha * other.data_ro, self.data_ro, out=self.data_wo)
else:
raise NotImplementedError("Not implemented for GPU")


# must have comm, can be modified in parloop (implies a reduction)
class Global(SetFreeDataCarrier, VecAccessMixin):
Expand Down
16 changes: 16 additions & 0 deletions tests/pyop2/test_dats.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,22 @@ def test_accessing_data_with_halos_increments_dat_version(self, d1):
d1.data_with_halos
assert d1.dat_version == 1

def test_axpy(self, d1):
d2 = op2.Dat(d1.dataset)
d1.data[:] = 0
d2.data[:] = 2
d1.axpy(3, d2)
assert (d1.data_ro == 3 * 2).all()

def test_maxpy(self, d1):
d2 = op2.Dat(d1.dataset)
d3 = op2.Dat(d1.dataset)
d1.data[:] = 0
d2.data[:] = 2
d3.data[:] = 3
d1.maxpy((2, 3), (d2, d3))
assert (d1.data_ro == 2 * 2 + 3 * 3).all()


class TestDatView():

Expand Down

0 comments on commit ded8f14

Please sign in to comment.