Skip to content

Commit

Permalink
improve checkpoint / remat concreteness error with static_argnums
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Dec 18, 2024
1 parent 09fdd0d commit 9acd4a9
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 77 deletions.
22 changes: 11 additions & 11 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@
from jax._src import effects
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src.api_util import flatten_fun, shaped_abstractify
from jax._src.api_util import (
flatten_fun, shaped_abstractify, debug_info, fun_sourceinfo, fun_signature)
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
Expand All @@ -41,7 +42,7 @@
from jax._src.lax import convolution as lax_convolution
from jax._src.lib.mlir.dialects import hlo
from jax._src.traceback_util import api_boundary
from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure, keystr
from jax._src.tree_util import tree_flatten, tree_unflatten, tree_structure
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
safe_zip, merge_lists, weakref_lru_cache)

Expand Down Expand Up @@ -320,10 +321,12 @@ def foo(x, y):
@wraps(fun)
@api_boundary
def fun_remat(*args, **kwargs):
debug = debug_info("checkpoint / remat", fun_sourceinfo(fun),
fun_signature(fun), args, kwargs, static_argnums, ())
fun_, args = _remat_static_argnums(fun, static_argnums, args)
args_flat, in_tree = tree_flatten((args, kwargs))
in_avals = [shaped_abstractify(x) for x in args_flat]
jaxpr, consts, out_tree = _trace_to_jaxpr(fun_, in_tree, tuple(in_avals))
jaxpr, consts, out_tree = _trace_to_jaxpr(fun_, in_tree, tuple(in_avals), debug)
out_flat = remat_p.bind(
*consts, *args_flat, jaxpr=jaxpr, prevent_cse=prevent_cse,
differentiated=False, policy=policy)
Expand Down Expand Up @@ -409,9 +412,8 @@ def new_fun(*dyn_args, **kwargs):
# This helper is similar to those in control_flow/common.py, but with
# remat-specific errors.
@weakref_lru_cache
def _trace_to_jaxpr(fun, in_tree, in_avals):
def _trace_to_jaxpr(fun, in_tree, in_avals, debug):
flat_fun, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
debug = pe.debug_info(fun, in_tree, out_tree, True, "checkpoint")
try:
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
except core.ConcretizationTypeError as e:
Expand Down Expand Up @@ -445,10 +447,9 @@ def f_(*args):
out_tree = lambda: tree_structure(out_shape)
assert len(jaxpr.invars) == len(in_leaves)
dbg = pe.debug_info(f, in_tree, out_tree, True, "saved_residuals")
arg_info = pe.arg_info_all(dbg)
return _saved_residuals(jaxpr, arg_info)
return _saved_residuals(jaxpr, dbg.arg_names) # type: ignore

def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]:
def _saved_residuals(jaxpr, arg_names) -> list[tuple[core.AbstractValue, str]]:
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}

Expand All @@ -463,9 +464,8 @@ def _saved_residuals(jaxpr, arg_info) -> list[tuple[core.AbstractValue, str]]:

