Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve jnp.mean / jnp.sum / ... error message for out-of-bounds axis index #25155

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Nov 28, 2024

jnp.mean(0, axis=0)

Before:

Traceback (most recent call last):
  File "/usr/local/google/home/mattjj/packages/jax/james10.py", line 4, in <module>
    jnp.mean(0, axis=0)
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/numpy/reductions.py", line 844, in mean
    return _mean(a, _ensure_optional_axes(axis), dtype, out, keepdims,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/numpy/reductions.py", line 871, in _mean
    normalizer = core.dimension_as_value(_axis_size(a, axis))
                                         ^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/numpy/reductions.py", line 782, in _axis_size
    size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/numpy/reductions.py", line 782, in <lambda>
    size *= maybe_named_axis(a, lambda i: a_shape[i], lambda name: lax.psum(1, name))
                                          ~~~~~~~^^^
IndexError: tuple index out of range

After:

Traceback (most recent call last):
  File "/usr/local/google/home/mattjj/packages/jax/james10.py", line 4, in <module>
    jnp.mean(0, axis=0)
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/numpy/reductions.py", line 861, in mean
    axis = _ensure_optional_axes(axis, core.get_aval(a))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/google/home/mattjj/packages/jax/jax/_src/numpy/reductions.py", line 216, in _ensure_optional_axes
    raise ValueError(f"'axis' argument of {x} is out-of-bounds for array of "
ValueError: 'axis' argument of 0 is out-of-bounds for array of rank 0 (type i32[])

@mattjj mattjj requested a review from froystig November 28, 2024 00:15
@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Nov 28, 2024
@mattjj mattjj added pull ready Ready for copybara import and testing and removed pull ready Ready for copybara import and testing labels Dec 18, 2024
@mattjj mattjj force-pushed the improve-reduction-error-message branch from 9fcf1e5 to 00b6e60 Compare December 18, 2024 21:25
@mattjj mattjj force-pushed the improve-reduction-error-message branch from 00b6e60 to ed84e45 Compare December 18, 2024 21:29
@mattjj mattjj added the better_errors Improve the error reporting label Dec 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better_errors Improve the error reporting pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants