diff --git a/tessellate_ipu/core/tile_interpreter.py b/tessellate_ipu/core/tile_interpreter.py index 5f1771f..7326f65 100644 --- a/tessellate_ipu/core/tile_interpreter.py +++ b/tessellate_ipu/core/tile_interpreter.py @@ -274,6 +274,8 @@ def register_ipu_tile_primitive(primitive: Primitive, translation: IpuVertexTran translation: IPU vertex translation rule. """ global _ipu_tile_primitive_registry + if primitive.name in _ipu_tile_primitive_registry: + raise KeyError(f"The primitive '{primitive.name}' is already registered in TessellateIPU.") _ipu_tile_primitive_registry[primitive.name] = (primitive, translation) diff --git a/tessellate_ipu/lax/__init__.py b/tessellate_ipu/lax/__init__.py index 2d86d80..d868705 100644 --- a/tessellate_ipu/lax/__init__.py +++ b/tessellate_ipu/lax/__init__.py @@ -39,6 +39,7 @@ ) from .tile_lax_dot import IpuConvVertexType from .tile_lax_gather import gather_p +from .tile_lax_scatter import scatter_add_p, scatter_max_p, scatter_min_p, scatter_mul_p, scatter_p from .tile_lax_unary import ( # tanh_inplace_p, abs_inplace_p, asin_inplace_p, diff --git a/tessellate_ipu/lax/tile_lax_scatter.py b/tessellate_ipu/lax/tile_lax_scatter.py new file mode 100644 index 0000000..8ff2f65 --- /dev/null +++ b/tessellate_ipu/lax/tile_lax_scatter.py @@ -0,0 +1,228 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +import logging +from typing import Any, Dict, List, Tuple + +import numpy as np +from jax.core import Primitive, ShapedArray +from jax.lax import ( + GatherScatterMode, + ScatterDimensionNumbers, + scatter_add_p, + scatter_max_p, + scatter_min_p, + scatter_mul_p, + scatter_p, +) + +from tessellate_ipu.core import ( + IpuTileMapEquation, + make_ipu_vertex_attributes, + make_ipu_vertex_constant_info, + make_ipu_vertex_in_info, + make_ipu_vertex_inout_info, + make_ipu_vertex_name_templated, + register_ipu_tile_primitive, +) +from tessellate_ipu.utils import DTypeLike + +_scatter_primitive_to_properties: Dict[Primitive, Any] = { + # scatter_p: (0, "ADD"), + scatter_add_p: (1, "ADD"), + scatter_min_p: (None, "MIN"), + scatter_max_p: (None, "MAX"), + scatter_mul_p: (None, "MUL"), +} +"""IPU translation properties for every JAX LAX scatter primitive. +""" + + +def make_scatter_vertex_fullname(dtype: DTypeLike, opname: str, scale: Any) -> str: + """Generate popops Scatter/MultiUpdateOp vertex name.""" + opname = f"popops::Operation::{opname}" + if scale is not None: + basename = "popops::ScaledMultiUpdateOp" + return make_ipu_vertex_name_templated(basename, dtype, dtype, False, opname) + else: + basename = "popops::MultiUpdateOp" + return make_ipu_vertex_name_templated(basename, dtype, False, opname) + + +def check_scatter_dimension_numbers(dimension_numbers: ScatterDimensionNumbers): + """Check `scatter` dimension_numbers is supported on TessellateIPU. + + At the moment: basically only supporting a single configuration! + We need to expand on this at some point! + """ + dim_numbers_default = ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,) + ) + if dimension_numbers != dim_numbers_default: + raise NotImplementedError(f"TessellateIPU `scatter` only support dimension numbers: {dim_numbers_default}.") + + +def ipu_scatter_op_primitive_translation( + p: Primitive, + tiles: Tuple[int, ...], + inavals: List[ShapedArray], + attributes: Dict[str, Any] = None, +) -> IpuTileMapEquation: + """IPU `scatter_xx` primitive translation rule to IPU vertex. + + See: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scatter.html + + Args: + p: JAX primitive. + tiles: Collection of tiles. + inavals: Input data + start indices arrays. + attributes: Gather operator attributes + Returns: + IPU tile map primitive structure. + """ + # TODO: query for JAX device. + num_context_workers = 6 + + assert len(inavals) == 3 + assert attributes is not None + operand, scatter_indices, updates = inavals + # Extract scatter attributes + dimension_numbers = attributes["dimension_numbers"] + # Default values from JAX LAX interface. + indices_are_sorted = attributes.get("indices_are_sorted", False) + unique_indices = attributes.get("unique_indices", False) + mode = attributes.get("mode", GatherScatterMode.PROMISE_IN_BOUNDS) + + # Check scatter attributes are supported by TessellateIPU. + assert operand.ndim == 1 + assert scatter_indices.ndim == 2 + assert operand.dtype == updates.dtype + assert scatter_indices.dtype == np.uint32, "TessellateIPU `scatter` only supports `uint32` indices." + if indices_are_sorted: + logging.warning("TessellateIPU `scatter` operation does not make use of `indices_are_sorted` argument.") + if unique_indices: + logging.warning("TessellateIPU `scatter` operation does not make use of `unique_indices` argument.") + assert ( + mode == GatherScatterMode.PROMISE_IN_BOUNDS + ), "Only `PROMISE_IN_BOUNDS` scatter mode supported in TessellateIPU." + check_scatter_dimension_numbers(dimension_numbers) + + # Primitive translation properties. + scale, opname = _scatter_primitive_to_properties[p] + vname = make_scatter_vertex_fullname(operand.dtype, opname, scale) + # Construct poplibs MultiSlice vertex attributes. + attrs_i32, attrs_f32 = make_ipu_vertex_attributes( + baseOffset=0, # unused? + numBaseElements=operand.size, # Number of elements in input. + maxElementsPerWorker=int(np.ceil(operand.size / num_context_workers)), + regionSize=1, # TODO: understand? + indicesAreSorted=False, + ) + + # const unsigned baseOffset; // in the slice dimension + # const unsigned numBaseElements; // in the slice dimension + # const unsigned short regionSize; // stride between slices + # const bool indicesAreSorted; // indices are sorted in increasing order + # const bool splitSingleRegion; // Use in the case of a single offset and + # // alignment constraints are met. + # // in the slice dimension (ceil numBaseElements / numWorkers). + # const unsigned maxElementsPerWorker; + + # Constant `scale` (if required by the vertex). + constants_info = [] + if scale is not None: + constants_info = [make_ipu_vertex_constant_info("scale", np.array(scale, dtype=operand.dtype), vertex_dim2=-1)] + # For now: need to do it manually at the Python `tile_map` level. + ipu_prim_info = IpuTileMapEquation( + vname=vname, + pname=p.name, + tiles=tiles, + inputs_info=[ + make_ipu_vertex_inout_info("baseT", operand), + make_ipu_vertex_in_info("offsets", scatter_indices), + make_ipu_vertex_in_info("subT", updates), + ] + + constants_info, + outputs_info=[make_ipu_vertex_inout_info("baseT", operand)], + attributes_i32=attrs_i32, + attributes_f32=attrs_f32, + ) + return ipu_prim_info + + +def ipu_scatter_primitive_translation( + p: Primitive, + tiles: Tuple[int, ...], + inavals: List[ShapedArray], + attributes: Dict[str, Any] = None, +) -> IpuTileMapEquation: + """IPU `scatter` primitive translation rule to IPU vertex. + + Note: using a specific translation, as the poplibs vertex is different. + See: https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scatter.html + + Args: + p: JAX primitive. + tiles: Collection of tiles. + inavals: Input data + start indices arrays. + attributes: Gather operator attributes + Returns: + IPU tile map primitive structure. + """ + # TODO: query for JAX device. + num_context_workers = 6 + + assert len(inavals) == 3 + assert attributes is not None + operand, scatter_indices, updates = inavals + # Extract scatter attributes + dimension_numbers = attributes["dimension_numbers"] + # Default values from JAX LAX interface. + indices_are_sorted = attributes.get("indices_are_sorted", False) + unique_indices = attributes.get("unique_indices", False) + mode = attributes.get("mode", GatherScatterMode.PROMISE_IN_BOUNDS) + + # Check scatter attributes are supported by TessellateIPU. + assert operand.ndim == 1 + assert scatter_indices.ndim == 2 + assert operand.dtype == updates.dtype + assert scatter_indices.dtype == np.uint32, "TessellateIPU `scatter` only supports `uint32` indices." + if indices_are_sorted: + logging.warning("TessellateIPU `scatter` operation does not make use of `indices_are_sorted` argument.") + if unique_indices: + logging.warning("TessellateIPU `scatter` operation does not make use of `unique_indices` argument.") + assert ( + mode == GatherScatterMode.PROMISE_IN_BOUNDS + ), "Only `PROMISE_IN_BOUNDS` scatter mode supported in TessellateIPU." + check_scatter_dimension_numbers(dimension_numbers) + + vname = make_ipu_vertex_name_templated("popops::MultiUpdate", operand.dtype) + # Construct poplibs MultiSlice vertex attributes. + attrs_i32, attrs_f32 = make_ipu_vertex_attributes( + baseOffset=0, # unused? + numBaseElements=operand.size, # Number of elements in input. + maxElementsPerWorker=int(np.ceil(operand.size / num_context_workers)), + regionSize=1, # TODO: understand? + indicesAreSorted=False, + splitSingleRegion=True, + ) + # For now: need to do it manually at the Python `tile_map` level. + ipu_prim_info = IpuTileMapEquation( + vname=vname, + pname=p.name, + tiles=tiles, + inputs_info=[ + make_ipu_vertex_inout_info("baseT", operand), + make_ipu_vertex_in_info("offsets", scatter_indices), + make_ipu_vertex_in_info("subT", updates), + ], + outputs_info=[make_ipu_vertex_inout_info("baseT", operand)], + attributes_i32=attrs_i32, + attributes_f32=attrs_f32, + ) + return ipu_prim_info + + +# Register JAX `scatter` primitives with update op. +for p in _scatter_primitive_to_properties.keys(): + register_ipu_tile_primitive(p, ipu_scatter_op_primitive_translation) +# Specific translation for the simple `scatter` case +register_ipu_tile_primitive(scatter_p, ipu_scatter_primitive_translation) diff --git a/tessellate_ipu/lib/tessellate_ipu_core.cpp b/tessellate_ipu/lib/tessellate_ipu_core.cpp index 0f7722a..f5917cf 100644 --- a/tessellate_ipu/lib/tessellate_ipu_core.cpp +++ b/tessellate_ipu/lib/tessellate_ipu_core.cpp @@ -102,7 +102,7 @@ NB_MODULE(pytessellate_ipu_core, m) { nanobind::arg("shape"), nanobind::arg("dtype"), nanobind::arg("constant_data"), nanobind::arg("slices2d")) .def(nanobind::init(), + IpuType, int64_t, const Base64Data&>(), nanobind::arg("name"), nanobind::arg("iotype"), nanobind::arg("shape"), nanobind::arg("dtype"), nanobind::arg("vertex_dim2") = 0, @@ -118,6 +118,7 @@ NB_MODULE(pytessellate_ipu_core, m) { .def_rw("aval", &VertexIOInfo::aval) .def_rw("constant_data", &VertexIOInfo::constant_data) .def_rw("slices2d", &VertexIOInfo::slices2d) + .def_rw("is_scalar", &VertexIOInfo::is_scalar) .def_prop_ro("shape", [](const VertexIOInfo& v) { return v.aval.shape; }) .def_prop_ro("dtype", [](const VertexIOInfo& v) { return v.aval.dtype; }) .def_prop_ro("is_constant_input", &VertexIOInfo::isConstantInput); diff --git a/tessellate_ipu/lib/tile_map_ops.cpp b/tessellate_ipu/lib/tile_map_ops.cpp index 5460844..6a0cc22 100644 --- a/tessellate_ipu/lib/tile_map_ops.cpp +++ b/tessellate_ipu/lib/tile_map_ops.cpp @@ -1,6 +1,7 @@ // Copyright (c) 2022 Graphcore Ltd. All rights reserved. #include "tile_map_ops.hpp" +#include namespace ipu { std::vector TileMapEquation::allocateInputTensors( @@ -91,7 +92,8 @@ void TileMapEquation::add(poplar::Graph& graph, poplar::program::Sequence& prog, // Map/connect vertex input tensors. for (size_t k = 0; k < inputs.size(); ++k) { const auto& info = inputs_info[k]; - graph.connect(v[info.name], info.connectReshape(inputs[k][tidx])); + const auto tensor = info.connectReshape(inputs[k][tidx]); + graph.connect(v[info.name], tensor); } // Map/connect vertex output tensors. for (size_t k = 0; k < outputs.size(); ++k) { diff --git a/tessellate_ipu/lib/tile_map_ops.hpp b/tessellate_ipu/lib/tile_map_ops.hpp index 5ee2e71..c7a72ec 100644 --- a/tessellate_ipu/lib/tile_map_ops.hpp +++ b/tessellate_ipu/lib/tile_map_ops.hpp @@ -61,6 +61,8 @@ struct VertexIOInfo { Base64Data constant_data = Base64Data(); /** Slices, in the case of 2d tensor input. */ std::vector slices2d; + /** Is the vertex IO tensor just a scalar? */ + bool is_scalar = false; /** Default constructors/assignment. */ VertexIOInfo() noexcept = default; @@ -91,8 +93,8 @@ struct VertexIOInfo { * @brief Build a vertex IO info (with vertex second dim info). */ VertexIOInfo(const std::string& _name, VertexIOType _iotype, - const ShapeType& _shape, IpuType _dtype, - std::size_t _vertex_dim2, const Base64Data& _constant_data) + const ShapeType& _shape, IpuType _dtype, int64_t _vertex_dim2, + const Base64Data& _constant_data) : name{_name}, iotype{_iotype}, aval{_shape, _dtype}, @@ -102,6 +104,11 @@ struct VertexIOInfo { slices2d = TensorSlice::makeTensor2dSlices(aval.size() / _vertex_dim2, _vertex_dim2); } + // Negative => code for scalar. + if (_vertex_dim2 < 0) { + is_scalar = true; + } + // Zero => normal flattened case. } /** @@ -138,8 +145,20 @@ struct VertexIOInfo { /** * @brief Reshape a tensor to the proper rank for vertex connection. + * + * This bit of logic is necessary as Poplar vertices only support: + * rank 0: i.e. scalar entry; + * rank 1: flattened array; + * rank 2: collection of tensor slices; */ poplar::Tensor connectReshape(const poplar::Tensor& t) const { + if (is_scalar) { + if (t.numElements() != 1) { + throw std::logic_error( + "Expecting a single scalar element to connect to the vertex."); + } + return t.flatten()[0]; + } if (slices2d.empty()) { // Rank 1 (no 2d slices): flatten the IO tensor. return t.flatten(); @@ -159,12 +178,12 @@ struct VertexIOInfo { } }; NLOHMANN_DEFINE_TYPE_NON_INTRUSIVE(VertexIOInfo, name, iotype, aval, - constant_data, slices2d) + constant_data, slices2d, is_scalar) inline bool operator==(const VertexIOInfo& lhs, const VertexIOInfo& rhs) { return lhs.name == rhs.name && lhs.iotype == rhs.iotype && lhs.aval.shape == rhs.aval.shape && lhs.aval.dtype == rhs.aval.dtype; - // TODO: compare 2d slices. + // TODO: compare 2d slices and is_scalar? } /** diff --git a/tests/core/custom_arange_primitive.py b/tests/core/custom_arange_primitive.py index 9c0087c..683ab57 100644 --- a/tests/core/custom_arange_primitive.py +++ b/tests/core/custom_arange_primitive.py @@ -58,7 +58,7 @@ def custom_arange_tile_translation_ipu( outaval = core.ShapedArray(outshape, outdtype) gp_filename = custom_vertex_filename - global_scale_data = np.array([7], dtype=outdtype) + global_scale_data = np.array(7, dtype=outdtype) ipu_dtype = from_numpy_dtype_to_ipu_type(outdtype) vertex_name = f"CustomArangeVertex<{ipu_dtype.name.lower()}>" # Translation rule to IPU vertex @@ -69,7 +69,7 @@ def custom_arange_tile_translation_ipu( # IO vertex infos. inputs_info=[ make_ipu_vertex_in_info("scales", inavals[0], vertex_dim2=inavals[0].shape[1]), - make_ipu_vertex_constant_info("global_scale", global_scale_data), + make_ipu_vertex_constant_info("global_scale", global_scale_data, vertex_dim2=-1), ], outputs_info=[make_ipu_vertex_out_info("out", outaval)], # Additional attributes to pass to the vertex diff --git a/tests/core/custom_arange_vertex.cpp b/tests/core/custom_arange_vertex.cpp index 7257aaa..b2ce4aa 100644 --- a/tests/core/custom_arange_vertex.cpp +++ b/tests/core/custom_arange_vertex.cpp @@ -15,13 +15,14 @@ class CustomArangeVertex : public Vertex { // Testing 2d tensor IO supported. Vector>, poplar::VectorLayout::ONE_PTR> scales; // (2, size) // Testing constant vertex tensor. - Input> global_scale; // (1,) + // Input> global_scale; // (1,) + Input global_scale; // (,) scalar Output> out; // (size, ) bool compute() { const auto outsize = out.size(); for (std::size_t idx = 0; idx < outsize; ++idx) { - out[idx] = T(idx) * scales[0][idx] * scales[1][idx] * global_scale[0]; + out[idx] = T(idx) * scales[0][idx] * scales[1][idx] * (*global_scale); } return true; } diff --git a/tests/core/test_tile_interpreter_primitives.py b/tests/core/test_tile_interpreter_primitives.py index 19cdacf..38686de 100644 --- a/tests/core/test_tile_interpreter_primitives.py +++ b/tests/core/test_tile_interpreter_primitives.py @@ -39,8 +39,9 @@ def test__make_ipu_vertex_io_info__proper_result(self): assert info.shape == [1, 2, 3] assert info.dtype == IpuType.HALF assert not info.is_constant_input + assert not info.is_scalar - def test__make_ipu_vertex_constant_info__proper_result(self): + def test__make_ipu_vertex_constant_info__array__proper_result(self): datain = np.array([1, 2, 3, 4], dtype=np.float32) info = make_ipu_vertex_constant_info("constant", datain, vertex_dim2=2) assert isinstance(info, IpuVertexIOInfo) @@ -50,6 +51,22 @@ def test__make_ipu_vertex_constant_info__proper_result(self): assert info.dtype == IpuType.FLOAT assert info.is_constant_input assert len(info.slices2d) == 2 + assert not info.is_scalar + + dataout = np.frombuffer(base64.decodebytes(str.encode(info.constant_data.encoded_data)), dtype=datain.dtype) + npt.assert_array_equal(dataout, datain) + + def test__make_ipu_vertex_constant_info__scalar__proper_result(self): + datain = np.array(3, dtype=np.float32) + # vertex_dim2 < -1 indicating scalar entry. + info = make_ipu_vertex_constant_info("constant", datain, vertex_dim2=-1) + assert isinstance(info, IpuVertexIOInfo) + assert info.name == "constant" + assert info.iotype == IpuVertexIOType.In + assert tuple(info.shape) == datain.shape + assert info.dtype == IpuType.FLOAT + assert info.is_constant_input + assert info.is_scalar dataout = np.frombuffer(base64.decodebytes(str.encode(info.constant_data.encoded_data)), dtype=datain.dtype) npt.assert_array_equal(dataout, datain) diff --git a/tests/lax/test_tile_lax_gather.py b/tests/lax/test_tile_lax_gather.py index 4536aaa..0a3af15 100644 --- a/tests/lax/test_tile_lax_gather.py +++ b/tests/lax/test_tile_lax_gather.py @@ -22,6 +22,7 @@ def setUp(self): @parameterized.parameters( {"num_elements": 8, "num_indices": 3}, {"num_elements": 8, "num_indices": 12}, + {"num_elements": 256, "num_indices": 512}, ) def test__tile_map__gather__jitting__proper_result(self, num_elements, num_indices): tiles = (0,) diff --git a/tests/lax/test_tile_lax_scatter.py b/tests/lax/test_tile_lax_scatter.py new file mode 100644 index 0000000..597308f --- /dev/null +++ b/tests/lax/test_tile_lax_scatter.py @@ -0,0 +1,68 @@ +# Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from functools import partial + +import chex +import jax +import numpy as np +import numpy.testing as npt +from absl.testing import parameterized + +from tessellate_ipu import tile_map, tile_put_replicated +from tessellate_ipu.lax import scatter_add_p, scatter_max_p, scatter_mul_p, scatter_p + + +class IpuTilePrimitivesLaxScater(chex.TestCase, parameterized.TestCase): + def setUp(self): + super().setUp() + self.device = jax.devices("ipu")[0] + self.num_tiles = self.device.num_tiles + # Not very clean, but for better reproducibility. + np.random.seed(123) + + @parameterized.parameters( + {"num_elements": 8, "num_indices": 3, "scatter_prim": scatter_p}, + {"num_elements": 8, "num_indices": 16, "scatter_prim": scatter_add_p}, + {"num_elements": 8, "num_indices": 16, "scatter_prim": scatter_max_p}, + {"num_elements": 8, "num_indices": 16, "scatter_prim": scatter_mul_p}, + {"num_elements": 8, "num_indices": 3, "scatter_prim": scatter_add_p}, + {"num_elements": 8, "num_indices": 12, "scatter_prim": scatter_add_p}, + {"num_elements": 256, "num_indices": 512, "scatter_prim": scatter_add_p}, + ) + def test__tile_map__scatter__jitting__multi_sizes__proper_result(self, num_elements, num_indices, scatter_prim): + tiles = (0,) + data = np.random.randn(num_elements).astype(np.float32) + indices = np.random.randint(low=0, high=num_elements, size=num_indices) + indices = indices.reshape(-1, 1).astype(np.uint32) + updates = np.random.randn(indices.size).astype(np.float32) + + # Only supported configuration! + scatter_dnums = jax.lax.ScatterDimensionNumbers( + update_window_dims=(), inserted_window_dims=(0,), scatter_dims_to_operand_dims=(0,) + ) + + def scatter_add_fn(data, indices, updates): + data = tile_put_replicated(data, tiles) + indices = tile_put_replicated(indices, tiles) + updates = tile_put_replicated(updates, tiles) + return tile_map( + scatter_prim, + data, + indices, + updates, + dimension_numbers=scatter_dnums, + indices_are_sorted=False, + unique_indices=False, + mode=jax.lax.GatherScatterMode.PROMISE_IN_BOUNDS, + update_jaxpr=None, + update_consts=None, + ) + + cpu_scatter_add_fn = partial(jax.jit, backend="cpu")(scatter_add_fn) + ipu_scatter_add_fn = partial(jax.jit, backend="ipu")(scatter_add_fn) + + cpu_output = cpu_scatter_add_fn(data, indices, updates) + ipu_output = ipu_scatter_add_fn(data, indices, updates) + + assert ipu_output.tiles == tiles + assert ipu_output.dtype == data.dtype + npt.assert_array_equal(ipu_output, cpu_output) diff --git a/tests/lib/test_tessellate_core.py b/tests/lib/test_tessellate_core.py index 6074214..8fbf02e 100644 --- a/tests/lib/test_tessellate_core.py +++ b/tests/lib/test_tessellate_core.py @@ -67,12 +67,12 @@ def test__ipu_vertex_io_info__to_json_str__proper_representation(self): ioinfo = IpuVertexIOInfo(name="in0", iotype=IpuVertexIOType.InOut, shape=[1, 2, 3], dtype=IpuType.FLOAT) assert ( ioinfo.to_json_str() - == '{"aval":{"dtype":12,"shape":[1,2,3]},"constant_data":null,"iotype":2,"name":"in0","slices2d":[]}' + == '{"aval":{"dtype":12,"shape":[1,2,3]},"constant_data":null,"iotype":2,"is_scalar":false,"name":"in0","slices2d":[]}' ) def test__ipu_vertex_io_info__from_json_str__proper_representation(self): ioinfo = IpuVertexIOInfo.from_json_str( - '{"aval":{"dtype":12,"shape":[1,2,3]},"constant_data":null,"iotype":2,"name":"in0","slices2d":[{"begin":10,"end":15}]}' + '{"aval":{"dtype":12,"shape":[1,2,3]},"constant_data":null,"iotype":2,"name":"in0","slices2d":[{"begin":10,"end":15}],"is_scalar":false}' ) assert ioinfo.name == "in0" assert ioinfo.iotype == IpuVertexIOType.InOut