Skip to content

Commit

Permalink
Basic support of jax.lax.gather in TessellateIPU.
Browse files Browse the repository at this point in the history
TessellateIPU `gather` integration using popops `popops::MultiSlice<>` vertex.

At the moment, it only supports slice size 1 and the following gather dimensions
configuration:
```python
jax.lax.GatherDimensionNumbers(
    offset_dims=(),
    collapsed_slice_dims=(0,),
    start_index_map=(0,))
```

Note: it does not take advantage of worker threads in the current configuration.
  • Loading branch information
balancap committed Sep 22, 2023
1 parent e254e59 commit bab7e7a
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 0 deletions.
1 change: 1 addition & 0 deletions tessellate_ipu/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
sub_inplace_p,
)
from .tile_lax_dot import IpuConvVertexType
from .tile_lax_gather import gather_p
from .tile_lax_unary import ( # tanh_inplace_p,
abs_inplace_p,
asin_inplace_p,
Expand Down
113 changes: 113 additions & 0 deletions tessellate_ipu/lax/tile_lax_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from typing import Any, Dict, List, Tuple

import numpy as np
from jax.core import Primitive, ShapedArray
from jax.lax import GatherDimensionNumbers, GatherScatterMode, gather_p

from tessellate_ipu.core import (
IpuTileMapEquation,
make_ipu_vertex_attributes,
make_ipu_vertex_in_info,
make_ipu_vertex_name_templated,
make_ipu_vertex_out_info,
register_ipu_tile_primitive,
)
from tessellate_ipu.utils import DTypeLike


def make_gather_vertex_fullname(dtype: DTypeLike) -> str:
"""Generate popops Gather/MultiSlice vertex name."""
basename = "popops::MultiSlice"
return make_ipu_vertex_name_templated(basename, dtype)


def check_gather_dimension_numbers(dimension_numbers: GatherDimensionNumbers):
"""Check `gather` dimension_numbers is supported on TessellateIPU.
At the moment: basically only supporting a single configuration!
We need to expand on this at some point!
"""
dim_numbers_default = GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))
if dimension_numbers != dim_numbers_default:
raise NotImplementedError(f"TessellateIPU `gather` only support dimension numbers: {dim_numbers_default}.")


def ipu_gather_primitive_translation(
p: Primitive,
tiles: Tuple[int, ...],
inavals: List[ShapedArray],
attributes: Dict[str, Any] = None,
) -> IpuTileMapEquation:
"""IPU `gather` primitive translation rule to IPU vertex.
See: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.gather.html
Args:
p: JAX primitive.
tiles: Collection of tiles.
inavals: Input data + start indices arrays.
attributes: Gather operator attributes
Returns:
IPU tile map primitive structure.
"""
# TODO: query for JAX device.
num_context_workers = 6

assert len(inavals) == 2
assert attributes is not None
operand, start_indices = inavals
# Extract gather attributes
dimension_numbers = attributes["dimension_numbers"]
slice_sizes = attributes["slice_sizes"]
# Default values from JAX LAX interface.
indices_are_sorted = attributes.get("indices_are_sorted", False)
unique_indices = attributes.get("unique_indices", False)
mode = attributes.get("mode", GatherScatterMode.PROMISE_IN_BOUNDS)
fill_value = attributes.get("fill_value", None)

# Check gather attributes are supported by TessellateIPU.
assert operand.ndim == 1
assert start_indices.ndim == 2
assert slice_sizes == (1,)
assert (
mode == GatherScatterMode.PROMISE_IN_BOUNDS
), "Only `PROMISE_IN_BOUNDS` gather mode supported in TessellateIPU."
assert start_indices.dtype == np.uint32, "TessellateIPU `gather` only supports `uint32` indices."
check_gather_dimension_numbers(dimension_numbers)
# Gather output aval.
outaval = p.abstract_eval(
*inavals,
dimension_numbers=dimension_numbers,
slice_sizes=slice_sizes,
indices_are_sorted=indices_are_sorted,
unique_indices=unique_indices,
mode=mode,
fill_value=fill_value,
)[0]

vname = make_gather_vertex_fullname(operand.dtype)
# Construct poplibs MultiSlice vertex attributes.
attrs_i32, attrs_f32 = make_ipu_vertex_attributes(
baseOffset=0, # unused?
numBaseElements=operand.size, # Number of elements in input.
maxElementsPerWorker=int(np.ceil(start_indices.size / num_context_workers)),
regionSize=1, # TODO: understand?
splitSingleRegion=True, # Split regions between threads?
)
# TODO: should we use `split offsets` between threads?
# For now: need to do it manually at the Python `tile_map` level.
ipu_prim_info = IpuTileMapEquation(
vname=vname,
pname=p.name,
tiles=tiles,
inputs_info=[make_ipu_vertex_in_info("baseT", operand), make_ipu_vertex_in_info("offsets", start_indices)],
outputs_info=[make_ipu_vertex_out_info("subT", outaval)],
attributes_i32=attrs_i32,
attributes_f32=attrs_f32,
)
return ipu_prim_info


# Register JAX gather primitive.
register_ipu_tile_primitive(gather_p, ipu_gather_primitive_translation)
58 changes: 58 additions & 0 deletions tests/lax/test_tile_lax_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.
from functools import partial

import chex
import jax
import numpy as np
import numpy.testing as npt
from absl.testing import parameterized

from tessellate_ipu import tile_map, tile_put_replicated
from tessellate_ipu.lax import gather_p


class IpuTilePrimitivesLaxGather(chex.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self.device = jax.devices("ipu")[0]
self.num_tiles = self.device.num_tiles
# Not very clean, but for better reproducibility.
np.random.seed(123)

@parameterized.parameters(
{"num_elements": 8, "num_indices": 3},
{"num_elements": 8, "num_indices": 12},
)
def test__tile_map__gather__jitting__proper_result(self, num_elements, num_indices):
tiles = (0,)
data = np.random.randn(num_elements).astype(np.float32)
indices = np.random.randint(low=0, high=num_elements, size=num_indices)
indices = indices.reshape(-1, 1).astype(np.uint32)

# Only supported configuration!
dim_numbers = jax.lax.GatherDimensionNumbers(offset_dims=(), collapsed_slice_dims=(0,), start_index_map=(0,))

def gather_fn(data, indices):
data = tile_put_replicated(data, tiles)
indices = tile_put_replicated(indices, tiles)
return tile_map(
gather_p,
data,
indices,
dimension_numbers=dim_numbers,
slice_sizes=(1,),
mode=jax.lax.GatherScatterMode.PROMISE_IN_BOUNDS,
unique_indices=False,
indices_are_sorted=False,
fill_value=None,
)

cpu_gather_fn = partial(jax.jit, backend="cpu")(gather_fn)
ipu_gather_fn = partial(jax.jit, backend="ipu")(gather_fn)

cpu_output = cpu_gather_fn(data, indices)
ipu_output = ipu_gather_fn(data, indices)

assert ipu_output.tiles == tiles
assert ipu_output.dtype == data.dtype
npt.assert_array_equal(ipu_output, cpu_output)

0 comments on commit bab7e7a

Please sign in to comment.