Skip to content

Commit

Permalink
Test improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmaddison committed Jul 12, 2024
1 parent 3688343 commit 5dec19a
Showing 1 changed file with 27 additions and 3 deletions.
30 changes: 27 additions & 3 deletions tests/regression/test_adjoint_reverse_over_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def test_subfunction(idx):

with reverse_over_forward():
u = Function(space, name="u")
u.sub(idx).interpolate(-2 * X[0])
u.sub(idx).interpolate(X[0] - 0.5)
u_ref = u.copy(deepcopy=True)
zeta = Function(space, name="zeta")
zeta.sub(idx).interpolate(X[0])
Expand All @@ -140,7 +140,7 @@ def test_subfunction(idx):

@pytest.mark.skipcomplex
def test_interpolate():
mesh = UnitSquareMesh(10, 10)
mesh = UnitIntervalMesh(10)
X = SpatialCoordinate(mesh)
space_a = FunctionSpace(mesh, "Lagrange", 1)
space_b = FunctionSpace(mesh, "Lagrange", 2)
Expand All @@ -162,9 +162,33 @@ def test_interpolate():
assemble(6 * inner(u_ref * zeta, test_a) * dx).dat.data_ro)


@pytest.mark.skipcomplex
def test_interpolate_expr():
mesh = UnitIntervalMesh(10)
X = SpatialCoordinate(mesh)
space_a = FunctionSpace(mesh, "Lagrange", 1)
space_b = FunctionSpace(mesh, "Lagrange", 2)
test_a = TestFunction(space_a)

with reverse_over_forward():
u = Function(space_a, name="u").interpolate(X[0] - 0.5)
u_ref = u.copy(deepcopy=True)
zeta = Function(space_a, name="zeta").interpolate(X[0])
u.block_variable.tlm_value = zeta.copy(deepcopy=True)

v = Function(space_b, name="v").interpolate(-3 * u)
J = assemble(v ** 3 * dx)

_ = compute_gradient(J.block_variable.tlm_value, Control(u))
adj_value = u.block_variable.adj_value
assert np.allclose(
adj_value.dat.data_ro,
assemble(-162 * inner(u_ref * zeta, test_a) * dx).dat.data_ro)


@pytest.mark.skipcomplex
def test_project():
mesh = UnitSquareMesh(10, 10)
mesh = UnitIntervalMesh(10)
X = SpatialCoordinate(mesh)
space_a = FunctionSpace(mesh, "Lagrange", 1)
space_b = FunctionSpace(mesh, "Discontinuous Lagrange", 0)
Expand Down

0 comments on commit 5dec19a

Please sign in to comment.