diff --git a/tessellate_ipu/core/tile_array.py b/tessellate_ipu/core/tile_array.py index 0979320..5fb8ab2 100644 --- a/tessellate_ipu/core/tile_array.py +++ b/tessellate_ipu/core/tile_array.py @@ -1,8 +1,10 @@ # Copyright (c) 2022 Graphcore Ltd. All rights reserved. +import itertools from dataclasses import dataclass from typing import Any, Sequence, Tuple, Union import chex +import jax.lax import numpy as np from jax.core import ShapedArray from jax.interpreters.xla import DeviceArray @@ -185,6 +187,14 @@ def __getitem__(self, key: Union[SliceType, MultiSliceType]) -> "TileShardedArra check_tile_array_multi_slice(key, self.array.shape) return TileShardedArray(array=self.array[key], tiles=self.tiles[key[0]]) # type:ignore + @classmethod + def concatenate(cls, arrays: Sequence["TileShardedArray"]) -> "TileShardedArray": + """Concatenate tile sharded arrays along the first axis.""" + assert all([isinstance(v, TileShardedArray) for v in arrays]) + outarray = jax.lax.concatenate([v.array for v in arrays], dimension=0) + outtiles = tuple(itertools.chain(*[v.tiles for v in arrays])) + return TileShardedArray(array=outarray, tiles=outtiles) + def tile_put_sharded(array: DeviceArray, tiles: Sequence[int]) -> TileShardedArray: """Shard a JAX array over tiles on the first axis. diff --git a/tessellate_ipu/core/tile_interpreter_vertex_utils.py b/tessellate_ipu/core/tile_interpreter_vertex_utils.py index 2ba3c2a..1389281 100644 --- a/tessellate_ipu/core/tile_interpreter_vertex_utils.py +++ b/tessellate_ipu/core/tile_interpreter_vertex_utils.py @@ -1,6 +1,6 @@ # Copyright (c) 2022 Graphcore Ltd. All rights reserved. import math -from typing import List +from typing import List, Optional import numpy as np from numpy.typing import DTypeLike, NDArray @@ -26,9 +26,14 @@ def make_num_elements_per_worker(N: int, num_workers: int) -> NDArray[np.int32]: def make_ipu_vector1d_worker_offsets( - size: int, vector_size: int = 2, num_workers: int = 6, wdtype: DTypeLike = np.uint16 + size: int, + vector_size: int = 2, + num_workers: int = 6, + wdtype: DTypeLike = np.uint16, + allow_overlap: bool = False, + grain_size: Optional[int] = None, ) -> NDArray[np.int_]: - """Make the QR householder row update worker sizes, i.e. how many + """Make worker sizes/offsets for a 1D array workload, i.e. how many data vectors per worker thread? Args: @@ -36,26 +41,38 @@ def make_ipu_vector1d_worker_offsets( vector_size: Vector size (2: float, 4: half). num_workers: Number of workers. wdtype: Worklists dtype. + allow_overlap: Allowing overlap between workers. Make it easier to deal with remainer term. + grain_size: Optional grain size. vector_size by default. Returns: (6,) number of data vectors per thread. """ + grain_size = grain_size or vector_size + grain_scale = grain_size // vector_size def make_offsets_fn(sizes): sizes = [0] + sizes - offsets = np.cumsum(np.array(sizes, wdtype), dtype=wdtype) + offsets = np.cumsum(np.array(sizes, wdtype) * grain_scale, dtype=wdtype) return offsets - assert size % vector_size == 0 + # TODO: support properly odd size. + assert size % 2 == 0, "Not supporting odd sizing at the moment." + # Base checks! + assert grain_size % vector_size == 0 + assert size >= grain_size, f"Requires at least a size of {grain_size}." + assert ( + size % grain_size == 0 or allow_overlap + ), f"Requires the size, {size}, divisible by the grain size {grain_size}, (or allowing overlap)." + # Base worksize on the first few workers. - base_worksize: int = math.ceil(size / (vector_size * num_workers)) - num_base_workers = size // (vector_size * base_worksize) + base_worksize: int = math.ceil(size / (grain_size * num_workers)) + num_base_workers = size // (grain_size * base_worksize) worker_sizes: List[int] = [base_worksize] * num_base_workers if num_base_workers == num_workers: return make_offsets_fn(worker_sizes) # Remainer term, for the next thread. - rem_worksize = size - base_worksize * vector_size * num_base_workers - rem_worksize = rem_worksize // vector_size + rem_worksize = size - base_worksize * grain_size * num_base_workers + rem_worksize = rem_worksize // grain_size worker_sizes += [rem_worksize] # Fill the rest with zeros. unused_workers = num_workers - num_base_workers - 1 diff --git a/tessellate_ipu/core/vertex/intrinsics_utils.hpp b/tessellate_ipu/core/vertex/intrinsics_utils.hpp index a254777..7860ac2 100644 --- a/tessellate_ipu/core/vertex/intrinsics_utils.hpp +++ b/tessellate_ipu/core/vertex/intrinsics_utils.hpp @@ -64,6 +64,7 @@ ALWAYS_INLINE T ipu_div_by_6(T n) noexcept { */ ALWAYS_INLINE void __builtin_ipu_put_tas(float v) noexcept { // TAS register, used for __builtin_ipu_f32v2axpy. + // TODO: use `__builtin_ipu_uput`? asm volatile( R"l( uput $TAS, %[sv] )l" @@ -72,6 +73,20 @@ ALWAYS_INLINE void __builtin_ipu_put_tas(float v) noexcept { :); } +/** + * @brief Zero AACC registers. + */ +ALWAYS_INLINE void __builtin_ipu_aacc_zero() { + asm (R"( + setzi $a0, 0x8 + uput $FP_CLR, $a0 + )" + : + : + : "$a0"); +} + + /** * @brief IPU cmac f32 instruction. */ diff --git a/tessellate_ipu/core/vertex/ipu_amp.hpp b/tessellate_ipu/core/vertex/ipu_amp.hpp new file mode 100644 index 0000000..0ef7b85 --- /dev/null +++ b/tessellate_ipu/core/vertex/ipu_amp.hpp @@ -0,0 +1,127 @@ +// Copyright (c) 2023 Graphcore Ltd. All rights reserved. +#pragma once +#include + +#include "intrinsics_utils.hpp" +#include "ipu_model_types.hpp" + +namespace ipu { + +/** + * @brief Thin abstraction of the IPU AMP unit(s) and registers, allowing + * to write generic code compiling on IPU model and IPU hardware. + * + * NOTE: zero-cost abstraction on IPU hardware. + * + * The AMP class is modelling AACC registers as well as AMP unit instructions + * on the IPU model, reproducing the expected behaviour of the hardware. + */ +template +class AMP { + public: + // TODO: support half as well. + static_assert(std::is_same_v); + using FPType = T; + /** Number of AACC register available in hw. */ + // TODO: use TFPU_AMP_UNITS_PER_SET and TFPU_AACC_PER_AMP_UNIT; + static constexpr unsigned NumAACC = 16; + + // TODO: random initialization on IPU model of registers. + AMP() noexcept = default; + // No copy + no move allowed! + AMP(const AMP&) = delete; + AMP(AMP&&) = delete; + + /** + * @brief Set the value of the TAS register, used in + * `axpy` operation. + */ + ALWAYS_INLINE void tas(FPType val) noexcept { +#ifdef __IPU__ + __builtin_ipu_put_tas(val); +#else + m_tas = val; +#endif + } + /** + * @brief Zero AACC registers. + */ + ALWAYS_INLINE void aaccZero() noexcept { +#ifdef __IPU__ + __builtin_ipu_aacc_zero(); +#else + for (unsigned idx = 0; idx < NumAACC; ++idx) { + m_aacc[idx] = 0; + } +#endif + } + + /** + * @brief Scaled-add `axpy` intrinsic. Only supported on FP32. + * NOTE: act as 1 stage pipeline, storing result in AACC[0...2] + */ + ALWAYS_INLINE float2 axpy(float2 x, float2 y) noexcept { + using T2 = float2; +#ifdef __IPU__ + // Weird ordering here? Bug in the intrinsic definition? + return __builtin_ipu_f32v2axpy(y, x); +#else + // Simulating pipeline with storing in AACC[0] and AACC[2]. + const auto res = T2{m_aacc[0], m_aacc[2]}; + // FIXME/TODO: understand ordering!? + m_aacc[0] = m_tas * x[0] + y[0]; + m_aacc[2] = m_tas * x[1] + y[1]; + return res; +#endif + } + + /** + * @brief Outer-product `aop` intrinsic. Only supported on FP32. + * Storing results in AACC[0...6] + */ + void aop(float2 x, float2 y) noexcept { +#ifdef __IPU__ + // Note: third argument not used by hw. + __builtin_ipu_f32v2aop(x, y, 0); +#else + // Multiply + accumulate. + m_aacc[0] += x[0] * y[0]; + m_aacc[2] += x[1] * y[0]; + m_aacc[4] += x[0] * y[1]; + m_aacc[6] += x[1] * y[1]; +#endif + } + + /** + * @brief `gina` instruction: get AACC register + propagate. + * FIXME: support non-zero flag/index. + */ + template + float2 gina(float2 val) noexcept { + using T2 = float2; +#ifdef __IPU__ + return __builtin_ipu_f32v2gina(val, 0); +#else + // TODO: implement GINA_IMMFLAGS__SET__GET + const auto res = T2{m_aacc[0], m_aacc[2]}; + // Propagate accumulator states. + for (unsigned idx = 4; idx < NumAACC; idx += 4) { + m_aacc[idx - 4] = m_aacc[idx]; + m_aacc[idx - 2] = m_aacc[idx + 2]; + } + m_aacc[NumAACC - 4] = val[0]; + m_aacc[NumAACC - 2] = val[1]; + return res; +#endif + } + + private: +#ifndef __IPU__ + // Simulating AACC registers on IPU model. + FPType m_aacc[NumAACC]; + // Simulating TAS register on IPU model. + FPType m_tas; +#endif +}; + +} // namespace ipu diff --git a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp index e30de5f..4074aaa 100644 --- a/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp +++ b/tessellate_ipu/core/vertex/tile_jacobi_vertex.cpp @@ -3,6 +3,7 @@ #include #include "intrinsics_utils.hpp" +#include "ipu_amp.hpp" #include "tile_small_dot.hpp" using namespace poplar; @@ -80,10 +81,11 @@ class JacobiSymSchur2 : public Vertex { }; template -void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated, - T* qcol_updated, T* cs, unsigned p, unsigned q, - unsigned short wstart, - unsigned short wend) noexcept { +inline void jacob_update_first_step(const T* pcol, const T* qcol, + T* pcol_updated, T* qcol_updated, T* cs, + unsigned p, unsigned q, + unsigned short wstart, + unsigned short wend) noexcept { using T2 = float2; using IndexType = unsigned short; @@ -106,7 +108,7 @@ void jacob_update_first_step(const T* pcol, const T* qcol, T* pcol_updated, float2* ptr_qcol_updated = reinterpret_cast(qcol_updated) + wstart; // Apply Schur2 cs rotation to p/q columns (optimized kernel). rotation2d_f32(cs_vec, ptr_pcol, ptr_qcol, ptr_pcol_updated, - ptr_qcol_updated, wsize); + ptr_qcol_updated, wsize); // Update main values App, Apq, Aqq pcol_updated[p] = c * c * App - 2 * s * c * Apq + s * s * Aqq; qcol_updated[q] = s * s * App + 2 * s * c * Apq + c * c * Aqq; @@ -184,6 +186,129 @@ class [[poplar::constraint( } }; +/** + * @brief Jacobi update second step, using Schur2 coefficient from + * other pairs of columns. + */ +template +inline void jacobi_update_second_step(const unsigned* rotset_sorted_arr, + const T* cs_arr, const T* pcol, + const T* qcol, T* pcol_updated, + T* qcol_updated, unsigned wstart, + unsigned wend) noexcept { + const unsigned wsize = (wend - wstart) / 2; + // Necessary for generating `rpt` loop. + __builtin_assume(wsize < 4096); + using T2 = float2; + // Increment pointers. NOTE: unrolling creating "4x" factor. + rotset_sorted_arr += 2 * wstart; + const T2* cs_arr_ptr = reinterpret_cast(cs_arr) + wstart; + + // Basic usage of AMP unit with `aop` outer-product :) + ipu::AMP amp; + amp.aaccZero(); + + const T2 zeros{0, 0}; + T2 res, cs0, cs1, Sp0, Sq0, Sp1, Sq1, tmp0, tmp1; + unsigned k0, l0, k1, l1; + + // The loop body is roughly the following equations: + // const T Spk = pcol_ptr[k]; + // const T Spl = pcol_ptr[l]; + // const T Sqk = qcol_ptr[k]; + // const T Sql = qcol_ptr[l]; + + // pcol_updated_ptr[k] = c * Spk - s * Spl; + // pcol_updated_ptr[l] = s * Spk + c * Spl; + // qcol_updated_ptr[k] = c * Sqk - s * Sql; + // qcol_updated_ptr[l] = s * Sqk + c * Sql; + + // Problem: generate poor bundling of operations in the loop. + // Solution: unroll 2 steps + f32v2aop + manual re-ordering. + // NOTE: f32v2aop mostly useful for reducing register pressure, + // as results are stored in AACC registers (not AUX). Just saving 1 compute + // cycle. + + // Pre-loading due to unrolling + reordering. + k0 = ipu::load_postinc(&rotset_sorted_arr, 1); + l0 = ipu::load_postinc(&rotset_sorted_arr, 1); + cs0 = ipu::load_postinc(&cs_arr_ptr, 1); + Sp0 = {pcol[k0], pcol[l0]}; + for (unsigned half_idx = 0; half_idx < wsize; ++half_idx) { + // Pseudo bundling of instructions, to help popc. + { + Sq0[0] = qcol[k0]; + amp.aop(cs0, Sp0); + } + { + k1 = ipu::load_postinc(&rotset_sorted_arr, 1); + tmp0 = amp.template gina<0>(zeros); + } + { + l1 = ipu::load_postinc(&rotset_sorted_arr, 1); + tmp1 = amp.template gina<0>(zeros); + } + { + Sq0[1] = qcol[l0]; + pcol_updated[k0] = tmp0[0] - tmp1[1]; + } + { + pcol_updated[l0] = tmp0[1] + tmp1[0]; + amp.aop(cs0, Sq0); + } + { + cs1 = ipu::load_postinc(&cs_arr_ptr, 1); + tmp0 = amp.template gina<0>(zeros); + } + { + Sp1[0] = pcol[k1]; + tmp1 = amp.template gina<0>(zeros); + } + { + Sp1[1] = pcol[l1]; + qcol_updated[k0] = tmp0[0] - tmp1[1]; + } + // Unrolling: second part. + // NOTE: inputs already (partially) loaded. + { + qcol_updated[l0] = tmp0[1] + tmp1[0]; + amp.aop(cs1, Sp1); + } + { + Sq1[0] = qcol[k1]; + tmp0 = amp.template gina<0>(zeros); + } + { + Sq1[1] = qcol[l1]; + tmp1 = amp.template gina<0>(zeros); + } + { + k0 = ipu::load_postinc(&rotset_sorted_arr, 1); + pcol_updated[k1] = tmp0[0] - tmp1[1]; + } + { + pcol_updated[l1] = tmp0[1] + tmp1[0]; + amp.aop(cs1, Sq1); + } + { + l0 = ipu::load_postinc(&rotset_sorted_arr, 1); + tmp0 = amp.template gina<0>(zeros); + } + { + cs0 = ipu::load_postinc(&cs_arr_ptr, 1); + tmp1 = amp.template gina<0>(zeros); + } + { + Sp0[0] = pcol[k0]; + qcol_updated[k1] = tmp0[0] - tmp1[1]; + } + { + qcol_updated[l1] = tmp0[1] + tmp1[0]; + Sp0[1] = pcol[l0]; + } + } +} + class JacobiUpdateSecondStep : public MultiVertex { public: using T = float; @@ -213,11 +338,14 @@ class JacobiUpdateSecondStep : public MultiVertex { bool compute(unsigned wid) { // Size of the index prefix in pcol and qcol. - constexpr int INDEX_PREFIX = 2; + constexpr unsigned INDEX_PREFIX = 2; // Worker load: start + end vectorized indexes. - const IndexType wstart = worker_offsets[wid]; - const IndexType wend = worker_offsets[wid + 1]; - const IndexType wsize = wend - wstart; + const unsigned wstart = worker_offsets[wid]; + const unsigned wend = worker_offsets[wid + 1]; + + // Forward pq indices. + pcol_updated[0] = pcol[0]; + qcol_updated[0] = qcol[0]; // Use (p, q) = (1, 0) for ignore idx. const unsigned ignore_idx = 2 * rotset_idx_ignored[0]; @@ -229,33 +357,9 @@ class JacobiUpdateSecondStep : public MultiVertex { auto pcol_updated_ptr = pcol_updated.data() + INDEX_PREFIX; auto qcol_updated_ptr = qcol_updated.data() + INDEX_PREFIX; - // Forward pq indices. - pcol_updated[0] = pcol[0]; - qcol_updated[0] = qcol[0]; - - // Parallized loop on update using other columns coefficients - for (IndexType half_idx = 0; half_idx != wsize; ++half_idx) { - // TODO: cleaning pq indices offset. - const unsigned k = rotset_sorted_arr[2 * half_idx + 2 * wstart]; - const unsigned l = rotset_sorted_arr[2 * half_idx + 1 + 2 * wstart]; - - const T c = cs_arr[2 * half_idx + 2 * wstart]; - const T s = cs_arr[2 * half_idx + 1 + 2 * wstart]; - - // 4 coefficients updates! - // TODO: vectorization?! - const T Spk = pcol_ptr[k]; - const T Spl = pcol_ptr[l]; - - const T Sqk = qcol_ptr[k]; - const T Sql = qcol_ptr[l]; - - pcol_updated_ptr[k] = c * Spk - s * Spl; - pcol_updated_ptr[l] = s * Spk + c * Spl; - - qcol_updated_ptr[k] = c * Sqk - s * Sql; - qcol_updated_ptr[l] = s * Sqk + c * Sql; - } + jacobi_update_second_step(rotset_sorted_arr.data(), cs_arr.data(), pcol_ptr, + qcol_ptr, pcol_updated_ptr, qcol_updated_ptr, + wstart, wend); return true; } }; diff --git a/tessellate_ipu/core/vertex/tile_small_dot.hpp b/tessellate_ipu/core/vertex/tile_small_dot.hpp index 26bd338..0b380fd 100644 --- a/tessellate_ipu/core/vertex/tile_small_dot.hpp +++ b/tessellate_ipu/core/vertex/tile_small_dot.hpp @@ -1,5 +1,6 @@ // Copyright (c) 2023 Graphcore Ltd. All rights reserved. #include "intrinsics_utils.hpp" +#include "ipu_amp.hpp" /** * @brief z = a*x + b*y float32 implementation. @@ -35,9 +36,9 @@ inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y, // __builtin_assume(nblocks < 4096); using T2 = float2; const T2 av = {a, a}; - // Using TAS register for one of the scalar. - __ipu_and_ipumodel_tas tas; - tas.put(b); + // Basic AMP usage with TAS + axpy instruction. + ipu::AMP amp; + amp.tas(b); T2 res, xv, yv, zv, tmp; @@ -49,13 +50,11 @@ inline void axplusby_f32(float a, float b, const float2 *x, const float2 *y, // popc should be able to generate an optimal rpt loop. { xv = ipu::load_postinc(&x, 1); - // TODO: fix ordering of arguments in `f32v2axpy`. - tmp = tas.f32v2axpy(res, yv); + tmp = amp.axpy(yv, res); } { yv = ipu::load_postinc(&y, 1); - // TODO: fix ordering of arguments in `f32v2axpy`. - zv = tas.f32v2axpy(tmp, tmp); + zv = amp.axpy(tmp, tmp); } { ipu::store_postinc(&z, zv, 1); @@ -139,7 +138,8 @@ template inline void rotation2d_f32(float2 cs, const float2 *inrow0, const float2 *inrow1, float2 *outrow0, float2 *outrow1, rptsize_t nblocks) { - // TODO: investigate using IPU AMP unit? + // axplusby is using one AMP unit. TODO: investigate using more! axplusby_f32(cs[0], -cs[1], inrow0, inrow1, outrow0, nblocks); + // NOTE: inrow1+0, outrow1 arguments order necessary due to bank constraints! axplusby_f32(cs[0], cs[1], inrow1, inrow0, outrow1, nblocks); } diff --git a/tessellate_ipu/linalg/tile_linalg_jacobi.py b/tessellate_ipu/linalg/tile_linalg_jacobi.py index 5319b56..0a67f0c 100644 --- a/tessellate_ipu/linalg/tile_linalg_jacobi.py +++ b/tessellate_ipu/linalg/tile_linalg_jacobi.py @@ -19,6 +19,7 @@ tile_put_sharded, ) from tessellate_ipu.core.tile_interpreter_vertex_utils import make_ipu_vector1d_worker_offsets +from tessellate_ipu.lax import tile_fill from tessellate_ipu.utils import NDArray Array = Any @@ -69,8 +70,10 @@ def get_jacobi_vertex_gp_filename() -> str: inputs=["cs_arr", "rotset_sorted_arr", "rotset_idx_ignored", "pcol", "qcol"], outputs={"cs_arr": 0, "pcol_updated": 3, "qcol_updated": 4}, constants={ + # NOTE: using grain_size=4 because of partial loop unrolling + # TODO: support overlap properly. "worker_offsets": lambda inavals, *_: make_ipu_vector1d_worker_offsets( - inavals[3].size - INDEX_PREFIX, vector_size=2, wdtype=np.uint16 + inavals[3].size - INDEX_PREFIX, vector_size=2, wdtype=np.uint16, allow_overlap=False, grain_size=4 ) }, gp_filename=get_jacobi_vertex_gp_filename(), @@ -232,6 +235,14 @@ def ipu_jacobi_eigh_body(idx: Array, inputs: Tuple[TileShardedArray, ...]) -> Tu rotset_sorted_sharded, cs_per_tile, Apcols, Aqcols = tile_map( # type:ignore jacobi_update_first_step_p, Apcols, Aqcols ) + # Append zero indices to the rotset, for loop unrolling in `jacobi_update_second_step` + rotset_zeros = tile_fill((2,), 0, dtype=rotset_sorted_sharded.dtype, tiles=(0,)) + # Barrier to make sure communication gets fused into a single block. + rotset_zeros, rotset_sorted_sharded, cs_per_tile = tile_data_barrier( + rotset_zeros, rotset_sorted_sharded, cs_per_tile + ) + rotset_sorted_sharded = TileShardedArray.concatenate([rotset_sorted_sharded, rotset_zeros]) + # Replicate Schur decomposition + rotset across all A tiles: (2*N//2) comms. with jax.named_scope("rotset_sorted_replicated"): rotset_sorted_replicated = tile_put_replicated(rotset_sorted_sharded.array, tiles=Atiles) diff --git a/tests/linalg/test_tile_linalg_jacobi.py b/tests/linalg/test_tile_linalg_jacobi.py index e54dd6c..5c07e0d 100644 --- a/tests/linalg/test_tile_linalg_jacobi.py +++ b/tests/linalg/test_tile_linalg_jacobi.py @@ -178,7 +178,7 @@ def test__jacobi_eigh__single_iteration(self): @unittest.skipUnless(ipu_num_tiles >= 16, "Requires IPU with 16 tiles") def test__jacobi_eigh_raw__proper_eigh_result(self): - N = 8 + N = 12 x = np.random.randn(N, N).astype(np.float32) x = (x + x.T) / 2.0