Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add error checks for refs, behind a flag #25449

Merged
merged 1 commit into from
Dec 18, 2024

Conversation

mattjj
Copy link
Collaborator

@mattjj mattjj commented Dec 12, 2024

This PR adds some error checks for incorrect ustage of mutable array references. The checks are behind a flag, default off, because there are probably some downstream users that need to be updated, and we want to defer doing that until all the checks land.

This PR only adds checks for jit, namely:

  1. that we don't pass in the same mutable array argument twice;
  2. that we don't pass in a mutable array as an argument that we also close over;
  3. that we don't return any mutable arrays.

We need to do the same for all higher-order primitives (HOPs): scan, cond, while_loop, and custom_jvp/vjp. I'm going to leave those as follow-ups so we can iterate on getting the jit errors right, since those will largely follow the same pattern.

Here are some example error messages:

import jax
import jax.numpy as jnp
from jax._src import core

jax.config.update('jax_mutable_array_checks', True)

@jax.jit
def f(x):
  return core.mutable_array(x)

f(jnp.arange(3.))
Traceback (most recent call last):
  File "/usr/local/google/home/mattjj/packages/jax/dougal70.py", line 11, in <module>
    f(jnp.arange(3.))
ValueError: function f at
/usr/local/google/home/mattjj/packages/jax/dougal70.py:7 traced for jit returned
a mutable array reference of type Ref{float32[3]}, but mutable array references
cannot be returned.

The returned mutable array was created on line 
/usr/local/google/home/mattjj/packages/jax/dougal70.py:9:9 (f).

@jax.jit
def f(x):
  return {'hi': [0, core.mutable_array(x), 5]}

f(jnp.arange(3.))
Traceback (most recent call last):
  File "/usr/local/google/home/mattjj/packages/jax/dougal70.py", line 11, in <module>
    f(jnp.arange(3.))
ValueError: function f at
/usr/local/google/home/mattjj/packages/jax/dougal70.py:7 traced for jit returned
a mutable array reference of type Ref{float32[3]} at output tree path ['hi'][1],
but mutable array references cannot be returned.

The returned mutable array was created on line 
/usr/local/google/home/mattjj/packages/jax/dougal70.py:9:20 (f).

@jax.jit
def f(y_ref):
  return y_ref

x_ref = core.mutable_array(x)
f(x_ref)
Traceback (most recent call last):
  File "/usr/local/google/home/mattjj/packages/jax/dougal70.py", line 12, in <module>
    f(x_ref)
ValueError: function f at
/usr/local/google/home/mattjj/packages/jax/dougal70.py:7 traced for jit returned
a mutable array reference of type Ref{float32[3]}, but mutable array references
cannot be returned.

The returned mutable array was passed in as the argument y_ref.

@jax.jit
def f(x_ref, y_ref):
  ...

x_ref = core.mutable_array(jnp.arange(3.))
f(x_ref, x_ref)
Traceback (most recent call last):
  File "/usr/local/google/home/mattjj/packages/jax/dougal70.py", line 12, in <module>
    f(x_ref, x_ref)
ValueError: only one reference to a mutable array may be passed as an argument
to a function, but when tracing f at
/usr/local/google/home/mattjj/packages/jax/dougal70.py:7 for jit the mutable
array reference of type Ref{float32[3]} appeared at both x_ref and y_ref.

@jax.jit
def f(params):
  ...

x_ref = core.mutable_array(jnp.arange(3.))
f({'a': x_ref, 'b': {'c': x_ref}})
Traceback (most recent call last):
  File "/usr/local/google/home/mattjj/packages/jax/dougal70.py", line 12, in <module>
    f({'a': x_ref, 'b': {'c': x_ref}})
ValueError: only one reference to a mutable array may be passed as an argument
to a function, but when tracing f at
/usr/local/google/home/mattjj/packages/jax/dougal70.py:7 for jit the mutable
array reference of type Ref{float32[3]} appeared at both params['a'] and
params['b']['c'].

@mattjj mattjj force-pushed the ref-errors-3 branch 6 times, most recently from 8576614 to 9e58a57 Compare December 17, 2024 22:13
@mattjj mattjj requested a review from dougalm December 17, 2024 22:17
@mattjj mattjj marked this pull request as ready for review December 17, 2024 22:18
@mattjj mattjj added the pull ready Ready for copybara import and testing label Dec 17, 2024
@mattjj mattjj force-pushed the ref-errors-3 branch 10 times, most recently from 9c6586b to 122eed9 Compare December 18, 2024 06:55
@copybara-service copybara-service bot merged commit 96d4a75 into jax-ml:main Dec 18, 2024
19 checks passed
@mattjj mattjj deleted the ref-errors-3 branch December 18, 2024 16:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants