Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve checkpoint / remat concreteness error with static_argnums #24516

Merged

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Oct 24, 2024

Follow-up to #24511.

from functools import partial
import jax

@partial(jax.remat, static_argnums=(1,))
def f(x, _):
  if x > 0:
    return x
  else:
    return jnp.sin(x)

f(3., 1.)

Before:

Traceback (most recent call last):
  File "/usr/local/google/home/mattjj/packages/jax/froystig.py", line 11, in <module>
    f(3., 1.)
  File "/usr/local/google/home/mattjj/packages/jax/froystig.py", line 6, in f
    if x > 0:
       ^^^^^
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function new_fun at /usr/local/google/home/mattjj/packages/jax/jax/_src/ad_checkpoint.py:394 for checkpoint. This concrete value was not available in Python because it depends on the value of the argument dyn_args[0].
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Consider using the `static_argnums` parameter for `jax.remat` or `jax.checkpoint`. See the `jax.checkpoint` docstring and its example involving `static_argnums`:
https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html

Notice how it says "This concrete value was not available in Python because it depends on the value of the argument dyn_args[0].". What!?

After:

Traceback (most recent call last):
  File "/usr/local/google/home/mattjj/packages/jax/froystig.py", line 11, in <module>
    f(3., 1.)
  File "/usr/local/google/home/mattjj/packages/jax/froystig.py", line 6, in f
    if x > 0:
       ^^^^^
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[].
The error occurred while tracing the function f at /usr/local/google/home/mattjj/packages/jax/froystig.py:4 for checkpoint / remat. This concrete value was not available in Python because it depends on the value of the argument x.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

Consider using the `static_argnums` parameter for `jax.remat` or `jax.checkpoint`. See the `jax.checkpoint` docstring and its example involving `static_argnums`:
https://jax.readthedocs.io/en/latest/_autosummary/jax.checkpoint.html

Notice how it says "This concrete value was not available in Python because it depends on the value of the argument x."

image

Also, net reduction in lines! Some cleanup that my past self promised long ago...

@mattjj mattjj requested a review from froystig October 24, 2024 21:53
@mattjj mattjj force-pushed the improve-concreteness-error-in-remat-3 branch 6 times, most recently from 64c28e4 to df7b04d Compare October 25, 2024 03:31
@mattjj mattjj added the pull ready Ready for copybara import and testing label Oct 25, 2024
@mattjj mattjj added pull ready Ready for copybara import and testing and removed pull ready Ready for copybara import and testing labels Dec 17, 2024
@mattjj mattjj force-pushed the improve-concreteness-error-in-remat-3 branch from df7b04d to 7a942eb Compare December 18, 2024 04:18
@mattjj mattjj force-pushed the improve-concreteness-error-in-remat-3 branch from 7a942eb to 9acd4a9 Compare December 18, 2024 04:25
@mattjj mattjj added the better_errors Improve the error reporting label Dec 18, 2024
@mattjj mattjj self-assigned this Dec 18, 2024
@copybara-service copybara-service bot merged commit 3262770 into jax-ml:main Dec 18, 2024
20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
better_errors Improve the error reporting pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants