Skip to content

Commit

Permalink
[sharding_in_types] Enforce AxisTypes to always exist if set_mesh i…
Browse files Browse the repository at this point in the history
…s used.

Also support `Auto` mode fully or mixed in with `User` mode. This works by overriding the sharding of `Auto` axes in the PartitionSpec with `Unconstrained` in `ShapedArray` constructor. The `ShapedArray` constructor is the central place where we can make such substitutions.

During lowering of shardings with auto axes, we mark the auto dims are `unspecifed_dims`. We don't mark all dims as unspecified because that would enable XLA to shard them even further which is not what we want if some of the dims are user sharded.

PiperOrigin-RevId: 704911253
  • Loading branch information
yashk2810 authored and Google-ML-Automation committed Dec 11, 2024
1 parent e88b578 commit b5e4fd1
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 56 deletions.
2 changes: 2 additions & 0 deletions jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ pytype_strict_library(
":dtypes",
":effects",
":mesh",
":partition_spec",
":pretty_printer",
":source_info_util",
":traceback_util",
Expand Down Expand Up @@ -558,6 +559,7 @@ pytype_strict_library(
":layout",
":op_shardings",
":partial_eval",
":partition_spec",
":path",
":pickle_util",
":sharding",
Expand Down
43 changes: 23 additions & 20 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from jax._src import effects
from jax._src import compute_on
from jax._src import mesh as mesh_lib
from jax._src.partition_spec import PartitionSpec as P, UnconstrainedSingleton
from jax._src.errors import (
ConcretizationTypeError, TracerArrayConversionError, TracerBoolConversionError,
TracerIntegerConversionError, UnexpectedTracerError)
Expand Down Expand Up @@ -1599,13 +1600,30 @@ def _invalid_shape_error(shape: Shape, context: str=""):

return TypeError(msg)

# TODO(yashkatariya): Only works with User/Auto. Generalize it to work with
# Collective too.
def _maybe_modify_sharding(sharding):
if mesh_lib.AxisTypes.Auto not in sharding.mesh.axis_types:
return sharding

new_spec = []
for s in sharding.spec:
if s is None or isinstance(s, UnconstrainedSingleton):
new_spec.append(s)
else:
temp_s = s[0] if isinstance(s, tuple) else s
new_spec.append(
P.UNCONSTRAINED
if sharding.mesh._name_to_type[temp_s] == mesh_lib.AxisTypes.Auto else s)
return sharding.with_spec(new_spec)


def get_sharding(sharding, ndim):
from jax._src.sharding_impls import NamedSharding, PartitionSpec as P # type: ignore
from jax._src.sharding_impls import NamedSharding # type: ignore

if sharding is not None:
assert len(sharding.spec) == ndim
return sharding
return _maybe_modify_sharding(sharding)

context_mesh = mesh_lib.get_abstract_mesh()
# TODO(yashkatariya): Error out and ask users to set the context mesh in their
Expand Down Expand Up @@ -1675,9 +1693,7 @@ def str_short(self, short_dtypes=False):
dt_str = dt_str.replace('void', 'float0')
if hasattr(self, 'sharding') and self.sharding is not None:
shapestr = _get_shape_sharding_str(self.shape, self.sharding.spec)
axis_types = self.sharding.mesh.axis_types
axt = _get_axis_type_str(axis_types) if axis_types is not None else ''
return f'{dt_str}[{shapestr}]{axt}'
return f'{dt_str}[{shapestr}]'
else:
shapestr = ','.join(map(str, self.shape))
return f'{dt_str}[{shapestr}]'
Expand All @@ -1689,26 +1705,13 @@ def _len(self, ignored_tracer):
raise TypeError("len() of unsized object") from err # same as numpy error


def _get_axis_type_str(axis_types):
from jax._src.mesh import AxisTypes # type: ignore

out = []
for t, axes in axis_types.items():
a = f"({','.join(a for a in axes)})" if isinstance(axes, tuple) else axes
if t == AxisTypes.Collective:
out.append(f"C:{a}")
elif t == AxisTypes.User:
out.append(f"U:{a}")
else:
assert t == AxisTypes.Auto
out.append(f"A:{a}")
return f"{{{', '.join(out)}}}"

def _get_shape_sharding_str(shape, spec):
out = []
for s1, s2 in zip(shape, spec):
if s2 is None:
out.append(f"{s1}")
elif isinstance(s2, UnconstrainedSingleton):
out.append(f"{s1}")
elif isinstance(s2, tuple):
ss = ','.join(s for s in s2)
out.append(f"{s1}@({ss})")
Expand Down
16 changes: 12 additions & 4 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from jax._src.layout import AutoLayout, DeviceLocalLayout
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import AUTO
from jax._src.partition_spec import UnconstrainedSingleton
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
from jax._src.lib.mlir import dialects, ir, passmanager
Expand Down Expand Up @@ -2524,12 +2525,19 @@ def lower_sharding_under_shit(ctx, op, aval, sharding_proto=None):
# Don't emit a wsc under full manual mode to avoid increasing HLO size.
if aval.sharding.mesh._are_all_axes_collective:
return op
if aval.sharding.mesh._are_all_axes_auto:
return op
# TODO(yashkatariya): If all the axes in pspec are AUTO or collective,
# `return op` early and avoid bloating HLO size.
proto = (aval.sharding._to_xla_hlo_sharding(aval.ndim).to_proto()
if sharding_proto is None else sharding_proto)
# TODO(yashkatariya): Enable this
# unspecified_dims = (set(range(aval.ndim))
# if aval.sharding.mesh._any_axis_collective else None)
return wrap_with_sharding_op(ctx, op, aval, proto)
unspecified_dims = None
if aval.sharding.mesh._any_axis_collective:
unspecified_dims = set(range(aval.ndim))
elif aval.sharding.mesh._any_axis_auto:
unspecified_dims = {i for i, s in enumerate(aval.sharding.spec)
if isinstance(s, UnconstrainedSingleton)}
return wrap_with_sharding_op(ctx, op, aval, proto, unspecified_dims)


def set_sharding(op, sharding: xc.OpSharding | sharding_impls.SdyArraySharding):
Expand Down
8 changes: 5 additions & 3 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
from jax._src.partition_spec import PartitionSpec, UnconstrainedSingleton
from jax._src.sharding import Sharding as JSharding
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
Expand Down Expand Up @@ -2123,11 +2123,13 @@ def _concretize_abstract_shardings(shardings, avals, device_assignment):
@lru_cache(maxsize=128)
def _abstract_to_concrete_mesh(abstract_mesh):
return mesh_lib.Mesh(
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names)
np_dev.reshape(abstract_mesh.axis_sizes), abstract_mesh.axis_names,
axis_types=abstract_mesh.axis_types)

out = []
for s, a in zip(shardings, avals):
if isinstance(s, UnspecifiedValue) and a.sharding is not None:
if (isinstance(s, UnspecifiedValue) and a.sharding is not None and
all(not isinstance(s, UnconstrainedSingleton) for s in a.sharding.spec)):
out.append(NamedSharding(_abstract_to_concrete_mesh(a.sharding.mesh),
a.sharding.spec))
else:
Expand Down
36 changes: 29 additions & 7 deletions jax/_src/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def axis_names_to_types(axis_types) -> dict[str, AxisTypes]:

_mesh_object_dict = {} # type: ignore

MeshAxisType = dict[AxisTypes, str | tuple[str, ...]]

class Mesh(contextlib.ContextDecorator):
"""Declare the hardware resources available in the scope of this manager.
Expand Down Expand Up @@ -178,11 +179,11 @@ class Mesh(contextlib.ContextDecorator):

devices: np.ndarray
axis_names: tuple[MeshAxisName, ...]
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None
axis_types: MeshAxisType | None

def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
axis_names: str | Sequence[MeshAxisName],
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None):
axis_names: str | Sequence[MeshAxisName], *,
axis_types: MeshAxisType | None = None):
if not isinstance(devices, np.ndarray):
devices = np.array(devices)
if isinstance(axis_names, str):
Expand Down Expand Up @@ -216,7 +217,8 @@ def __new__(cls, devices: np.ndarray | Sequence[xc.Device],
return self

def __reduce__(self):
return (type(self), (self.devices, self.axis_names, self.axis_types))
return (type(self), (self.devices, self.axis_names),
{'axis_types': self.axis_types})

def __eq__(self, other):
if not isinstance(other, Mesh):
Expand Down Expand Up @@ -348,7 +350,7 @@ def local_devices(self):

@functools.cached_property
def abstract_mesh(self):
return AbstractMesh(self.shape_tuple, self.axis_types)
return AbstractMesh(self.shape_tuple, axis_types=self.axis_types)


EMPTY_ENV = ResourceEnv(Mesh(np.empty((), dtype=object), ()))
Expand All @@ -373,8 +375,8 @@ class AbstractMesh:
details.
"""

def __init__(self, shape_tuple: tuple[tuple[str, int], ...],
axis_types: dict[AxisTypes, str | tuple[str, ...]] | None = None):
def __init__(self, shape_tuple: tuple[tuple[str, int], ...], *,
axis_types: MeshAxisType | None = None):
self.shape_tuple = shape_tuple
self.axis_types = axis_types
if self.shape_tuple:
Expand Down Expand Up @@ -434,6 +436,24 @@ def _are_all_axes_collective(self) -> bool:
return False
return all(t == AxisTypes.Collective for t in self.axis_types.keys())

@functools.cached_property
def _are_all_axes_auto(self) -> bool:
if self.axis_types is None:
return False
return all(t == AxisTypes.Auto for t in self.axis_types.keys())

@functools.cached_property
def _any_axis_collective(self) -> bool:
if self.axis_types is None:
return False
return any(t == AxisTypes.Collective for t in self.axis_types.keys())

@functools.cached_property
def _any_axis_auto(self) -> bool:
if self.axis_types is None:
return False
return any(t == AxisTypes.Auto for t in self.axis_types.keys())

@property
def devices(self):
_raise_value_error("devices")
Expand Down Expand Up @@ -474,6 +494,8 @@ def _raise_value_error(name):

@contextlib.contextmanager
def set_abstract_mesh(mesh: AbstractMesh):
if mesh is not None and mesh.axis_types is None:
raise RuntimeError('Please set the AxisTypes of Mesh.')
prev_val = jax_config.abstract_mesh_context_manager.swap_local(mesh)
try:
yield
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/partition_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from __future__ import annotations

class _UnconstrainedPartitionSingleton:
class UnconstrainedSingleton:

def __repr__(self):
return "UNCONSTRAINED"
Expand All @@ -23,7 +23,7 @@ def __repr__(self):
# Unconstrained sentinel value for PartitionSpec, representing a dimension for
# which the user wants XLA to assign the best partitioning.
# TODO(yashkatariya): May rename to AUTO.
_UNCONSTRAINED_PARTITION = _UnconstrainedPartitionSingleton()
_UNCONSTRAINED_PARTITION = UnconstrainedSingleton()


class PartitionSpec(tuple):
Expand Down
19 changes: 17 additions & 2 deletions jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,19 @@ def _check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes):
f"is also found in manual_axes: {_manual_axes}.") from None


