Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve jnp.mean / jnp.sum / ... error message for out-of-bounds axis index #25155

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,7 @@ def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
[6, 5]]], dtype=int32)
"""
util.check_arraylike("flip", m)
return _flip(asarray(m), reductions._ensure_optional_axes(axis))
return _flip(asarray(m), reductions._ensure_optional_axes(axis, core.get_aval(m)))

@partial(jit, static_argnames=('axis',))
def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array:
Expand Down
92 changes: 56 additions & 36 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,13 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike,
result = lax.expand_dims(result, pos_dims)
return lax.convert_element_type(result, dtype or result_dtype)

def _canonicalize_axis_allow_named(x, rank):
return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)

def _reduction_dims(a: ArrayLike, axis: Axis):
if axis is None:
return (tuple(range(np.ndim(a))),) * 2
elif not isinstance(axis, (np.ndarray, tuple, list)):
axis = (axis,) # type: ignore[assignment]
canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a))
for x in axis) # type: ignore[union-attr]
assert not isinstance(axis, int)
canon_axis = tuple(_canonicalize_axis(x, np.ndim(a)) for x in axis)
if len(canon_axis) != len(set(canon_axis)):
raise ValueError(f"duplicate value in 'axis': {axis}")
canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int))
Expand Down Expand Up @@ -215,16 +212,26 @@ def _require_integer(operand: ArrayLike) -> Array:
raise ValueError(f"integer argument required; got dtype={arr.dtype}")
return arr

def _ensure_optional_axes(x: Axis) -> Axis:
def _ensure_optional_axes(x: Axis, aval: core.ShapedArray) -> Axis:
fail = object()
def force(x):
if x is None:
return None
try:
return operator.index(x)
except TypeError:
return tuple(i if isinstance(i, str) else operator.index(i) for i in x)
return core.concrete_or_error(
force, x, "The axis argument must be known statically.")
try: return operator.index(x)
except: pass
try: return tuple(operator.index(i) for i in x)
except: pass
return fail
x = core.concrete_or_error(force, x, "The axis argument must be known statically.")
if x is fail:
raise TypeError(f"'axis' argument must be None, int, or sequence of ints, got {x}")
ndim = len(aval.shape)
if x is not None:
x_ = x if isinstance(x, tuple) else (x,)
if not _all(-ndim <= i < ndim for i in x_):
raise ValueError(f"'axis' argument of {x} is out-of-bounds for array of "
f"rank {ndim} (type {aval.str_short(short_dtypes=True)})")
return x


@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
Expand Down Expand Up @@ -309,7 +316,8 @@ def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
>>> jnp.sum(x, axis=0, keepdims=True, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
"""
return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out,
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduce_sum(a, axis=axis, dtype=dtype, out=out,
keepdims=keepdims, initial=initial, where=where,
promote_integers=promote_integers)

Expand Down Expand Up @@ -397,7 +405,8 @@ def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
[1],
[1]], dtype=int32)
"""
return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype,
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduce_prod(a, axis=axis, dtype=dtype,
out=out, keepdims=keepdims, initial=initial, where=where,
promote_integers=promote_integers)

Expand Down Expand Up @@ -482,7 +491,8 @@ def max(a: ArrayLike, axis: Axis = None, out: None = None,
>>> jnp.max(x, axis=0, keepdims=True, initial=0, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
"""
return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out,
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduce_max(a, axis=axis, out=out,
keepdims=keepdims, initial=initial, where=where)