for i, v in enumerate(jaxpr.invars):
if v in res_vars:
if arg_info is not None:
arg_name, arg_path = arg_info[i]
src = f'from the argument {arg_name}{keystr(arg_path)}'
if arg_names is not None:
src = f'from the argument {arg_names[i]}'
else:
src = 'from the argument at flattened index {i}'
results.append((v.aval, src))
Expand Down
96 changes: 36 additions & 60 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from collections.abc import Callable, Sequence, Hashable
from contextlib import contextmanager
from functools import partial
import inspect
import itertools as it
import operator as op
from typing import Any, NamedTuple, Union
Expand Down Expand Up @@ -46,7 +45,7 @@
InputType, OutputType, get_referent, JaxprEqnContext)
from jax._src.state.types import AbstractRef
from jax._src.tree_util import (PyTreeDef, treedef_tuple, tree_unflatten,
tree_flatten, tree_structure, KeyPath, generate_key_paths,
tree_flatten, tree_structure, generate_key_paths,
keystr)
from jax._src.util import (unzip2, safe_zip, safe_map, toposort, split_list,
merge_lists, partition_list, OrderedSet,
Expand Down Expand Up @@ -1529,8 +1528,7 @@ class DynamicJaxprTracer(core.Tracer):
def __init__(self, trace, aval, line_info=None):
self._trace = trace
self._line_info = line_info
# Needed for UnexpectedTracerError.
self._debug_info = self._trace.frame.debug_info
self._debug_info = self._trace.frame.debug_info # for UnexpectedTracerError
self.aval = aval

def full_lower(self):
Expand All @@ -1551,11 +1549,11 @@ def _origin_msg(self):

origin = ("The error occurred while tracing the function "
f"{dbg.func_src_info or '<unknown>'} for {dbg.traced_for}. ")
arg_info = arg_info_all(dbg)
# TODO(mattjj): figure out when not (invar_pos < len(arg_info))
if invar_pos and arg_info and all(i < len(arg_info) for i in invar_pos):
arg_info = [arg_info[i] for i in invar_pos]
arg_names = [f'{name}{keystr(path)}' for name, path in arg_info]
if invar_pos and dbg.arg_names:
try:
arg_names = [dbg.arg_names[i] for i in invar_pos]
except IndexError:
return "" # TODO(mattjj): figure out when not (invar_pos < len(arg_info))
if len(arg_names) == 1:
arg_info_str = f"the argument {arg_names[0]}"
elif len(arg_names) == 2:
Expand Down Expand Up @@ -1632,7 +1630,7 @@ class JaxprStackFrame:
attrs_tracked: list[tuple[Any, str]]
attrs_inits: list
attrs_vars: list[Var]
debug_info: DebugInfo | None
debug_info: lu.TracingDebugInfo | None

def __init__(self):
self.gensym = core.gensym()
Expand Down Expand Up @@ -2116,64 +2114,42 @@ def _jvp_jaxpr_zeros(f, store, in_zeros, zero_avals, *primal_tangent_avals):
store.store(out_zeros)
return [*out_primals, *out_nz_tangents]

# TODO(mattjj): remove this DebugInfo and helper functions, replace with
# api_util.py versions

class DebugInfo(NamedTuple):
func_src_info: str | None # f'{fun.__name__} at {filename}:{lineno}'
signature: inspect.Signature | None # inspect.signature(fun)
in_tree: PyTreeDef | None # caller/constructor might not have this info
out_tree: Callable[[], PyTreeDef] | None # lazy, not avail at trace time
has_kwargs: bool # whether in_tree corresponds to (args, kwargs) or args
traced_for: str # "jit", "scan", "make_jaxpr", etc

def debug_info(fn: Callable, in_tree: PyTreeDef | None,
out_tree_thunk: Callable[[], PyTreeDef] | None,
has_kwargs: bool, traced_for: str) -> DebugInfo:
sig = api_util.fun_signature(fn)
# Callers should be using linear_util.debug_info instead!
def debug_info(
fn: Callable,
in_tree: PyTreeDef | None,
out_tree_thunk: Callable[[], PyTreeDef] | None,
has_kwargs: bool,
traced_for: str
) -> lu.TracingDebugInfo | None:
src_info = fun_sourceinfo(fn)
return DebugInfo(src_info, sig, in_tree, out_tree_thunk, has_kwargs,
traced_for)

def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> DebugInfo:
"Make a DebugInfo from data available to final-style primitives like pmap."
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)

def arg_info_all(dbg: DebugInfo) -> list[tuple[str, KeyPath]] | None:
ba = None if dbg.in_tree is None else sig_info(dbg)
if ba is None: return None
return [(name, key_path) for name, dummy_arg in ba.arguments.items()
for key_path, _ in generate_key_paths(dummy_arg)]

def sig_info(dbg: DebugInfo) -> inspect.BoundArguments | None:
if dbg.in_tree is None or dbg.signature is None: return None
try:
dummy_args = tree_unflatten(dbg.in_tree, [False] * dbg.in_tree.num_leaves)
except:
return None
args, kwargs = dummy_args if dbg.has_kwargs else (dummy_args, {})
try:
return dbg.signature.bind(*args, **kwargs)
except (TypeError, ValueError):
return None

def result_info(dbg: DebugInfo) -> list[KeyPath] | None:
if dbg.out_tree is None: return None
try:
num_leaves = dbg.out_tree().num_leaves
dummy_result = tree_unflatten(dbg.out_tree(), [False] * num_leaves)
dummy_args = tree_unflatten(in_tree, [False] * in_tree.num_leaves) # type: ignore
args, kwargs = dummy_args if has_kwargs else (dummy_args, {})
ba = api_util.fun_signature(fn).bind(*args, **kwargs) # type: ignore
arg_names = tuple(f'{name}{keystr(path)}' for name, dummy in ba.arguments.items()
for path, _ in generate_key_paths(dummy))
except:
return None
else:
return [path for path, _ in generate_key_paths(dummy_result)]
arg_names = None
def result_paths():
try:
out_tree = out_tree_thunk()
dummy_result = tree_unflatten(out_tree, [False] * out_tree.num_leaves)
except:
return None
return tuple(path for path, _ in generate_key_paths(dummy_result))
return lu.TracingDebugInfo(traced_for, src_info, arg_names, result_paths) # type: ignore

def debug_info_final(fn: lu.WrappedFun, traced_for: str) -> lu.TracingDebugInfo | None:
in_tree, out_tree, has_kws = flattened_fun_in_tree(fn) or (None, None, False)
return debug_info(fn.f, in_tree, out_tree, has_kws, traced_for)


@profiler.annotate_function
def trace_to_jaxpr_dynamic(
fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
debug_info: DebugInfo | None = None,
debug_info: lu.TracingDebugInfo | None = None,
*,
keep_inputs: list[bool] | None = None,
) -> tuple[Jaxpr, list[AbstractValue], list[Any],
Expand All @@ -2197,7 +2173,7 @@ def trace_to_jaxpr_dynamic(

@profiler.annotate_function
def trace_to_jaxpr_dynamic2(
fun: lu.WrappedFun, debug_info: DebugInfo | None = None
fun: lu.WrappedFun, debug_info: lu.TracingDebugInfo | None = None
) -> tuple[Jaxpr, OutputType, list[Any]]:

trace = DynamicJaxprTrace()
Expand Down
1 change: 0 additions & 1 deletion jax/_src/linear_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,6 @@ def valid_size(d) -> bool:
class TracingDebugInfo(NamedTuple):
# Packages up trace/staging-time debug info about a func and its parameters,
# formed just before staging to a jaxpr and read in trace-time error messages.
# TODO(mattjj): delete partial_eval.DebugInfo, replace all uses with this cls
traced_for: str # e.g. 'jit', 'scan', etc
func_src_info: str | None # e.g. f'{fun.__name__} at {filename}:{lineno}'
arg_names: tuple[str, ...] # e.g. ('args[0]', ... )
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def to_block_mapping(
"pallas_call index_map",
)
index_map_src_info = NameAndSrcInfo.from_pallas_call(
None, debug.func_src_info
None, debug.func_src_info # type: ignore
)
with tracing_grid_env(grid, mapped_dims):
jaxpr, out_avals, consts, () = pe.trace_to_jaxpr_dynamic(
Expand Down
4 changes: 0 additions & 4 deletions jax/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
ConstFoldRule as ConstFoldRule,
ConstVar as ConstVar,
DCERule as DCERule,
DebugInfo as DebugInfo,
DynamicJaxprTrace as DynamicJaxprTrace,
DynamicJaxprTracer as DynamicJaxprTracer,
ForwardingRule as ForwardingRule,
Expand All @@ -40,7 +39,6 @@
TracerId as TracerId,
Val as Val,
abstract_eval_fun as abstract_eval_fun,
arg_info_all as arg_info_all,
call_padding_rule as call_padding_rule,
call_param_updaters as call_param_updaters,
call_partial_eval_custom_rule as call_partial_eval_custom_rule,
Expand Down Expand Up @@ -79,8 +77,6 @@
partial_eval_wrapper_nounits as partial_eval_wrapper_nounits,
partition_pvals as partition_pvals,
recipe_to_eqn as recipe_to_eqn,
result_info as result_info,
sig_info as sig_info,
trace_to_jaxpr_dynamic as _trace_to_jaxpr_dynamic,
trace_to_jaxpr_dynamic2 as trace_to_jaxpr_dynamic2,
trace_to_jaxpr_nounits as trace_to_jaxpr_nounits,
Expand Down
15 changes: 15 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6507,6 +6507,21 @@ def f(x):
else:
assert False

def test_concreteness_error_includes_user_code_with_static_argnums(self):
@partial(jax.remat, static_argnums=(1,))
def f(x, _):
if x > 0:
return x
else:
return jnp.sin(x)

try:
f(3., 1.)
except TracerBoolConversionError:
self.assertIn('x > 0', traceback.format_exc())
else:
assert False


@jtu.with_config(jax_pprint_use_color=False)
class JaxprTest(jtu.JaxTestCase):
Expand Down

0 comments on commit 9acd4a9

Please sign in to comment.