From d6a9ceb2f0d821d4b42b61fad8db993e0eca910e Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Mon, 25 Sep 2023 21:19:58 +0100 Subject: [PATCH] Fix gather and scatter bug in IPU hardware (#25) IPU model is not fully replicating IPU hardware in the case of gather and scatter vertices, where the `splitSingleRegion` seems to be ignored on the IPU model. Setting back `splitSingleRegion=False` solves the issue. One still needs to investigate which configuration of these vertices is the most optimal. --- tessellate_ipu/lax/tile_lax_gather.py | 2 +- tessellate_ipu/lax/tile_lax_scatter.py | 2 +- tests/lax/test_tile_lax_gather.py | 2 ++ 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tessellate_ipu/lax/tile_lax_gather.py b/tessellate_ipu/lax/tile_lax_gather.py index 211ab9a..33c46e3 100644 --- a/tessellate_ipu/lax/tile_lax_gather.py +++ b/tessellate_ipu/lax/tile_lax_gather.py @@ -93,7 +93,7 @@ def ipu_gather_primitive_translation( numBaseElements=operand.size, # Number of elements in input. maxElementsPerWorker=int(np.ceil(start_indices.size / num_context_workers)), regionSize=1, # TODO: understand? - splitSingleRegion=True, # Split regions between threads? + splitSingleRegion=False, # Split regions between threads? TODO: understand! ) # TODO: should we use `split offsets` between threads? # For now: need to do it manually at the Python `tile_map` level. diff --git a/tessellate_ipu/lax/tile_lax_scatter.py b/tessellate_ipu/lax/tile_lax_scatter.py index 8a60b1f..7466c85 100644 --- a/tessellate_ipu/lax/tile_lax_scatter.py +++ b/tessellate_ipu/lax/tile_lax_scatter.py @@ -192,7 +192,7 @@ def ipu_scatter_primitive_translation( maxElementsPerWorker=int(np.ceil(operand.size / num_context_workers)), regionSize=1, # TODO: understand? indicesAreSorted=False, - splitSingleRegion=True, + splitSingleRegion=False, # Split regions between threads? TODO: understand! ) # For now: need to do it manually at the Python `tile_map` level. ipu_prim_info = IpuTileMapEquation( diff --git a/tests/lax/test_tile_lax_gather.py b/tests/lax/test_tile_lax_gather.py index 0a3af15..006abc1 100644 --- a/tests/lax/test_tile_lax_gather.py +++ b/tests/lax/test_tile_lax_gather.py @@ -5,12 +5,14 @@ import jax import numpy as np import numpy.testing as npt +import pytest from absl.testing import parameterized from tessellate_ipu import tile_map, tile_put_replicated from tessellate_ipu.lax import gather_p +@pytest.mark.ipu_hardware class IpuTilePrimitivesLaxGather(chex.TestCase, parameterized.TestCase): def setUp(self): super().setUp()