diff --git a/tessellate_ipu/lax/tile_lax_gather.py b/tessellate_ipu/lax/tile_lax_gather.py index 211ab9a..33c46e3 100644 --- a/tessellate_ipu/lax/tile_lax_gather.py +++ b/tessellate_ipu/lax/tile_lax_gather.py @@ -93,7 +93,7 @@ def ipu_gather_primitive_translation( numBaseElements=operand.size, # Number of elements in input. maxElementsPerWorker=int(np.ceil(start_indices.size / num_context_workers)), regionSize=1, # TODO: understand? - splitSingleRegion=True, # Split regions between threads? + splitSingleRegion=False, # Split regions between threads? TODO: understand! ) # TODO: should we use `split offsets` between threads? # For now: need to do it manually at the Python `tile_map` level. diff --git a/tessellate_ipu/lax/tile_lax_scatter.py b/tessellate_ipu/lax/tile_lax_scatter.py index 8a60b1f..7466c85 100644 --- a/tessellate_ipu/lax/tile_lax_scatter.py +++ b/tessellate_ipu/lax/tile_lax_scatter.py @@ -192,7 +192,7 @@ def ipu_scatter_primitive_translation( maxElementsPerWorker=int(np.ceil(operand.size / num_context_workers)), regionSize=1, # TODO: understand? indicesAreSorted=False, - splitSingleRegion=True, + splitSingleRegion=False, # Split regions between threads? TODO: understand! ) # For now: need to do it manually at the Python `tile_map` level. ipu_prim_info = IpuTileMapEquation( diff --git a/tests/lax/test_tile_lax_gather.py b/tests/lax/test_tile_lax_gather.py index 0a3af15..006abc1 100644 --- a/tests/lax/test_tile_lax_gather.py +++ b/tests/lax/test_tile_lax_gather.py @@ -5,12 +5,14 @@ import jax import numpy as np import numpy.testing as npt +import pytest from absl.testing import parameterized from tessellate_ipu import tile_map, tile_put_replicated from tessellate_ipu.lax import gather_p +@pytest.mark.ipu_hardware class IpuTilePrimitivesLaxGather(chex.TestCase, parameterized.TestCase): def setUp(self): super().setUp()