Skip to content

Commit

Permalink
ref errors
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Dec 17, 2024
1 parent 4aebe55 commit 45f8d97
Show file tree
Hide file tree
Showing 5 changed files with 163 additions and 75 deletions.
6 changes: 6 additions & 0 deletions jax/_src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,6 +1477,12 @@ def _update_disable_jit_thread_local(val):
upgrade=True,
help='Disable the check from #19009 to enable some custom_vjp hacks.')

mutable_array_checks = bool_state(
name='jax_mutable_array_checks',
default=False,
upgrade=True,
help='Enable error checks for mutable arrays that rule out aliasing.')

xla_runtime_errors = bool_state(
name='jax_experimental_unsafe_xla_runtime_errors',
default=False,
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1946,6 +1946,8 @@ def mutable_array_abstract_eval(init_aval):
def _mutable_array_impl(init_val):
from jax._src.state.types import AbstractRef # pytype: disable=import-error
aval = get_aval(init_val)
# TODO(mattjj): improve spelling of 'defensive copy' here, avoid circular dep
init_val = init_val.copy() if hasattr(init_val, 'copy') else init_val
return MutableArray(AbstractRef(aval), init_val)

def freeze(ref):
Expand Down
44 changes: 25 additions & 19 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -987,21 +987,11 @@ def partial_eval_jaxpr_custom(
ensure_out_inst: bool | Sequence[bool],
saveable: Callable[..., RematCases_],
) -> tuple[Jaxpr, Jaxpr, list[bool], list[bool], int]:
if type(in_inst) is bool:
in_inst = (in_inst,) * len(jaxpr.invars)
if type(ensure_out_unknowns) is bool:
ensure_out_unknowns = (ensure_out_unknowns,) * len(jaxpr.outvars)
if type(ensure_out_inst) is bool:
ensure_out_inst = (ensure_out_inst,) * len(jaxpr.outvars)
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \
_partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns),
tuple(in_inst),
tuple(ensure_out_unknowns),
tuple(ensure_out_inst), saveable)
if num_res_ref > 0:
raise ValueError(
"Cannot use `partial_eval_jaxpr_custom` with stateful jaxprs.")
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res
*outs, num_res_ref = partial_eval_jaxpr_stateful(
jaxpr, in_unknowns, in_inst, ensure_out_unknowns, ensure_out_inst, saveable)
if num_res_ref:
raise ValueError("Cannot use `partial_eval_jaxpr_custom` with stateful jaxprs.")
return *outs, # type: ignore

