Skip to content

Commit

Permalink
Merge pull request #25442 from jakevdp:raise-to-shaped
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 705556199
  • Loading branch information
Google-ML-Automation committed Dec 12, 2024
2 parents 3f58337 + 40367a9 commit ea63aea
Show file tree
Hide file tree
Showing 25 changed files with 50 additions and 67 deletions.
1 change: 0 additions & 1 deletion jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
AbstractToken = core.AbstractToken
abstract_token = core.abstract_token
canonicalize_shape = core.canonicalize_shape
raise_to_shaped = core.raise_to_shaped

numpy_scalar_types: set[type] = { # pylint: disable=g-bare-generic
dtypes.int4, np.int8, np.int16, np.int32, np.int64,
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/ad_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from jax._src import core
from jax._src import traceback_util
from jax._src.core import Primitive, valid_jaxtype, raise_to_shaped, get_aval
from jax._src.core import Primitive, valid_jaxtype, get_aval
from jax._src.tree_util import register_pytree_node, tree_map
from jax._src.typing import Array, ArrayLike
from jax._src.util import safe_map
Expand Down Expand Up @@ -51,7 +51,7 @@ def zeros_like_aval(aval: core.AbstractValue) -> Array:
aval_zeros_likers: dict[type, Callable[[Any], Array]] = {}

def zeros_like_jaxval(val):
return zeros_like_aval(core.raise_to_shaped(core.get_aval(val)))
return zeros_like_aval(core.get_aval(val))

def instantiate(z: Zero | Array) -> Array:
if isinstance(z, Zero):
Expand All @@ -67,7 +67,7 @@ def __repr__(self) -> str:
return f'Zero({self.aval})'
@staticmethod
def from_primal_value(val: Any) -> Zero:
return Zero(raise_to_shaped(get_aval(val)).to_tangent_aval())
return Zero(get_aval(val).to_tangent_aval())

register_pytree_node(Zero, lambda z: ((), z.aval), lambda aval, _: Zero(aval))

Expand Down
4 changes: 2 additions & 2 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2356,7 +2356,7 @@ def device_put_sharded(shards: Sequence[Any], devices: Sequence[xc.Device]): #
f"len(devices) = {len(devices)}.")

def _device_put_sharded(*xs):
avals = [core.raise_to_shaped(core.get_aval(x)) for x in xs]
avals = [core.get_aval(x) for x in xs]
if not all(a1 == a2 for a1, a2 in zip(avals[:-1], avals[1:])):
a1, a2 = next((a1, a2) for a1, a2 in zip(avals[:-1], avals[1:])
if a1 != a2)
Expand Down Expand Up @@ -2418,7 +2418,7 @@ def device_put_replicated(x: Any, devices: Sequence[xc.Device]): # noqa: F811
"a non-empty sequence.")
def _device_put_replicated(x):
aval = core.unmapped_aval(len(devices), core.no_axis_name, 0,
core.raise_to_shaped(core.get_aval(x)))
core.get_aval(x))
assert isinstance(aval, ShapedArray)
sharding_spec = sharding_specs.create_pmap_sharding_spec(aval.shape)
if config.pmap_no_rank_reduction.value:
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,8 +587,7 @@ def _dtype(x):

def _shaped_abstractify_slow(x):
try:
return core.raise_to_shaped(
x if isinstance(x, core.AbstractValue) else core.get_aval(x))
return x if isinstance(x, core.AbstractValue) else core.get_aval(x)
except TypeError:
pass

Expand Down
23 changes: 10 additions & 13 deletions jax/_src/checkify.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,6 @@ def out_axes_thunk():
error = _reduce_any_error(error)
return error, out_vals

def get_shaped_aval(val):
return core.raise_to_shaped(core.get_aval(val))

def checkify_jaxpr(jaxpr: core.ClosedJaxpr, enabled_errors,
error: Error, *args) -> tuple[Error, list[core.Value]]:
err_vals, err_tree = jtu.tree_flatten(error)
Expand Down Expand Up @@ -760,7 +757,7 @@ def cond_error_check(error: Error, enabled_errors, index, *ops, branches):
# Get the error-effects out of all branches so the cond can be called with
# a merged error with all these effects.
err_vals, err_tree = jtu.tree_flatten(error)
in_avals = map(get_shaped_aval, [*err_vals, *ops])
in_avals = map(core.get_aval, [*err_vals, *ops])
def get_error_effects_from_jaxpr(jxpr):
_, _, effects = jaxpr_to_checkify_jaxpr(jxpr, enabled_errors, err_tree,
*in_avals)
Expand All @@ -770,7 +767,7 @@ def get_error_effects_from_jaxpr(jxpr):
err_vals, err_tree = jtu.tree_flatten(merged_error)

# Update branch jaxprs to be checkified jaxprs.
in_avals = map(get_shaped_aval, [*err_vals, *ops])
in_avals = map(core.get_aval, [*err_vals, *ops])
new_branches, out_trees, _ = unzip3(
jaxpr_to_checkify_jaxpr(
jxpr, enabled_errors, err_tree, *in_avals) for jxpr in branches)
Expand All @@ -792,19 +789,19 @@ def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
num_consts, num_carry, linear, unroll, _split_transpose):

consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
xs_mapped = [core.mapped_aval(length, 0, get_shaped_aval(val)) for val in xs]
xs_mapped = [core.mapped_aval(length, 0, core.get_aval(val)) for val in xs]
# Query body effects to create a merged error containing all effects (such
# that in and out carried error are of the same type).
err_vals, err_tree = jtu.tree_flatten(error)
new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped
_, _, effects = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
err_tree, *new_in_aval)

merged_error = error._add_placeholder_effects(effects)
err_vals, err_tree = jtu.tree_flatten(merged_error)

# Create checked-jaxpr, with the needed pre-processing on the inputs.
new_in_aval = map(get_shaped_aval, [*err_vals, *consts, *carry]) + xs_mapped
new_in_aval = map(core.get_aval, [*err_vals, *consts, *carry]) + xs_mapped
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
err_tree, *new_in_aval)

Expand Down Expand Up @@ -840,7 +837,7 @@ def new_body_f(*c_consts_and_vals):
*body_jaxpr.in_avals])
closed_jaxpr = pe.close_jaxpr(jaxpr)
err_vals, err_tree = jtu.tree_flatten(error)
err_vals = map(get_shaped_aval, err_vals)
err_vals = map(core.get_aval, err_vals)
flat_err_and_in_vals = [*err_vals, *c_consts_avals, *body_jaxpr.in_avals]
jaxpr, out_tree, error_effects = jaxpr_to_checkify_jaxpr(
closed_jaxpr, enabled_errors, err_tree, *flat_err_and_in_vals)
Expand Down Expand Up @@ -882,7 +879,7 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)

cond_in_flat = [*err_vals, *c_consts, *carry]
cond_in_flat = map(get_shaped_aval, cond_in_flat)
cond_in_flat = map(core.get_aval, cond_in_flat)
checked_cond_jaxpr, _, _ = jaxpr_to_checkify_jaxpr(cond_jaxpr, enabled_errors,
err_tree, *cond_in_flat)
compat_cond_jaxpr_ = ignore_error_output_jaxpr(checked_cond_jaxpr, num_error_vals)
Expand All @@ -906,7 +903,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
# jaxpr to checked_jaxpr
err_vals, err_tree = jtu.tree_flatten(error)
new_vals_in = [*err_vals, *vals_in]
in_avals = tuple(map(get_shaped_aval, new_vals_in))
in_avals = tuple(map(core.get_aval, new_vals_in))
checked_jaxpr, out_tree, _ = jaxpr_to_checkify_jaxpr(jaxpr, enabled_errors,
err_tree, *in_avals)

Expand Down Expand Up @@ -942,7 +939,7 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
def remat_error_check(error, enabled_errors, *vals_in, jaxpr, **params):
err_vals, err_tree = jtu.tree_flatten(error)
new_vals_in = [*err_vals, *vals_in]
in_avals = tuple(map(get_shaped_aval, new_vals_in))
in_avals = tuple(map(core.get_aval, new_vals_in))
checked_jaxpr_, out_tree, _ = jaxpr_to_checkify_jaxpr(
pe.close_jaxpr(jaxpr), enabled_errors, err_tree, *in_avals)
checked_jaxpr, () = checked_jaxpr_.jaxpr, checked_jaxpr_.consts
Expand All @@ -963,7 +960,7 @@ def shard_map_error_check(
# Replicated sharding for in errors.
new_in_names = (*([{}] * num_error_vals), *in_names)
new_vals_in = [*err_vals, *vals_in]
in_avals = list(map(get_shaped_aval, new_vals_in))
in_avals = list(map(core.get_aval, new_vals_in))
for i, v in enumerate(in_avals):
if not (sharder := core.shard_aval_handlers.get(type(v))):
raise ValueError(f'Unsupported aval type: {type(v)}')
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_batching.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __call__(self, *args, **kwargs):
"using def_vmap.")
args_flat, in_tree = tree_flatten(args)
flat_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(self.fun), in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
in_avals = [core.get_aval(x) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, out_tree, False, "custom_vmap")
jaxpr, _, consts, () = pe.trace_to_jaxpr_dynamic(flat_fun, in_avals, debug)
closed_call = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
Expand Down
9 changes: 3 additions & 6 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1051,7 +1051,7 @@ def rev(objective_fn, res, g):
from the closure.
"""
flat_args, in_tree = tree_flatten(example_args)
in_avals = tuple(map(abstractify, flat_args))
in_avals = tuple(map(core.get_aval, flat_args))
if config.check_tracer_leaks.value:
return _closure_convert_for_avals.__wrapped__(fun, in_tree, in_avals)
else:
Expand Down Expand Up @@ -1111,9 +1111,6 @@ def merge(l1, l2):
return [next(i2 if snd else i1) for snd in which]
return out, merge

def abstractify(x):
return core.get_aval(x)


### Custom transposition

Expand Down Expand Up @@ -1209,8 +1206,8 @@ def linear_call(fun: Callable, fun_transpose: Callable, residual_args,
f_in_tree = treedef_tuple((res_tree, lin_tree))
f, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), f_in_tree)

res_avals = map(abstractify, operands_res)
lin_avals = map(abstractify, operands_lin)
res_avals = map(core.get_aval, operands_res)
lin_avals = map(core.get_aval, operands_lin)
f_jaxpr, f_consts = _initial_style_jaxpr(f, (*res_avals, *lin_avals))
f_jaxpr = _close_jaxpr(f_jaxpr)
out_avals = f_jaxpr.out_avals
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/custom_partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def __call__(self, *args, **kwargs):
f_, dyn_args = lu.wrap_init(self.fun), args
args_flat, in_tree = tree_util.tree_flatten(dyn_args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(f_, in_tree)
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args_flat]
in_avals = [core.get_aval(x) for x in args_flat]
debug = pe.debug_info(self.fun, in_tree, out_tree, False,
"custom_partitioning")
mesh = mesh_lib.thread_resources.env.physical_mesh
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
"""Returns the shape and dtype of a jax.Array or a j"""
if isinstance(a, jax.ShapeDtypeStruct):
return a.shape, a.dtype
aval = core.raise_to_shaped(core.get_aval(a))
aval = core.get_aval(a)
return aval.shape, aval.dtype


Expand Down
2 changes: 1 addition & 1 deletion jax/_src/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -1504,7 +1504,7 @@ def shape_and_dtype_jax_array(a) -> tuple[Sequence[int | None], DType]:
"""Returns the shape and dtype of a jax.Array or a j"""
if isinstance(a, jax.ShapeDtypeStruct):
return a.shape, a.dtype
aval = core.raise_to_shaped(core.get_aval(a))
aval = core.get_aval(a)
return aval.shape, aval.dtype


Expand Down
2 changes: 1 addition & 1 deletion jax/_src/extend/ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def ffi_call(
f"custom_call_api_version < 4; got {custom_call_api_version}.")

def wrapped(*args: ArrayLike, **kwargs: Any):
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
in_avals = [core.get_aval(x) for x in args]

if input_layouts is None:
static_input_layouts = tuple(map(_convert_layout_for_lowering, in_avals))
Expand Down
5 changes: 1 addition & 4 deletions jax/_src/lax/control_flow/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,6 @@
effects.control_flow_allowed_effects.add_type(lax.InOutFeedEffect)


def _abstractify(x):
return core.raise_to_shaped(core.get_aval(x))

def _typecheck_param(prim, param, name, msg_required, pred):
if not pred:
msg = (f'invalid {prim} param {name} of type {type(param).__name__}, '
Expand Down Expand Up @@ -91,7 +88,7 @@ def _initial_style_jaxprs_with_common_consts(
return [], [], []

jaxprs, all_consts, all_out_trees, all_attrs_tracked = zip(*jaxpr_data)
all_const_avals = [map(_abstractify, consts) for consts in all_consts]
all_const_avals = [map(core.get_aval, consts) for consts in all_consts]
# If we get a `Ref` in the consts, we know it must come from an outer
# `run_state`. We also know if shouldn't be boxed up in another tracer.
# We assert that it is in fact a DynamicJaxprTracer
Expand Down
5 changes: 2 additions & 3 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import numpy as np

from jax._src.lax.control_flow.common import (
_abstractify,
_avals_short,
_check_tree_and_avals,
_initial_style_jaxprs_with_common_consts,
Expand Down Expand Up @@ -135,7 +134,7 @@ def switch(index, branches, *operands):
return branches[int(index)](*operands)

ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(_abstractify, ops))
ops_avals = tuple(map(core.get_aval, ops))

jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
branches, ops_tree, ops_avals, primitive_name='switch')
Expand Down Expand Up @@ -227,7 +226,7 @@ def cond(pred, true_fun, false_fun, *operands):
return false_fun(*operands)

ops, ops_tree = tree_flatten(operands)
ops_avals = tuple(map(_abstractify, ops))
ops_avals = tuple(map(core.get_aval, ops))

jaxprs, consts, out_trees = _initial_style_jaxprs_with_common_consts(
(true_fun, false_fun), ops_tree, ops_avals, 'cond')
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/lax/control_flow/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from jax._src.util import (partition_list, merge_lists, safe_map, safe_zip,
split_list, split_dict, weakref_lru_cache)
from jax._src.lax.control_flow import loops
from jax._src.lax.control_flow.common import _abstractify, _initial_style_jaxpr
from jax._src.lax.control_flow.common import _initial_style_jaxpr
import numpy as np

## JAX utilities
Expand Down Expand Up @@ -196,7 +196,7 @@ def _create_jaxpr(init):
init_flat = tree_leaves(init)
_, in_tree = tree_flatten((init, xs))

carry_avals = tuple(map(_abstractify, init_flat))
carry_avals = tuple(map(core.get_aval, init_flat))
jaxpr, _, out_tree = _initial_style_jaxpr(
f, in_tree, carry_avals + x_avals, "scan")
return jaxpr, out_tree
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
from jax._src.lax import slicing
from jax._src.lax import windowed_reductions
from jax._src.lax.control_flow.common import (
_abstractify, _avals_short, _initial_style_jaxpr,
_avals_short, _initial_style_jaxpr,
_initial_style_jaxpr_attrs, _make_closed_jaxpr_attrs, _prune_zeros,
_typecheck_param)
from jax._src.lax.other import logaddexp
Expand Down Expand Up @@ -275,7 +275,7 @@ def _create_jaxpr(init):
init_flat, init_tree = tree_flatten(init)
in_flat, in_tree = tree_flatten((init, xs))

carry_avals = tuple(_map(_abstractify, init_flat))
carry_avals = tuple(_map(core.get_aval, init_flat))
jaxpr, consts, out_tree, attrs_tracked = _initial_style_jaxpr_attrs(
f, in_tree, (*carry_avals, *x_avals), "scan")
out_tree_children = out_tree.children()
Expand Down Expand Up @@ -361,7 +361,7 @@ def _check_carry_type(name, body_fun, in_carry, out_carry_tree, out_avals):
if p else 'the input carry')
leaves_and_paths, in_carry_tree = tree_flatten_with_path(in_carry)
paths, in_carry_flat = unzip2(leaves_and_paths)
in_avals = _map(_abstractify, in_carry_flat)
in_avals = _map(core.get_aval, in_carry_flat)
if in_carry_tree != out_carry_tree:
try:
out_carry = tree_unflatten(out_carry_tree, out_avals)
Expand Down Expand Up @@ -1321,7 +1321,7 @@ def while_loop(cond_fun, body_fun, init_val):

def _create_jaxpr(init_val):
init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(_abstractify, init_vals))
init_avals = tuple(_map(core.get_aval, init_vals))
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(
cond_fun, in_tree, init_avals, "while_cond")
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(
Expand Down
5 changes: 2 additions & 3 deletions jax/_src/lax/control_flow/solves.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
import numpy as np

from jax._src.lax.control_flow.common import (
_abstractify,
_check_tree,
_initial_style_jaxpr,
)
Expand Down Expand Up @@ -87,7 +86,7 @@ def custom_root(f, initial_guess, solve, tangent_solve, has_aux=False):
implicit differentiation assuming ``f(solve(f, initial_guess)) == 0``.
"""
guess_flat, in_args_tree = tree_flatten((initial_guess,))
guess_avals = tuple(_map(_abstractify, guess_flat))
guess_avals = tuple(_map(core.get_aval, guess_flat))
f_jaxpr, f_consts, out_tree = _initial_style_jaxpr(
f, in_args_tree, guess_avals)

Expand Down Expand Up @@ -230,7 +229,7 @@ def custom_linear_solve(
transpose_solve = solve

b_flat, in_args_tree = tree_flatten((b,))
b_avals = tuple(_map(_abstractify, b_flat))
b_avals = tuple(_map(core.get_aval, b_flat))

tree, = treedef_children(in_args_tree)

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pallas/mosaic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def roll(
@roll_p.def_abstract_eval
def _roll_abstract_eval(x, shift, **_):
del shift
return jax_core.raise_to_shaped(x)
return x


def _roll_lowering_rule(
Expand Down
Loading

0 comments on commit ea63aea

Please sign in to comment.