Skip to content

Commit

Permalink
Derivative tabulation by solving Vandermonde system with recursive GL…
Browse files Browse the repository at this point in the history
… points
  • Loading branch information
pbrubeck committed Oct 17, 2023
1 parent b52018f commit 58a10c6
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions FIAT/polynomial_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
import numpy
from FIAT import expansions
from FIAT.functional import index_iterator
from FIAT.reference_element import UFCInterval
from FIAT.quadrature import GaussLegendreQuadratureLineRule
from FIAT.recursive_points import RecursivePointSet


def mis(m, n):
Expand Down Expand Up @@ -125,6 +128,7 @@ class ONPolynomialSet(PolynomialSet):
for vector- and tensor-valued sets as well.
"""
point_set = RecursivePointSet(lambda n: GaussLegendreQuadratureLineRule(UFCInterval(), n + 1).get_points())

def __init__(self, ref_el, degree, shape=tuple()):

Expand Down Expand Up @@ -155,23 +159,20 @@ def __init__(self, ref_el, degree, shape=tuple()):
cur_idx = tuple([cur_bf] + list(idx) + [exp_bf])
coeffs[cur_idx] = 1.0
cur_bf += 1

# construct dmats
if degree == 0:
dmats = [numpy.array([[0.0]], "d") for i in range(sd)]
else:
pts = ref_el.make_points(sd, 0, degree + sd + 1)
pts = self.point_set.recursive_points(ref_el.get_vertices(), degree)

v = numpy.transpose(expansion_set.tabulate(degree, pts))
vinv = numpy.linalg.inv(v)

dv = expansion_set.tabulate_derivatives(degree, pts)
dtildes = [[[a[1][i] for a in dvrow] for dvrow in dv]
for i in range(sd)]

dmats = [numpy.dot(vinv, numpy.transpose(dtilde))
dmats = [numpy.linalg.solve(v, numpy.transpose(dtilde))
for dtilde in dtildes]

PolynomialSet.__init__(self, ref_el, degree, embedded_degree,
expansion_set, coeffs, dmats)

Expand Down

0 comments on commit 58a10c6

Please sign in to comment.