Skip to content
This repository has been archived by the owner on Dec 6, 2024. It is now read-only.

Commit

Permalink
Merge pull request #192 from firedrakeproject/interpolation-operator
Browse files Browse the repository at this point in the history
Interpolation operator
  • Loading branch information
dham authored Dec 9, 2019
2 parents e1a2c17 + 282dee1 commit 9bd9ec7
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions tsfc/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,8 @@ def name_multiindex(multiindex, name):
return builder.construct_kernel(kernel_name, impero_c, parameters["precision"], index_names, quad_rule)


def compile_expression_at_points(expression, points, coordinates, interface=None, parameters=None, coffee=True):
def compile_expression_at_points(expression, points, coordinates, interface=None,
parameters=None, coffee=True):
"""Compiles a UFL expression to be evaluated at compile-time known
reference points. Useful for interpolating UFL expressions onto
function spaces with only point evaluation nodes.
Expand All @@ -289,10 +290,6 @@ def compile_expression_at_points(expression, points, coordinates, interface=None
_.update(parameters)
parameters = _

# No arguments, please!
if extract_arguments(expression):
return ValueError("Cannot interpolate UFL expression with Arguments!")

# Determine whether in complex mode
complex_mode = is_complex(parameters["scalar_type"])

Expand All @@ -311,6 +308,9 @@ def compile_expression_at_points(expression, points, coordinates, interface=None
interface = firedrake_interface_loopy.ExpressionKernelBuilder

builder = interface(parameters["scalar_type"])
arguments = extract_arguments(expression)
argument_multiindices = tuple(builder.create_element(arg.ufl_element()).get_indices()
for arg in arguments)

# Replace coordinates (if any)
domain = expression.ufl_domain()
Expand All @@ -333,7 +333,8 @@ def compile_expression_at_points(expression, points, coordinates, interface=None
config = dict(interface=builder,
ufl_cell=coordinates.ufl_domain().ufl_cell(),
precision=parameters["precision"],
point_set=point_set)
point_set=point_set,
argument_multiindices=argument_multiindices)
ir, = fem.compile_ufl(expression, point_sum=False, **config)

# Deal with non-scalar expressions
Expand All @@ -343,8 +344,8 @@ def compile_expression_at_points(expression, points, coordinates, interface=None
ir = gem.Indexed(ir, tensor_indices)

# Build kernel body
return_shape = (len(points),) + value_shape
return_indices = point_set.indices + tensor_indices
return_indices = point_set.indices + tensor_indices + tuple(chain(*argument_multiindices))
return_shape = tuple(i.extent for i in return_indices)
return_var = gem.Variable('A', return_shape)
if coffee:
return_arg = ast.Decl(parameters["scalar_type"], ast.Symbol('A', rank=return_shape))
Expand Down

0 comments on commit 9bd9ec7

Please sign in to comment.