Skip to content

Commit

Permalink
improve jnp.mean / jnp.sum / ... error message for out-of-bounds axis…
Browse files Browse the repository at this point in the history
… index
  • Loading branch information
mattjj committed Dec 18, 2024
1 parent 74eca13 commit 00b6e60
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 37 deletions.
2 changes: 1 addition & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,7 +1557,7 @@ def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
[6, 5]]], dtype=int32)
"""
util.check_arraylike("flip", m)
return _flip(asarray(m), reductions._ensure_optional_axes(axis))
return _flip(asarray(m), reductions._ensure_optional_axes(axis, core.get_aval(m)))

@partial(jit, static_argnames=('axis',))
def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array:
Expand Down
91 changes: 55 additions & 36 deletions jax/_src/numpy/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,12 @@ def _reduction(a: ArrayLike, name: str, op: ReductionOp, init_val: ArrayLike,
result = lax.expand_dims(result, pos_dims)
return lax.convert_element_type(result, dtype or result_dtype)

def _canonicalize_axis_allow_named(x, rank):
return maybe_named_axis(x, lambda i: _canonicalize_axis(i, rank), lambda name: name)

def _reduction_dims(a: ArrayLike, axis: Axis):
if axis is None:
return (tuple(range(np.ndim(a))),) * 2
elif not isinstance(axis, (np.ndarray, tuple, list)):
axis = (axis,) # type: ignore[assignment]
canon_axis = tuple(_canonicalize_axis_allow_named(x, np.ndim(a))
for x in axis) # type: ignore[union-attr]
canon_axis = tuple(_canonicalize_axis(x, np.ndim(a)) for x in axis)
if len(canon_axis) != len(set(canon_axis)):
raise ValueError(f"duplicate value in 'axis': {axis}")
canon_pos_axis = tuple(x for x in canon_axis if isinstance(x, int))
Expand Down Expand Up @@ -215,16 +211,26 @@ def _require_integer(operand: ArrayLike) -> Array:
raise ValueError(f"integer argument required; got dtype={arr.dtype}")
return arr

def _ensure_optional_axes(x: Axis) -> Axis:
def _ensure_optional_axes(x: Axis, aval: core.ShapedArray) -> Axis:
fail = object()
def force(x):
if x is None:
return None
try:
return operator.index(x)
except TypeError:
return tuple(i if isinstance(i, str) else operator.index(i) for i in x)
return core.concrete_or_error(
force, x, "The axis argument must be known statically.")
try: return operator.index(x)
except: pass
try: return tuple(operator.index(i) for i in x)
except: pass
return fail
x = core.concrete_or_error(force, x, "The axis argument must be known statically.")
if x is fail:
raise TypeError(f"'axis' argument must be None, int, or sequence of ints, got {x}")
ndim = len(aval.shape)
if x is not None:
x_ = x if isinstance(x, tuple) else (x,)
if not _all(-ndim <= i < ndim for i in x_):
raise ValueError(f"'axis' argument of {x} is out-of-bounds for array of "
f"rank {ndim} (type {aval.str_short(short_dtypes=True)})")
return x


@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims', 'promote_integers'), inline=True)
Expand Down Expand Up @@ -309,7 +315,8 @@ def sum(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
>>> jnp.sum(x, axis=0, keepdims=True, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
"""
return _reduce_sum(a, axis=_ensure_optional_axes(axis), dtype=dtype, out=out,
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduce_sum(a, axis=axis, dtype=dtype, out=out,
keepdims=keepdims, initial=initial, where=where,
promote_integers=promote_integers)

Expand Down Expand Up @@ -397,7 +404,8 @@ def prod(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
[1],
[1]], dtype=int32)
"""
return _reduce_prod(a, axis=_ensure_optional_axes(axis), dtype=dtype,
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduce_prod(a, axis=axis, dtype=dtype,
out=out, keepdims=keepdims, initial=initial, where=where,
promote_integers=promote_integers)

Expand Down Expand Up @@ -482,7 +490,8 @@ def max(a: ArrayLike, axis: Axis = None, out: None = None,
>>> jnp.max(x, axis=0, keepdims=True, initial=0, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
"""
return _reduce_max(a, axis=_ensure_optional_axes(axis), out=out,
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduce_max(a, axis=axis, out=out,
keepdims=keepdims, initial=initial, where=where)

@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
Expand Down Expand Up @@ -564,7 +573,8 @@ def min(a: ArrayLike, axis: Axis = None, out: None = None,
>>> jnp.min(x, axis=0, keepdims=True, initial=0, where=where)
Array([[0, 0, 0, 0]], dtype=int32)
"""
return _reduce_min(a, axis=_ensure_optional_axes(axis), out=out,
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduce_min(a, axis=axis, out=out,
keepdims=keepdims, initial=initial, where=where)

@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
Expand Down Expand Up @@ -621,7 +631,8 @@ def all(a: ArrayLike, axis: Axis = None, out: None = None,
>>> jnp.all(x, axis=0, keepdims=True, where=where)
Array([[ True, True, False, False]], dtype=bool)
"""
return _reduce_all(a, axis=_ensure_optional_axes(axis), out=out,
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduce_all(a, axis=axis, out=out,
keepdims=keepdims, where=where)

@partial(api.jit, static_argnames=('axis', 'keepdims'), inline=True)
Expand Down Expand Up @@ -678,63 +689,69 @@ def any(a: ArrayLike, axis: Axis = None, out: None = None,
>>> jnp.any(x, axis=0, keepdims=True, where=where)
Array([[ True, False, True, False]], dtype=bool)
"""
return _reduce_any(a, axis=_ensure_optional_axes(axis), out=out,
keepdims=keepdims, where=where)
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduce_any(a, axis=axis, out=out, keepdims=keepdims, where=where)


@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_bitwise_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
arr = lax_internal.asarray(a)
init_val = np.array(-1, dtype=dtype or arr.dtype)
return _reduction(arr, name="reduce_bitwise_and", op=lax.bitwise_and, init_val=init_val, preproc=_require_integer,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
a = lax_internal.asarray(a)
init_val = np.array(-1, dtype=dtype or a.dtype)
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduction(a, name="reduce_bitwise_and", op=lax.bitwise_and, init_val=init_val, preproc=_require_integer,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)


@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_bitwise_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduction(a, name="reduce_bitwise_or", op=lax.bitwise_or, init_val=0, preproc=_require_integer,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)


@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_bitwise_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduction(a, name="reduce_bitwise_xor", op=lax.bitwise_xor, init_val=0, preproc=_require_integer,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)


@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_logical_and(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduction(a, name="reduce_logical_and", op=lax.bitwise_and, init_val=True, preproc=_cast_to_bool,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)


@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_logical_or(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduction(a, name="reduce_logical_or", op=lax.bitwise_or, init_val=False, preproc=_cast_to_bool,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)


@partial(api.jit, static_argnames=('axis', 'keepdims', 'dtype'), inline=True)
def _reduce_logical_xor(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None) -> Array:
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _reduction(a, name="reduce_logical_xor", op=lax.bitwise_xor, init_val=False, preproc=_cast_to_bool,
axis=_ensure_optional_axes(axis), dtype=dtype, out=out, keepdims=keepdims,
axis=axis, dtype=dtype, out=out, keepdims=keepdims,
initial=initial, where_=where)


Expand Down Expand Up @@ -860,8 +877,8 @@ def mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
[2.5],
[6. ]], dtype=float32)
"""
return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
where=where)
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _mean(a, axis, dtype, out, keepdims, where=where)

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'), inline=True)
def _mean(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
Expand Down Expand Up @@ -961,7 +978,8 @@ def average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
>>> jnp.average(x, weights=weights, axis=1)
Array([5.5, 4.5], dtype=float32)
"""
return _average(a, _ensure_optional_axes(axis), weights, returned, keepdims)
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _average(a, axis, weights, returned, keepdims)

@partial(api.jit, static_argnames=('axis', 'returned', 'keepdims'), inline=True)
def _average(a: ArrayLike, axis: Axis = None, weights: ArrayLike | None = None,
Expand Down Expand Up @@ -1099,8 +1117,8 @@ def var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
correction = ddof
elif not isinstance(ddof, int) or ddof != 0:
raise ValueError("ddof and correction can't be provided simultaneously.")
return _var(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
where=where)
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _var(a, axis, dtype, out, correction, keepdims, where=where)

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _var(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
Expand Down Expand Up @@ -1237,8 +1255,8 @@ def std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
correction = ddof
elif not isinstance(ddof, int) or ddof != 0:
raise ValueError("ddof and correction can't be provided simultaneously.")
return _std(a, _ensure_optional_axes(axis), dtype, out, correction, keepdims,
where=where)
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _std(a, axis, dtype, out, correction, keepdims, where=where)

@partial(api.jit, static_argnames=('axis', 'dtype', 'keepdims'))
def _std(a: ArrayLike, axis: Axis = None, dtype: DTypeLike | None = None,
Expand Down Expand Up @@ -1293,7 +1311,8 @@ def ptp(a: ArrayLike, axis: Axis = None, out: None = None,
[7],
[6]], dtype=int32)
"""
return _ptp(a, _ensure_optional_axes(axis), out, keepdims)
axis = _ensure_optional_axes(axis, core.get_aval(a))
return _ptp(a, axis, out, keepdims)

@partial(api.jit, static_argnames=('axis', 'keepdims'))
def _ptp(a: ArrayLike, axis: Axis = None, out: None = None,
Expand Down
7 changes: 7 additions & 0 deletions tests/lax_numpy_reducers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -944,5 +944,12 @@ def np_op(x, axis=None, dtype=None, include_initial=False):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)

def test_oob_axis_error_message(self):
with self.assertRaisesRegex(ValueError, r"'axis' argument of 0 is out-of-bounds"):
jnp.mean(0, axis=0)
with self.assertRaisesRegex(ValueError, r"'axis' argument of \(0, 1\) is out-of-bounds"):
jnp.sum(jnp.arange(3), axis=(0, 1))


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 00b6e60

Please sign in to comment.