diff --git a/jax/BUILD b/jax/BUILD index 31020eb1d385..271f5e95a626 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -455,6 +455,7 @@ pytype_strict_library( ":dtypes", ":effects", ":mesh", + ":partition_spec", ":pretty_printer", ":source_info_util", ":traceback_util", @@ -558,6 +559,7 @@ pytype_strict_library( ":layout", ":op_shardings", ":partial_eval", + ":partition_spec", ":path", ":pickle_util", ":sharding", diff --git a/jax/_src/core.py b/jax/_src/core.py index 0c2949de07af..ac7ff8bdab76 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -39,6 +39,7 @@ from jax._src import effects from jax._src import compute_on from jax._src import mesh as mesh_lib +from jax._src.partition_spec import PartitionSpec as P, UnconstrainedSingleton from jax._src.errors import ( ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError, TracerIntegerConversionError, UnexpectedTracerError) @@ -1599,13 +1600,30 @@ def _invalid_shape_error(shape: Shape, context: str=""): return TypeError(msg) +# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with +# Collective too. +def _maybe_modify_sharding(sharding): + if mesh_lib.AxisTypes.Auto not in sharding.mesh.axis_types: + return sharding + + new_spec = [] + for s in sharding.spec: + if s is None or isinstance(s, UnconstrainedSingleton): + new_spec.append(s) + else: + temp_s = s[0] if isinstance(s, tuple) else s + new_spec.append( + P.UNCONSTRAINED + if sharding.mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s) + return sharding.with_spec(new_spec) + def get_sharding(sharding, ndim): - from jax._src.sharding_impls import NamedSharding, PartitionSpec as P # type: ignore + from jax._src.sharding_impls import NamedSharding # type: ignore if sharding is not None: assert len(sharding.spec) == ndim - return sharding + return _maybe_modify_sharding(sharding) context_mesh = mesh_lib.get_abstract_mesh() # TODO(yashkatariya): Error out and ask users to set the context mesh in their @@ -1675,9 +1693,7 @@ def str_short(self, short_dtypes=False): dt_str = dt_str.replace('void', 'float0') if hasattr(self, 'sharding') and self.sharding is not None: shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec) - axis_types = self.sharding.mesh.axis_types - axt = _get_axis_type_str(axis_types) if axis_types is not None else '' - return f'{dt_str}[{shapestr}]{axt}' + return f'{dt_str}[{shapestr}]' else: shapestr = ','.join(map(str, self.shape)) return f'{dt_str}[{shapestr}]' @@ -1689,26 +1705,13 @@ def _len(self, ignored_tracer): raise TypeError("len() of unsized object") from err # same as numpy error -def _get_axis_type_str(axis_types): - from jax._src.mesh import AxisTypes # type: ignore - - out = [] - for t, axes in axis_types.items(): - a = f"({','.join(a for a in axes)})" if isinstance(axes, tuple) else axes - if t == AxisTypes.Collective: - out.append(f"C:{a}") - elif t == AxisTypes.User: - out.append(f"U:{a}") - else: - assert t == AxisTypes.Auto - out.append(f"A:{a}") - return f"{{{', '.join(out)}}}" - def _get_shape_sharding_str(shape, spec): out = [] for s1, s2 in zip(shape, spec): if s2 is None: out.append(f"{s1}") + elif isinstance(s2, UnconstrainedSingleton): + out.append(f"{s1}") elif isinstance(s2, tuple): ss = ','.join(s for s in s2) out.append(f"{s1}@({ss})") diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 531177b7244c..a4edbff0d6e2 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -50,6 +50,7 @@ from jax._src.layout import AutoLayout, DeviceLocalLayout from jax._src.sharding import Sharding as JSharding from jax._src.sharding_impls import AUTO +from jax._src.partition_spec import UnconstrainedSingleton from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension from jax._src.lib.mlir import dialects, ir, passmanager @@ -2524,12 +2525,19 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None): # Don't emit a wsc under full manual mode to avoid increasing HLO size. if aval.sharding.mesh._are_all_axes_collective: return op + if aval.sharding.mesh._are_all_axes_auto: + return op + # TODO(yashkatariya): If all the axes in pspec are AUTO or collective, + # `return op` early and avoid bloating HLO size. proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto() if sharding_proto is None else sharding_proto) - # TODO(yashkatariya): Enable this - # unspecified_dims = (set(range(aval.ndim)) - # if aval.sharding.mesh._any_axis_collective else None) - return wrap_with_sharding_op(ctx, op, aval, proto) + unspecified_dims = None + if aval.sharding.mesh._any_axis_collective: + unspecified_dims = set(range(aval.ndim)) + elif aval.sharding.mesh._any_axis_auto: + unspecified_dims = {i for i, s in enumerate(aval.sharding.spec) + if isinstance(s, UnconstrainedSingleton)} + return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims) def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding): diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index d48e81b9092c..0f16bab19181 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -63,7 +63,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import hlo -from jax._src.partition_spec import PartitionSpec +from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton from jax._src.sharding import Sharding as JSharding from jax._src.sharding_impls import ( ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED, @@ -2123,11 +2123,13 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment): @lru_cache(maxsize=128) def _abstract_to_concrete_mesh(abstract_mesh): return mesh_lib.Mesh( - np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names) + np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names, + axis_types=abstract_mesh.axis_types) out = [] for s, a in zip(shardings, avals): - if isinstance(s, UnspecifiedValue) and a.sharding is not None: + if (isinstance(s, UnspecifiedValue) and a.sharding is not None and + all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)): out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh), a.sharding.spec)) else: diff --git a/jax/_src/mesh.py b/jax/_src/mesh.py index c7b8f692055d..480cd64d1cfb 100644 --- a/jax/_src/mesh.py +++ b/jax/_src/mesh.py @@ -124,6 +124,7 @@ def axis_names_to_types(axis_types) -> dict[str, AxisTypes]: _mesh_object_dict = {} # type: ignore +MeshAxisType = dict[AxisTypes, str | tuple[str, ...]] class Mesh(contextlib.ContextDecorator): """Declare the hardware resources available in the scope of this manager. @@ -178,11 +179,11 @@ class Mesh(contextlib.ContextDecorator): devices: np.ndarray axis_names: tuple[MeshAxisName, ...] - axis_types: dict[AxisTypes, str | tuple[str, ...]] | None + axis_types: MeshAxisType | None def __new__(cls, devices: np.ndarray | Sequence[xc.Device], - axis_names: str | Sequence[MeshAxisName], - axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None): + axis_names: str | Sequence[MeshAxisName], *, + axis_types: MeshAxisType | None = None): if not isinstance(devices, np.ndarray): devices = np.array(devices) if isinstance(axis_names, str): @@ -216,7 +217,8 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device], return self def __reduce__(self): - return (type(self), (self.devices, self.axis_names, self.axis_types)) + return (type(self), (self.devices, self.axis_names), + {'axis_types': self.axis_types}) def __eq__(self, other): if not isinstance(other, Mesh): @@ -348,7 +350,7 @@ def local_devices(self): @functools.cached_property def abstract_mesh(self): - return AbstractMesh(self.shape_tuple, self.axis_types) + return AbstractMesh(self.shape_tuple, axis_types=self.axis_types) EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ())) @@ -373,8 +375,8 @@ class AbstractMesh: details. """ - def __init__(self, shape_tuple: tuple[tuple[str, int], ...], - axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None): + def __init__(self, shape_tuple: tuple[tuple[str, int], ...], *, + axis_types: MeshAxisType | None = None): self.shape_tuple = shape_tuple self.axis_types = axis_types if self.shape_tuple: @@ -434,6 +436,24 @@ def _are_all_axes_collective(self) -> bool: return False return all(t == AxisTypes.Collective for t in self.axis_types.keys()) + @functools.cached_property + def _are_all_axes_auto(self) -> bool: + if self.axis_types is None: + return False + return all(t == AxisTypes.Auto for t in self.axis_types.keys()) + + @functools.cached_property + def _any_axis_collective(self) -> bool: + if self.axis_types is None: + return False + return any(t == AxisTypes.Collective for t in self.axis_types.keys()) + + @functools.cached_property + def _any_axis_auto(self) -> bool: + if self.axis_types is None: + return False + return any(t == AxisTypes.Auto for t in self.axis_types.keys()) + @property def devices(self): _raise_value_error("devices") @@ -474,6 +494,8 @@ def _raise_value_error(name): @contextlib.contextmanager def set_abstract_mesh(mesh: AbstractMesh): + if mesh is not None and mesh.axis_types is None: + raise RuntimeError('Please set the AxisTypes of Mesh.') prev_val = jax_config.abstract_mesh_context_manager.swap_local(mesh) try: yield diff --git a/jax/_src/partition_spec.py b/jax/_src/partition_spec.py index f9bc2b60cee9..3af32ea90356 100644 --- a/jax/_src/partition_spec.py +++ b/jax/_src/partition_spec.py @@ -14,7 +14,7 @@ from __future__ import annotations -class _UnconstrainedPartitionSingleton: +class UnconstrainedSingleton: def __repr__(self): return "UNCONSTRAINED" @@ -23,7 +23,7 @@ def __repr__(self): # Unconstrained sentinel value for PartitionSpec, representing a dimension for # which the user wants XLA to assign the best partitioning. # TODO(yashkatariya): May rename to AUTO. -_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton() +_UNCONSTRAINED_PARTITION = UnconstrainedSingleton() class PartitionSpec(tuple): diff --git a/jax/_src/sharding_impls.py b/jax/_src/sharding_impls.py index 5e1def1079ac..7eff620c648a 100644 --- a/jax/_src/sharding_impls.py +++ b/jax/_src/sharding_impls.py @@ -67,6 +67,19 @@ def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes): f"is also found in manual_axes: {_manual_axes}.") from None +@util.cache(max_size=128, trace_context_in_key=False) +def _check_axis_type_consistency(mesh, parsed_pspec): + if mesh.axis_types is None: + return + for p in parsed_pspec: + if p is not None: + if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p): + raise ValueError( + 'AxisTypes should be the same in a tuple subset of PartitionSpec:' + f' {parsed_pspec.get_partition_spec()}. Got subset {p} with axis' + f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})') + + def hashed_index(x) -> int: # This works for both `pjit` indices and `pmap` indices (which might # have an integer instead of a slice). @@ -1084,6 +1097,7 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()): PartitionSpec() if spec is None else spec, "NamedSharding spec", allow_unconstrained_dims=True) _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes) + _check_axis_type_consistency(mesh, parsed_pspec) return parsed_pspec @@ -1673,7 +1687,8 @@ def _gspmd_to_named_sharding_via_mesh( def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], - *, devices: Sequence[xc.Device] | None = None) -> mesh_lib.Mesh: + *, devices: Sequence[xc.Device] | None = None, + axis_types: mesh_lib.MeshAxisType | None = None) -> mesh_lib.Mesh: """Creates an efficient mesh with the shape and axis names specified. This function attempts to automatically compute a good mapping from a set of @@ -1735,4 +1750,4 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str], mesh_devices = mesh_utils.create_device_mesh( new_axis_shapes, devices, allow_split_physical_axes=allow_split_physical_axes) - return mesh_lib.Mesh(mesh_devices, axis_names) + return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types) diff --git a/jax/_src/test_util.py b/jax/_src/test_util.py index 0bd5c7b139a1..ce418b686603 100644 --- a/jax/_src/test_util.py +++ b/jax/_src/test_util.py @@ -1443,26 +1443,28 @@ def with_and_without_mesh(f): ('Mesh', (('x', 2),), (('i', 'x'),)) ))(with_mesh_from_kwargs(f)) -def with_user_mesh(sizes, names): +def with_user_mesh(sizes, names, axis_types=None): + axis_types = ({mesh_lib.AxisTypes.User: names} + if axis_types is None else axis_types) def decorator(fn): def mesh_fn(*args, **kwargs): - mesh = create_mesh(sizes, names) + mesh = create_mesh(sizes, names, axis_types=axis_types) with mesh_lib.set_mesh(mesh): return fn(*args, **kwargs, mesh=mesh) return mesh_fn return decorator -def create_mesh(mesh_shape, axis_names, iota_order=False): +def create_mesh(mesh_shape, axis_names, iota_order=False, axis_types=None): size = math.prod(mesh_shape) if len(jax.devices()) < size: raise unittest.SkipTest(f"Test requires {size} global devices.") if iota_order: devices = sorted(jax.devices(), key=lambda d: d.id) mesh_devices = np.array(devices[:size]).reshape(mesh_shape) - return jax.sharding.Mesh(mesh_devices, axis_names) + return jax.sharding.Mesh(mesh_devices, axis_names, axis_types=axis_types) else: - return jax.make_mesh(mesh_shape, axis_names) + return jax.make_mesh(mesh_shape, axis_names, axis_types=axis_types) class _cached_property: null = object() diff --git a/jax/experimental/shard_map.py b/jax/experimental/shard_map.py index b4609282e2f8..5637668e4afa 100644 --- a/jax/experimental/shard_map.py +++ b/jax/experimental/shard_map.py @@ -46,7 +46,8 @@ from jax._src import traceback_util from jax._src import util from jax._src.core import Tracer -from jax._src.mesh import AbstractMesh, Mesh, AxisTypes, set_abstract_mesh +from jax._src.mesh import (AbstractMesh, Mesh, AxisTypes, set_abstract_mesh, + get_abstract_mesh) from jax._src.api import _shared_code_pmap, _prepare_pmap from jax._src.lax import (lax, parallel as lax_parallel, slicing, windowed_reductions, convolution, fft, linalg, @@ -536,7 +537,7 @@ def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue for i, sz in enumerate(aval.shape)) if config.sharding_in_types.value: new_mesh = AbstractMesh( - mesh.shape_tuple, {AxisTypes.Collective: mesh.axis_names}) + mesh.shape_tuple, axis_types={AxisTypes.Collective: mesh.axis_names}) new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim)) else: new_sharding = None @@ -548,11 +549,9 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames, assert isinstance(aval, core.ShapedArray) new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ())) for i, sz in enumerate(aval.shape)) - # TODO(yashkatariya): Reset the mesh properly based on the input avals if the - # mesh of shard_map specifies collective axes. if config.sharding_in_types.value: spec = _names_to_pspec(names)._normalized_spec(aval.ndim) - new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec) + new_sharding = NamedSharding(get_abstract_mesh(), spec) else: new_sharding = None return aval.update(shape=new_shape, sharding=new_sharding) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5bb09043568c..f4469f1a16cc 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -4837,6 +4837,7 @@ def test_dot_general_batch_error(self, mesh): @jtu.with_user_mesh((2, 2), ('model', 'data')) def test_aval_repr(self, mesh): + mesh = mesh.abstract_mesh aval = core.ShapedArray((128, 64), np.float32, sharding=NamedSharding(mesh, P('model', 'data'))) self.assertEqual(aval.str_short(), 'float32[128@model,64@data]') @@ -4977,21 +4978,21 @@ def f(x): lowered_text = f.lower(arr).as_text() self.assertIn('@Sharding', lowered_text) - def test_broadcasting_nary_error(self): - mesh1 = Mesh([jax.devices()[0]], 'x') - mesh2 = Mesh([jax.devices()[0]], 'y') + @jtu.with_user_mesh((1,), 'x') + def test_broadcasting_nary_error(self, mesh): + mesh2 = Mesh([jax.devices()[0]], 'y', + axis_types={mesh_lib.AxisTypes.User: 'y'}) - arr1 = jax.device_put(np.arange(8), NamedSharding(mesh1, P())) + arr1 = jax.device_put(np.arange(8), NamedSharding(mesh, P())) arr2 = jax.device_put(np.arange(8), NamedSharding(mesh2, P())) @jax.jit def f(x, y): return x + y - with config.sharding_in_types(True): - with self.assertRaisesRegex( - ValueError, "Mesh for all inputs should be equal"): - f(arr1, arr2) + with self.assertRaisesRegex( + ValueError, "Mesh for all inputs should be equal"): + f(arr1, arr2) @jtu.with_user_mesh((2, 2), ('x', 'y')) def test_sin_unop(self, mesh): @@ -5482,6 +5483,53 @@ def f(x): self.assertIn('@Sharding', f.lower(arr).as_text()) + def test_auto_user(self): + mesh = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types={mesh_lib.AxisTypes.Auto: ('x', 'y')}) + np_inp = np.arange(16.).reshape(8, 2) + s = NamedSharding(mesh, P('x', 'y')) + arr = jax.device_put(np_inp, s) + + @jax.jit + def f(x, x2): + y = x * 2 + z = jnp.sin(y) + a = z @ x2 + return a + + with mesh_lib.set_mesh(mesh): + out = f(arr, arr.T) + self.assertEqual(out.sharding, NamedSharding(mesh, P('x',))) + lowered_text = f.lower(arr, arr.T).as_text() + self.assertNotIn('unspecified_dims', lowered_text) + + mesh2 = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types={mesh_lib.AxisTypes.User: 'x', + mesh_lib.AxisTypes.Auto: 'y'}) + with mesh_lib.set_mesh(mesh2): + arr = jax.device_put(arr, NamedSharding(mesh2, P('x', 'y'))) + arr2 = jax.device_put(np_inp.T, NamedSharding(mesh2, P('y', None))) + out = f(arr, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh2, P('x', None))) + lowered_text = f.lower(arr, arr2).as_text() + self.assertTrue(lowered_text.count("unspecified_dims") == 3) + + mesh3 = jtu.create_mesh((2, 2), ('x', 'y'), + axis_types={mesh_lib.AxisTypes.User: 'y', + mesh_lib.AxisTypes.Auto: 'x'}) + with mesh_lib.set_mesh(mesh3): + arr = jax.device_put(arr, NamedSharding(mesh3, P('x', 'y'))) + arr2 = jax.device_put(np_inp.T, NamedSharding(mesh3, P('y', 'x'))) + out = f(arr, arr2) + self.assertEqual(out.sharding, NamedSharding(mesh3, P('x',))) + lowered_text = f.lower(arr, arr2).as_text() + self.assertTrue(lowered_text.count("unspecified_dims") == 4) + + with self.assertRaisesRegex( + ValueError, + "AxisTypes should be the same in a tuple subset of PartitionSpec"): + NamedSharding(mesh2, P(('x', 'y'))) + @jtu.pytest_mark_if_available('multiaccelerator') class PJitErrorTest(jtu.JaxTestCase):