diff --git a/jax/_src/core.py b/jax/_src/core.py index 2ad1f7b0bdf7..3acd5b83c60d 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1669,7 +1669,7 @@ def get_sharding(sharding, ndim): context_mesh = mesh_lib.get_abstract_mesh() if not context_mesh: - return RuntimeError("Please set the mesh via `jax.set_mesh` API.") + raise RuntimeError("Please set the mesh via `jax.set_mesh` API.") assert sharding is None return NamedSharding(context_mesh, P(*[None] * ndim))