diff --git a/tsfc/driver.py b/tsfc/driver.py index 3da8ba91..7297d19c 100644 --- a/tsfc/driver.py +++ b/tsfc/driver.py @@ -267,7 +267,8 @@ def name_multiindex(multiindex, name): return builder.construct_kernel(kernel_name, impero_c, parameters["precision"], index_names, quad_rule) -def compile_expression_at_points(expression, points, coordinates, interface=None, parameters=None, coffee=True): +def compile_expression_at_points(expression, points, coordinates, interface=None, + parameters=None, coffee=True): """Compiles a UFL expression to be evaluated at compile-time known reference points. Useful for interpolating UFL expressions onto function spaces with only point evaluation nodes. @@ -289,10 +290,6 @@ def compile_expression_at_points(expression, points, coordinates, interface=None _.update(parameters) parameters = _ - # No arguments, please! - if extract_arguments(expression): - return ValueError("Cannot interpolate UFL expression with Arguments!") - # Determine whether in complex mode complex_mode = is_complex(parameters["scalar_type"]) @@ -311,6 +308,9 @@ def compile_expression_at_points(expression, points, coordinates, interface=None interface = firedrake_interface_loopy.ExpressionKernelBuilder builder = interface(parameters["scalar_type"]) + arguments = extract_arguments(expression) + argument_multiindices = tuple(builder.create_element(arg.ufl_element()).get_indices() + for arg in arguments) # Replace coordinates (if any) domain = expression.ufl_domain() @@ -333,7 +333,8 @@ def compile_expression_at_points(expression, points, coordinates, interface=None config = dict(interface=builder, ufl_cell=coordinates.ufl_domain().ufl_cell(), precision=parameters["precision"], - point_set=point_set) + point_set=point_set, + argument_multiindices=argument_multiindices) ir, = fem.compile_ufl(expression, point_sum=False, **config) # Deal with non-scalar expressions @@ -343,8 +344,8 @@ def compile_expression_at_points(expression, points, coordinates, interface=None ir = gem.Indexed(ir, tensor_indices) # Build kernel body - return_shape = (len(points),) + value_shape - return_indices = point_set.indices + tensor_indices + return_indices = point_set.indices + tensor_indices + tuple(chain(*argument_multiindices)) + return_shape = tuple(i.extent for i in return_indices) return_var = gem.Variable('A', return_shape) if coffee: return_arg = ast.Decl(parameters["scalar_type"], ast.Symbol('A', rank=return_shape))