From 5f18075fc558b11ab83aa37589643954133e5708 Mon Sep 17 00:00:00 2001 From: ksagiyam <46749170+ksagiyam@users.noreply.github.com> Date: Tue, 18 Jun 2024 10:26:54 +0100 Subject: [PATCH] composed map: add permute method (#723) --------- Co-authored-by: Connor Ward --- pyop2/codegen/builder.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/pyop2/codegen/builder.py b/pyop2/codegen/builder.py index 89cf31fcf..505dc5d2b 100644 --- a/pyop2/codegen/builder.py +++ b/pyop2/codegen/builder.py @@ -75,7 +75,10 @@ def shape(self): def dtype(self): return self.values.dtype - def indexed(self, multiindex, layer=None, permute=lambda x: x): + def _permute(self, x): + return x + + def indexed(self, multiindex, layer=None): n, i, f = multiindex if layer is not None and self.offset is not None: # For extruded mesh, prefetch the indirections for each map, so that they don't @@ -84,7 +87,7 @@ def indexed(self, multiindex, layer=None, permute=lambda x: x): base_key = None if base_key not in self.prefetch: j = Index() - base = Indexed(self.values, (n, permute(j))) + base = Indexed(self.values, (n, self._permute(j))) self.prefetch[base_key] = Materialise(PackInst(), base, MultiIndex(j)) base = self.prefetch[base_key] @@ -122,17 +125,17 @@ def indexed(self, multiindex, layer=None, permute=lambda x: x): return Indexed(self.prefetch[key], (f, i)), (f, i) else: assert f.extent == 1 or f.extent is None - base = Indexed(self.values, (n, permute(i))) + base = Indexed(self.values, (n, self._permute(i))) return base, (f, i) - def indexed_vector(self, n, shape, layer=None, permute=lambda x: x): + def indexed_vector(self, n, shape, layer=None): shape = self.shape[1:] + shape if self.interior_horizontal: shape = (2, ) + shape else: shape = (1, ) + shape f, i, j = (Index(e) for e in shape) - base, (f, i) = self.indexed((n, i, f), layer=layer, permute=permute) + base, (f, i) = self.indexed((n, i, f), layer=layer) init = Sum(Product(base, Literal(numpy.int32(j.extent))), j) pack = Materialise(PackInst(), init, MultiIndex(f, i, j)) multiindex = tuple(Index(e) for e in pack.shape) @@ -168,13 +171,8 @@ def __init__(self, map_, permutation): self.offset_quotient = map_.offset_quotient self.permutation = NamedLiteral(permutation, parent=self.values, suffix=f"permutation{count}") - def indexed(self, multiindex, layer=None): - permute = lambda x: Indexed(self.permutation, (x,)) - return super().indexed(multiindex, layer=layer, permute=permute) - - def indexed_vector(self, n, shape, layer=None): - permute = lambda x: Indexed(self.permutation, (x,)) - return super().indexed_vector(n, shape, layer=layer, permute=permute) + def _permute(self, x): + return Indexed(self.permutation, (x,)) class CMap(Map):