def partial_eval_jaxpr_stateful(
jaxpr: Jaxpr,
Expand All @@ -1020,10 +1010,9 @@ def partial_eval_jaxpr_stateful(
if saveable is None:
saveable = everything_saveable
jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref = \
_partial_eval_jaxpr_custom_cached(jaxpr, tuple(in_unknowns),
tuple(in_inst),
tuple(ensure_out_unknowns),
tuple(ensure_out_inst), saveable)
_partial_eval_jaxpr_custom_cached(
jaxpr, tuple(in_unknowns), tuple(in_inst), tuple(ensure_out_unknowns),
tuple(ensure_out_inst), saveable)
return jaxpr_known, jaxpr_staged, out_unknowns, out_inst, num_res, num_res_ref

everything_saveable = lambda *_, **__: True
Expand Down Expand Up @@ -2189,12 +2178,29 @@ def trace_to_jaxpr_dynamic(
ans = fun.call_wrapped(*in_tracers)

out_tracers = map(trace.to_jaxpr_tracer, ans)
_check_no_refs(debug_info, [x.aval for x in out_tracers])
jaxpr, consts, attrs_tracked = trace.to_jaxpr(out_tracers)
del trace, fun, in_tracers, out_tracers, ans

config.enable_checks.value and core.check_jaxpr(jaxpr)
return jaxpr, [v.aval for v in jaxpr.outvars], consts, attrs_tracked

def _check_no_refs(
dbg: DebugInfo | None,
avals: Sequence[AbstractValue]
) -> None:
if not config.mutable_array_checks.value: return
for i, a in enumerate(avals):
if isinstance(a, AbstractRef):
if dbg is None:
raise ValueError(
f"function returned a mutable array reference of type {a.str_short()}, "
"but mutable array references cannot be returned.")
raise ValueError(
f"{dbg.func_src_info} traced for {dbg.traced_for} returned a mutable "
f"array reference of type {a.str_short()}, but mutable array "
"references cannot be returned.")

@profiler.annotate_function
def trace_to_jaxpr_dynamic2(
fun: lu.WrappedFun, debug_info: DebugInfo | None = None
Expand Down
111 changes: 61 additions & 50 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,17 +556,14 @@ def _infer_params_impl(
"pjit does not support kwargs when in_shardings is specified.")

if pjit_mesh is not None:
jit_name = 'pjit'
if (ji.backend or ji.device) and not pjit_mesh.empty:
raise ValueError(
"Mesh context manager should not be used with jit when backend or "
"device is also specified as an argument to jit.")
else:
jit_name = 'jit'

axes_specs = _flat_axes_specs(ji.abstracted_axes, *args, **kwargs)

dbg = debug_info(jit_name, ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
ji.static_argnums, ji.static_argnames)
f = lu.wrap_init(fun)
f, res_paths = result_paths(f)
Expand All @@ -593,6 +590,7 @@ def _infer_params_impl(
in_shardings_leaves = out_shardings_leaves = tuple(leaves)
in_shardings_treedef = out_shardings_treedef = treedef
else:
jit_name = 'pjit' if pjit_mesh is not None else 'jit'
in_shardings_leaves = tuple(
_create_sharding_for_array(pjit_mesh, x, 'in_shardings', jit_name)
for x in ji.in_shardings_leaves)
Expand All @@ -607,35 +605,13 @@ def _infer_params_impl(

in_type: core.InputType | tuple[core.AbstractValue, ...]
if config.dynamic_shapes.value:
assert in_avals is None
in_type = pe.infer_lambda_input_type(axes_specs, explicit_args)
in_avals = tuple(a for a, e in in_type if e)
elif in_avals is None:
avals = []
for i, a in enumerate(explicit_args):
try:
avals.append(shaped_abstractify(a))
except OverflowError as e:
arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg
else f"flattened argument number is {i}")
raise OverflowError(
"An overflow was encountered while parsing an argument to a jitted "
f"computation, whose {arg_path}."
) from e
except TypeError as e:
arg_description = (f"path {dbg.arg_names[i]}" if dbg
else f"flattened argument number {i}")
raise TypeError(
f"Error interpreting argument to {fun} as an abstract array."
f" The problematic value is of type {type(a)} and was passed to"
f" the function at {arg_description}.\n"
"This typically means that a jit-wrapped function was called with a non-array"
" argument, and this argument was not marked as static using the"
" static_argnums or static_argnames parameters of jax.jit."
) from e

in_type = in_avals = tuple(avals)
else:
assert isinstance(in_avals, tuple[core.AbstractValue, ...])
in_type = in_avals
assert in_avals is not None

in_shardings_flat, in_layouts_flat = _process_in_axis_resources(
in_shardings_treedef, in_shardings_leaves,
Expand All @@ -652,6 +628,7 @@ def _infer_params_impl(
flat_fun, in_type, attr_token, dbg,
HashableFunction(res_paths, closure=()),
IgnoreKey(ji.inline))
_check_no_aliased_refs(dbg, (*jaxpr.consts, *consts), explicit_args)
_attr_update(flat_fun, in_type, attr_token, attrs_tracked)

out_shardings_flat, out_layouts_flat = _check_and_canonicalize_out_shardings(
Expand Down Expand Up @@ -693,7 +670,6 @@ def _infer_params_impl(
donated_invars, dbg.arg_names if dbg else None, len(consts),
attrs_tracked, abstract_mesh), args_flat


def get_abstract_mesh_from_avals(in_avals):
if not config.sharding_in_types.value:
return None
Expand All @@ -711,9 +687,7 @@ def get_abstract_mesh_from_avals(in_avals):
class InferParamsCacheEntry:
"""Mutable value object for _infer_params_cached."""
__slots__ = ['pjit_params']

pjit_params: PjitParams | None

def __init__(self):
self.pjit_params = None

Expand Down Expand Up @@ -747,34 +721,71 @@ def _infer_params(
resource_env = None
pjit_mesh = None

skip_cache = config.dynamic_shapes.value
if not skip_cache:
signature, dynargs = jax_jit.parse_arguments(
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
ji.static_argnames, tree_util.default_registry)
try:
avals = tuple(shaped_abstractify(a) for a in dynargs)
except (OverflowError, TypeError):
# If we see something we don't understand, use the slow path.
skip_cache = True

if skip_cache:
if config.dynamic_shapes.value: # if dynamic shapes, don't use the cache
p, args_flat = _infer_params_impl(fun, ji, pjit_mesh, resource_env, args,
kwargs, in_avals=None)
return p, p.consts + args_flat

entry = _infer_params_cached(
fun, ji, signature, avals, pjit_mesh, resource_env)
signature, dynargs = jax_jit.parse_arguments(
args, tuple(kwargs.values()), tuple(kwargs.keys()), ji.static_argnums,
ji.static_argnames, tree_util.default_registry)
dbg = debug_info('jit', ji.fun_sourceinfo, ji.fun_signature, args, kwargs,
ji.static_argnums, ji.static_argnames)
avals = _infer_input_type(fun, dbg, dynargs)
entry = _infer_params_cached(fun, ji, signature, avals, pjit_mesh, resource_env)
if entry.pjit_params is None:
p, args_flat = _infer_params_impl(
fun, ji, pjit_mesh, resource_env, args, kwargs, in_avals=avals)
if p.attrs_tracked:
# If there are attrs_tracked, don't use the cache.
if p.attrs_tracked: # if attrs, don't popoulate the cache
return p, p.consts + args_flat
else:
entry.pjit_params = p
entry.pjit_params = p
return entry.pjit_params, entry.pjit_params.consts + dynargs

def _infer_input_type(fun, dbg, explicit_args) -> tuple[core.AbstractValue, ...]:
avals = []
refs: dict[int, int] = {}
for i, x in enumerate(explicit_args):
try:
avals.append(a := shaped_abstractify(x))
except OverflowError as e:
arg_path = (f"argument path is {dbg.arg_names[i]}" if dbg
else f"flattened argument number is {i}")
raise OverflowError(
"An overflow was encountered while parsing an argument to a jitted "
f"computation, whose {arg_path}."
) from e
except TypeError as e:
arg_description = (f"path {dbg.arg_names[i]}" if dbg
else f"flattened argument number {i}")
raise TypeError(
f"Error interpreting argument to {fun} as an abstract array."
f" The problematic value is of type {type(x)} and was passed to"
f" the function at {arg_description}.\n"
"This typically means that a jit-wrapped function was called with a non-array"
" argument, and this argument was not marked as static using the"
" static_argnums or static_argnames parameters of jax.jit."
) from e
if (isinstance(a, AbstractRef) and
(dup_idx := refs.setdefault(id(core.get_referent(x)), i)) != i):
raise ValueError(
"only one reference to a mutable array may be passed as an argument "
f"to a function, but when tracing {dbg.func_src_info} for {dbg.traced_for} "
f"the mutable array reference of type {a.str_short()} appeared at both "
f"{dbg.arg_names[dup_idx]} and {dbg.arg_names[i]}."
if dbg else
f"at both flat index {dup_idx} and flat index {i}")
return tuple(avals)

def _check_no_aliased_refs(dbg, consts, args) -> None:
refs: set[int] = {id(core.get_referent(c)) for c in consts}
for i, x in enumerate(args):
if id(core.get_referent(x)) in refs:
a = shaped_abstractify(x)
raise ValueError(
f"when tracing {dbg.func_src_info} for {dbg.traced_for}, a mutable "
f"array reference of type {a.str_short()} was both closed over and "
f"passed as the argument "
f"{dbg.arg_names[i]}" if dbg else "at flat index {i}")

def _extract_implicit_args(
in_type: Sequence[tuple[core.AbstractValue, bool]],
Expand Down
75 changes: 69 additions & 6 deletions tests/mutable_array_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,10 @@ def f(x_mut):
jaxpr = jax.make_jaxpr(f)(x_mut)
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))

# disabling this test for now. TODO(dougalm): re-enable once we add checks to
# ensure mutable arrays aren't returned or duplicated etc.
# def test_staging_error(self):
# x = jnp.zeros(3)
# with self.assertRaises(Exception):
# jax.jit(core.mutable_array)(x)
def test_staging_error(self):
x = jnp.zeros(3)
with self.assertRaises(Exception):
jax.jit(core.mutable_array)(x)

@parameterized.parameters([True, False])
def test_multiple_inputs_and_outputs(self, jit):
Expand Down Expand Up @@ -244,6 +242,71 @@ def f(x):
expected = 2.0
self.assertAllClose(ans, expected, check_dtypes=False)

def test_defensive_copy(self):
x = jnp.arange(3.)
_ = jax.jit(lambda x_ref: x_ref[...])(core.mutable_array(x))
x + 1 # don't crash


@jtu.with_config(jax_mutable_array_checks=True)
class MutableArrayErrorsTest(jtu.JaxTestCase):
def test_return_from_jit(self):
# TODO improve error message to say what output was problematic, and maybe
# where the offending mutable array was created
with self.assertRaisesRegex(
ValueError, "traced for jit returned a mutable array reference of type"):
jax.jit(core.mutable_array)(jnp.arange(3))

def test_argument_aliases_jit(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, "appeared at both x_ref and y_ref"):
jax.jit(lambda x_ref, y_ref: x_ref[...] + y_ref[...])(x_ref, x_ref)

def test_closure_and_argument_aliases_jit(self):
x_ref = core.mutable_array(0.)
with self.assertRaisesRegex(
ValueError, "closed over and passed as the argument y_ref"):
jax.jit(lambda y_ref: x_ref[...] + y_ref[...])(x_ref)

def test_return_from_scan(self):
with self.assertRaisesRegex(
ValueError, "traced for scan returned a mutable array reference of type"):
jax.lax.scan(lambda c, x: (core.mutable_array(c), x), 0, jnp.arange(3))

# TODO test_argument_aliases_scan
# TODO test_closure_and_argument_aliases_scan

def test_return_from_cond(self):
with self.assertRaisesRegex(
ValueError, "traced for cond returned a mutable array reference of type"):
jax.lax.cond(True, lambda: core.mutable_array(1.0), lambda: core.mutable_array(2.0))

# TODO test_argument_aliases_cond
# TODO test_closure_and_argument_aliases_cond

# TODO test_return_from_custom_jvp/vjp
# TODO test_argument_aliases_custom_jvp/vjp
# TODO tesT_closure_and_argument_aliases_custom_jvp/vjp

# TODO(mattjj): enable when cond works with mutable arrays
# @parameterized.parameters([False, True])
# def test_cond_both_branches_close_over_same_mutable_array(self, jit):
# # see also test_cond_with_ref_reuse in state_test.py
# x_ref = core.mutable_array(0.)
# def f(pred):
# def true_fun():
# x_ref[()] = 1.
# def false_fun():
# x_ref[()] = 2.
# jax.lax.cond(pred, true_fun, false_fun)
# if jit:
# f = jax.jit(f)
# out_true = f(True)
# self.assertAllClose(x_ref[...], 1.)
# out_false = f(False)
# self.assertAllClose(x_ref[...], 2.)


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 45f8d97

Please sign in to comment.