@util.cache(max_size=128, trace_context_in_key=False)
def _check_axis_type_consistency(mesh, parsed_pspec):
if mesh.axis_types is None:
return
for p in parsed_pspec:
if p is not None:
if not all(mesh._name_to_type[p[0]] == mesh._name_to_type[r] for r in p):
raise ValueError(
'AxisTypes should be the same in a tuple subset of PartitionSpec:'
f' {parsed_pspec.get_partition_spec()}. Got subset {p} with axis'
f' types: ({", ".join(str(mesh._name_to_type[r]) for r in p)})')


def hashed_index(x) -> int:
# This works for both `pjit` indices and `pmap` indices (which might
# have an integer instead of a slice).
Expand Down Expand Up @@ -1084,6 +1097,7 @@ def preprocess(mesh, spec, parsed_pspec, _manual_axes=frozenset()):
PartitionSpec() if spec is None else spec,
"NamedSharding spec", allow_unconstrained_dims=True)
_check_mesh_resource_axis(mesh, parsed_pspec, _manual_axes)
_check_axis_type_consistency(mesh, parsed_pspec)
return parsed_pspec


Expand Down Expand Up @@ -1673,7 +1687,8 @@ def _gspmd_to_named_sharding_via_mesh(


def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
*, devices: Sequence[xc.Device] | None = None) -> mesh_lib.Mesh:
*, devices: Sequence[xc.Device] | None = None,
axis_types: mesh_lib.MeshAxisType | None = None) -> mesh_lib.Mesh:
"""Creates an efficient mesh with the shape and axis names specified.
This function attempts to automatically compute a good mapping from a set of
Expand Down Expand Up @@ -1735,4 +1750,4 @@ def make_mesh(axis_shapes: Sequence[int], axis_names: Sequence[str],
mesh_devices = mesh_utils.create_device_mesh(
new_axis_shapes, devices,
allow_split_physical_axes=allow_split_physical_axes)
return mesh_lib.Mesh(mesh_devices, axis_names)
return mesh_lib.Mesh(mesh_devices, axis_names, axis_types=axis_types)
12 changes: 7 additions & 5 deletions jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1443,26 +1443,28 @@ def with_and_without_mesh(f):
('Mesh', (('x', 2),), (('i', 'x'),))
))(with_mesh_from_kwargs(f))

def with_user_mesh(sizes, names):
def with_user_mesh(sizes, names, axis_types=None):
axis_types = ({mesh_lib.AxisTypes.User: names}
if axis_types is None else axis_types)
def decorator(fn):
def mesh_fn(*args, **kwargs):
mesh = create_mesh(sizes, names)
mesh = create_mesh(sizes, names, axis_types=axis_types)
with mesh_lib.set_mesh(mesh):
return fn(*args, **kwargs, mesh=mesh)
return mesh_fn
return decorator


def create_mesh(mesh_shape, axis_names, iota_order=False):
def create_mesh(mesh_shape, axis_names, iota_order=False, axis_types=None):
size = math.prod(mesh_shape)
if len(jax.devices()) < size:
raise unittest.SkipTest(f"Test requires {size} global devices.")
if iota_order:
devices = sorted(jax.devices(), key=lambda d: d.id)
mesh_devices = np.array(devices[:size]).reshape(mesh_shape)
return jax.sharding.Mesh(mesh_devices, axis_names)
return jax.sharding.Mesh(mesh_devices, axis_names, axis_types=axis_types)
else:
return jax.make_mesh(mesh_shape, axis_names)
return jax.make_mesh(mesh_shape, axis_names, axis_types=axis_types)

class _cached_property:
null = object()
Expand Down
9 changes: 4 additions & 5 deletions jax/experimental/shard_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
from jax._src import traceback_util
from jax._src import util
from jax._src.core import Tracer
from jax._src.mesh import AbstractMesh, Mesh, AxisTypes, set_abstract_mesh
from jax._src.mesh import (AbstractMesh, Mesh, AxisTypes, set_abstract_mesh,
get_abstract_mesh)
from jax._src.api import _shared_code_pmap, _prepare_pmap
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, convolution, fft, linalg,
Expand Down Expand Up @@ -536,7 +537,7 @@ def _shard_shaped_array(mesh: Mesh, names: AxisNames, aval: core.AbstractValue
for i, sz in enumerate(aval.shape))
if config.sharding_in_types.value:
new_mesh = AbstractMesh(
mesh.shape_tuple, {AxisTypes.Collective: mesh.axis_names})
mesh.shape_tuple, axis_types={AxisTypes.Collective: mesh.axis_names})
new_sharding = NamedSharding(new_mesh, P(*[None] * aval.ndim))
else:
new_sharding = None
Expand All @@ -548,11 +549,9 @@ def _unshard_shaped_array(mesh: Mesh, names: AxisNames,
assert isinstance(aval, core.ShapedArray)
new_shape = tuple(sz * prod(mesh.shape[n] for n in names.get(i, ()))
for i, sz in enumerate(aval.shape))
# TODO(yashkatariya): Reset the mesh properly based on the input avals if the
# mesh of shard_map specifies collective axes.
if config.sharding_in_types.value:
spec = _names_to_pspec(names)._normalized_spec(aval.ndim)
new_sharding = NamedSharding(AbstractMesh(mesh.shape_tuple), spec)
new_sharding = NamedSharding(get_abstract_mesh(), spec)
else:
new_sharding = None
return aval.update(shape=new_shape, sharding=new_sharding)
Expand Down
Loading

0 comments on commit b5e4fd1

Please sign in to comment.