@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
Expand Down Expand Up @@ -564,7 +574,8 @@ def min(a: ArrayLike, axis: Axis = None, out: None = None,
>>> jnp.min(x, axis=0, keepdims=True, initial=0, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
"""
return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out,
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduce_min(a, axis=axis, out=out,
keepdims=keepdims, initial=initial, where=where)

@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
Expand Down Expand Up @@ -621,7 +632,8 @@ def all(a: ArrayLike, axis: Axis = None, out: None = None,
>>> jnp.all(x, axis=0, keepdims=True, where=where)
Array([[ True, True, False, False]], dtype=bool)
"""
return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out,
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduce_all(a, axis=axis, out=out,
keepdims=keepdims, where=where)

@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
Expand Down Expand Up @@ -678,63 +690,69 @@ def any(a: ArrayLike, axis: Axis = None, out: None = None,
>>> jnp.any(x, axis=0, keepdims=True, where=where)
Array([[ True, False, True, False]], dtype=bool)
"""
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, where=where)
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduce_any(a, axis=axis, out=out, keepdims=keepdims, where=where)


@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
arr = lax_internal.asarray(a)
init_val = np.array(-1, dtype=dtype or arr.dtype)
return _reduction(arr, name="reduce_bitwise_and", op=lax.bitwise_and, init_val=init_val, preproc=_require_integer,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
a = lax_internal.asarray(a)
init_val = np.array(-1, dtype=dtype or a.dtype)
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduction(a, name="reduce_bitwise_and", op=lax.bitwise_and, init_val=init_val, preproc=_require_integer,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)


@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduction(a, name="reduce_bitwise_or", op=lax.bitwise_or, init_val=0, preproc=_require_integer,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)


@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduction(a, name="reduce_bitwise_xor", op=lax.bitwise_xor, init_val=0, preproc=_require_integer,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)


@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduction(a, name="reduce_logical_and", op=lax.bitwise_and, init_val=True, preproc=_cast_to_bool,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)


@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduction(a, name="reduce_logical_or", op=lax.bitwise_or, init_val=False, preproc=_cast_to_bool,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)


@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduction(a, name="reduce_logical_xor", op=lax.bitwise_xor, init_val=False, preproc=_cast_to_bool,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)


Expand Down Expand Up @@ -860,8 +878,8 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
[2.5],
[6. ]], dtype=float32)
"""
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
where=where)
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _mean(a, axis, dtype, out, keepdims, where=where)

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
Expand Down Expand Up @@ -961,7 +979,8 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
>>> jnp.average(x, weights=weights, axis=1)
Array([5.5, 4.5], dtype=float32)
"""
return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _average(a, axis, weights, returned, keepdims)

@partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True)
def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
Expand Down Expand Up @@ -1099,8 +1118,8 @@ def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
correction = ddof
elif not isinstance(ddof, int) or ddof != 0:
raise ValueError("ddof and correction can't be provided simultaneously.")
return _var(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
where=where)
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _var(a, axis, dtype, out, correction, keepdims, where=where)

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
Expand Down Expand Up @@ -1237,8 +1256,8 @@ def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
correction = ddof
elif not isinstance(ddof, int) or ddof != 0:
raise ValueError("ddof and correction can't be provided simultaneously.")
return _std(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
where=where)
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _std(a, axis, dtype, out, correction, keepdims, where=where)

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
Expand Down Expand Up @@ -1293,7 +1312,8 @@ def ptp(a: ArrayLike, axis: Axis = None, out: None = None,
[7],
[6]], dtype=int32)
"""
return _ptp(a, _ensure_optional_axes(axis), out, keepdims)
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _ptp(a, axis, out, keepdims)

@partial(api.jit, static_argnames=('axis', 'keepdims'))
def _ptp(a: ArrayLike, axis: Axis = None, out: None = None,
Expand Down
7 changes: 7 additions & 0 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,5 +944,12 @@ def np_op(x, axis=None, dtype=None, include_initial=False):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

def test_oob_axis_error_message(self):
with self.assertRaisesRegex(ValueError, r"'axis' argument of 0 is out-of-bounds"):
jnp.mean(0, axis=0)
with self.assertRaisesRegex(ValueError, r"'axis' argument of \(0, 1\) is out-of-bounds"):
jnp.sum(jnp.arange(3), axis=(0, 1))


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())
Loading