Skip to content

Commit

Permalink
Fix iterator, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
pbrubeck committed Oct 14, 2023
1 parent 564d4d0 commit b2ece28
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 16 deletions.
12 changes: 6 additions & 6 deletions FIAT/recursive_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,14 @@ def multiindex_equal(d, k, interior=0):
"""
if d <= 0:
return
m = (d-1) * interior
if k <= m:
kmin = interior
kmax = k - (d-1) * kmin
if kmax < kmin:
return
for i in range(interior, k - interior):
for a in multiindex_equal(d-1, k-i, interior=interior):
for i in range(kmin, kmax):
for a in multiindex_equal(d-1, k-i, interior=kmin):
yield (i,) + a
if m < interior + 1:
yield (k - m,) + (interior,)*(d-1)
yield (kmax,) + (kmin,)*(d-1)


class NodeFamily:
Expand Down
12 changes: 12 additions & 0 deletions test/unit/test_fiat.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,21 @@ def __init__(self, a, b):
"GaussLegendre(I, 0)",
"GaussLegendre(I, 1)",
"GaussLegendre(I, 2)",
"GaussLegendre(T, 0)",
"GaussLegendre(T, 1)",
"GaussLegendre(T, 2)",
"GaussLegendre(S, 0)",
"GaussLegendre(S, 1)",
"GaussLegendre(S, 2)",
"GaussLobattoLegendre(I, 1)",
"GaussLobattoLegendre(I, 2)",
"GaussLobattoLegendre(I, 3)",
"GaussLobattoLegendre(T, 1)",
"GaussLobattoLegendre(T, 2)",
"GaussLobattoLegendre(T, 3)",
"GaussLobattoLegendre(S, 1)",
"GaussLobattoLegendre(S, 2)",
"GaussLobattoLegendre(S, 3)",
"Bubble(I, 2)",
"Bubble(T, 3)",
"Bubble(S, 4)",
Expand Down
12 changes: 7 additions & 5 deletions test/unit/test_gauss_legendre.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,21 @@
import numpy as np


@pytest.mark.parametrize("degree", range(1, 7))
def test_gl_basis_values(degree):
@pytest.mark.parametrize("degree", range(1, 5))
@pytest.mark.parametrize("dim", range(1, 4))
def test_gl_basis_values(dim, degree):
"""Ensure that integrating a simple monomial produces the expected results."""
from FIAT import ufc_simplex, GaussLegendre, make_quadrature

s = ufc_simplex(1)
s = ufc_simplex(dim)
q = make_quadrature(s, degree + 1)

fe = GaussLegendre(s, degree)
tab = fe.tabulate(0, q.pts)[(0,)]
tab = fe.tabulate(0, q.pts)[(0,)*dim]

for test_degree in range(degree + 1):
coefs = [n(lambda x: x[0]**test_degree) for n in fe.dual.nodes]
v = lambda x: x[0]**test_degree
coefs = [n(v) for n in fe.dual.nodes]
integral = np.dot(coefs, np.dot(tab, q.wts))
reference = np.dot([x[0]**test_degree
for x in q.pts], q.wts)
Expand Down
12 changes: 7 additions & 5 deletions test/unit/test_gauss_lobatto_legendre.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,21 @@
import numpy as np


@pytest.mark.parametrize("degree", range(1, 7))
def test_gll_basis_values(degree):
@pytest.mark.parametrize("degree", range(1, 5))
@pytest.mark.parametrize("dim", range(1, 4))
def test_gll_basis_values(dim, degree):
"""Ensure that integrating a simple monomial produces the expected results."""
from FIAT import ufc_simplex, GaussLobattoLegendre, make_quadrature

s = ufc_simplex(1)
s = ufc_simplex(dim)
q = make_quadrature(s, degree + 1)

fe = GaussLobattoLegendre(s, degree)
tab = fe.tabulate(0, q.pts)[(0,)]
tab = fe.tabulate(0, q.pts)[(0,)*dim]

for test_degree in range(degree + 1):
coefs = [n(lambda x: x[0]**test_degree) for n in fe.dual.nodes]
v = lambda x: x[0]**test_degree
coefs = [n(v) for n in fe.dual.nodes]
integral = np.dot(coefs, np.dot(tab, q.wts))
reference = np.dot([x[0]**test_degree
for x in q.pts], q.wts)
Expand Down

0 comments on commit b2ece28

Please sign in to comment.