From edecb7e09f2db7931bd49e1335a97aca680cd812 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 13 Dec 2024 10:45:56 -0800 Subject: [PATCH] Migrate from jax.core to jax.extend.core for several deprecated symbols A number of symbols from jax.core are deprecated as of recent JAX releases; some of them are newly available in jax.extend.core. PiperOrigin-RevId: 705932044 --- .../_src/curvature_blocks/curvature_block.py | 7 ++- kfac_jax/_src/layers_and_loss_tags.py | 26 ++++---- kfac_jax/_src/tag_graph_matcher.py | 61 ++++++++----------- kfac_jax/_src/tracer.py | 15 ++--- pyproject.toml | 8 +-- tests/models.py | 8 +-- tests/test_graph_matcher.py | 13 ++-- 7 files changed, 64 insertions(+), 74 deletions(-) diff --git a/kfac_jax/_src/curvature_blocks/curvature_block.py b/kfac_jax/_src/curvature_blocks/curvature_block.py index 9d54c29..3d0243f 100644 --- a/kfac_jax/_src/curvature_blocks/curvature_block.py +++ b/kfac_jax/_src/curvature_blocks/curvature_block.py @@ -17,6 +17,7 @@ from typing import Any, Sequence import jax +import jax.extend as jex import jax.numpy as jnp import jax.scipy from kfac_jax._src import layers_and_loss_tags as tags @@ -81,7 +82,7 @@ def name(self) -> str: @property def layer_tag_primitive(self) -> tags.LayerTag: - """The :class:`jax.core.Primitive` corresponding to the block's tag equation.""" + """The :class:`jex.core.Primitive` corresponding to the block's tag equation.""" primitive = self._layer_tag_eq.primitive assert isinstance(primitive, tgm.tags.LayerTag) @@ -89,14 +90,14 @@ def layer_tag_primitive(self) -> tags.LayerTag: return primitive @property - def parameter_variables(self) -> tuple[jax.core.Var, ...]: + def parameter_variables(self) -> tuple[jex.core.Var, ...]: """The parameter variables of the underlying Jax equation.""" param_vars = [] for p in tags.layer_eqn_data(self._layer_tag_eq).params: - assert isinstance(p, jax.core.Var) + assert isinstance(p, jex.core.Var) param_vars.append(p) return tuple(param_vars) diff --git a/kfac_jax/_src/layers_and_loss_tags.py b/kfac_jax/_src/layers_and_loss_tags.py index 48ed380..e80a2eb 100644 --- a/kfac_jax/_src/layers_and_loss_tags.py +++ b/kfac_jax/_src/layers_and_loss_tags.py @@ -17,7 +17,7 @@ from typing import Any, Generic, Sequence, TypeVar import jax -from jax import core +import jax.extend as jex # Types for annotation @@ -94,7 +94,7 @@ def get_loss_outputs( return tuple(kwargs[name] for name in meta.parameter_dependants) -class LossTag(core.Primitive): +class LossTag(jex.core.Primitive): """A Jax primitive for tagging K-FAC losses. The primitive is no-op at runtime, however its goal is to tag (annotate) the @@ -103,7 +103,7 @@ class LossTag(core.Primitive): curvature matrix. """ - # Whether the primitive returns multiple outputs (from core.Primitive) + # Whether the primitive returns multiple outputs (from jex.core.Primitive) multiple_results = True def __init__(self): @@ -175,9 +175,9 @@ def _batching( def loss_eqn_parameter_dependants( - eqn: jax.core.JaxprEqn, + eqn: jex.core.JaxprEqn, raise_an_error: bool = True, -) -> list[jax.core.Var]: +) -> list[jex.core.Var]: """Returns the parameter dependants variables from the give loss equation.""" if not isinstance(eqn.primitive, LossTag): if raise_an_error: @@ -192,7 +192,7 @@ def loss_eqn_parameter_dependants( def loss_eqn_construct_loss( - eqn: jax.core.JaxprEqn, + eqn: jex.core.JaxprEqn, *args: Array, ) -> Any: """Constructs an instance of the corresponding :class:`~LossFunction` class.""" @@ -206,7 +206,7 @@ def loss_eqn_construct_loss( return meta.loss_class(**kwargs) -def loss_eqn_class_name(eqn: jax.core.JaxprEqn) -> str: +def loss_eqn_class_name(eqn: jex.core.JaxprEqn) -> str: """The name of the underlying `~LossFunction` class.""" if not isinstance(eqn.primitive, LossTag): @@ -253,7 +253,7 @@ def get_and_verify_layer_meta( return meta -class LayerTag(core.Primitive): +class LayerTag(jex.core.Primitive): """A Jax primitive for tagging K-FAC layers. The primitive is no-op at runtime, however its goal is to tag (annotate) the @@ -347,9 +347,9 @@ def _batching( def layer_eqn_data( # pytype: disable=invalid-annotation - eqn: jax.core.JaxprEqn, + eqn: jex.core.JaxprEqn, raise_an_error: bool = True, -) -> LayerData[jax.core.Var]: +) -> LayerData[jex.core.Var]: if isinstance(eqn.primitive, LayerTag): return eqn.primitive.layer_data(eqn.invars, eqn.params, str(eqn)) @@ -360,7 +360,7 @@ def layer_eqn_data( # pytype: disable=invalid-annotation return LayerData(inputs=(), outputs=(), params=()) -def layer_eqn_name(eqn: jax.core.JaxprEqn) -> str: +def layer_eqn_name(eqn: jex.core.JaxprEqn) -> str: meta = get_and_verify_layer_meta(eqn.invars, eqn.params) if meta.name is None: raise ValueError("Layer name must be provided at this stage.") @@ -460,11 +460,11 @@ def register_scale_and_shift( ) -class LossTagEqn(core.JaxprEqn): +class LossTagEqn(jex.core.JaxprEqn): """A class used only for annotation purposes.""" primitive: LossTag -class LayerTagEqn(core.JaxprEqn): +class LayerTagEqn(jex.core.JaxprEqn): """A class used only for annotation purposes.""" primitive: LayerTag diff --git a/kfac_jax/_src/tag_graph_matcher.py b/kfac_jax/_src/tag_graph_matcher.py index 74204fc..86de22f 100644 --- a/kfac_jax/_src/tag_graph_matcher.py +++ b/kfac_jax/_src/tag_graph_matcher.py @@ -23,13 +23,7 @@ from absl import logging import immutabledict import jax - -jax_version = ( - jax.__version_info__ if hasattr(jax, "__version_info__") - else tuple(map(int, jax.__version__.split(".")))) - -if jax_version > (0, 4, 11): - import jax.extend as jax_extend # pylint: disable=g-import-not-at-top +import jax.extend as jex import jax.numpy as jnp # pylint: disable=g-import-not-at-top from kfac_jax._src import layers_and_loss_tags as tags @@ -42,11 +36,11 @@ # Types for annotation Array = utils.Array PyTreeDef = utils.PyTreeDef -Var = jax.core.Var +Var = jex.core.Var Vars = Sequence[Var] -Jaxpr = jax.core.Jaxpr -ClosedJaxpr = jax.core.ClosedJaxpr -JaxprEqn = jax.core.JaxprEqn +Jaxpr = jex.core.Jaxpr +ClosedJaxpr = jex.core.ClosedJaxpr +JaxprEqn = jex.core.JaxprEqn JaxprEqns = Sequence[JaxprEqn] T = TypeVar("T") J = TypeVar("J", Jaxpr, ClosedJaxpr) @@ -64,10 +58,7 @@ def eval_jaxpr_eqn(eqn: JaxprEqn, in_values: list[T]) -> list[T]: subfuns, bind_params = eqn.primitive.get_bind_params(eqn.params) - if jax_version > (0, 4, 11): - user_context = jax_extend.source_info_util.user_context - else: - user_context = jax.core.source_info_util.user_context # pytype: disable=module-attr + user_context = jex.source_info_util.user_context with user_context(eqn.source_info.traceback): output = eqn.primitive.bind(*subfuns, *in_values, **bind_params) @@ -245,9 +236,9 @@ class JaxprGraph: it. manual_registrations: Any layer tag equations that have been manually registered. - jaxpr: The underlying :class:`jax.core.Jaxpr` part of ``self.closed_jaxpr``. + jaxpr: The underlying :class:`Jaxpr` part of ``self.closed_jaxpr``. consts: The underlying constants part ``self.closed_jaxpr``. - outvars: The output variables of the underlying :class:`jax.core.Jaxpr` part + outvars: The output variables of the underlying :class:`Jaxpr` part of ``self.closed_jaxpr``. """ name: str @@ -294,7 +285,7 @@ def sub_graph_eqns( eqns.append(next_eqn) for v in next_eqn.invars: - if (not isinstance(v, jax.core.Literal) and v not in root_vars and + if (not isinstance(v, jex.core.Literal) and v not in root_vars and v not in processed_vars and v in self.var_to_creation_op): to_process_eqns.append(self.var_to_creation_op[v]) processed_vars.add(v) @@ -383,7 +374,7 @@ def make_jax_graph( eqns.append(eqn) sub_graph_vars.update( - v for v in eqn.invars if not isinstance(v, jax.core.Literal) + v for v in eqn.invars if not isinstance(v, jex.core.Literal) ) consts_i = [ @@ -461,8 +452,8 @@ class GraphPattern: in_values_preprocessor: A function that can optionally modify the in_vals passed to the tag_primitive, from those that are usually the input to the jaxpr. - jaxpr: The underlying :class:`jax.core.Jaxpr` represented by the pattern. - param_vars: The list of :class:`jax.core.Var` that correspond to parameters + jaxpr: The underlying :class:`Jaxpr` represented by the pattern. + param_vars: The list of :class:`Var` that correspond to parameters in the pattern. graph: A :class:`JaxprGraph` representation of the pattern. """ @@ -633,7 +624,7 @@ def add_vars_if_possible( If at least one of the pattern variables is a parameter, but the corresponding graph variable is not or vise-versa, the method does not update the current variables map and returns ``False``. Similarly, if at - least one of the graph variables is a :class:`~jax.core.Literal` (meaning a + least one of the graph variables is a :class:`iteral` (meaning a constant, independent of the function inputs) and the corresponding pattern variable is not an input to the pattern, it returns ``False``. In all other cases it updates the map and returns ``True``. @@ -648,12 +639,12 @@ def add_vars_if_possible( """ for var1, var2 in zip(eqn_vars, graph_vars): - var2_matchable = isinstance(var2, jax.core.Var) and ( + var2_matchable = isinstance(var2, jex.core.Var) and ( var2 in matchable_graph_params) if (var1 in param_variables and not var2_matchable or var1 not in param_variables and var2_matchable or - (isinstance(var2, jax.core.Literal) and var1 not in input_vars)): + (isinstance(var2, jex.core.Literal) and var1 not in input_vars)): return False current_variables_map.update(zip(eqn_vars, graph_vars)) @@ -788,7 +779,7 @@ def match_pattern( for k, v in match_variables_map.items(): if (k not in pattern.graph.jaxpr.invars and - not isinstance(v, jax.core.Literal)): + not isinstance(v, jex.core.Literal)): creation_op = graph.var_to_creation_op[v] @@ -883,14 +874,14 @@ def find_layer_tags_and_patterns( def read_env( - env: dict[jax.core.Var, T], + env: dict[jex.core.Var, T], variables: list[jax.core.Atom], ) -> list[T]: """Reads from the variable-to-array environment during tracing.""" result = [] assert isinstance(variables, list) for v in variables: - if isinstance(v, jax.core.Literal): + if isinstance(v, jex.core.Literal): # Literals are values baked into the Jaxpr result.append(v.val) elif isinstance(v, jax.core.DropVar): @@ -901,8 +892,8 @@ def read_env( def write_env( - env: dict[jax.core.Var, T], - variables: list[jax.core.Var], + env: dict[jex.core.Var, T], + variables: list[jex.core.Var], values: list[T], ) -> None: """Writes to the variable-to-array environment during tracing.""" @@ -979,7 +970,7 @@ def clean_jaxpr( final_outvars.append(var) - if not isinstance(var, jax.core.Literal): + if not isinstance(var, jex.core.Literal): dependants.add(var) for eqn in reversed(closed_jaxpr.jaxpr.eqns): @@ -1035,7 +1026,7 @@ def clean_jaxpr( if check: eqns.append(eqn) new_dependants = set(v for v in eqn.invars - if not isinstance(v, jax.core.Literal)) + if not isinstance(v, jex.core.Literal)) dependants = dependants.union(new_dependants) # Dependants should only be invars @@ -1112,7 +1103,7 @@ def merge_broadcasts_jaxpr(jaxpr: J) -> J: # We ignore broadcasting of constants if (eqn.primitive.name == "broadcast_in_dim" and - not all(isinstance(v, jax.core.Literal) for v in eqn.invars)): + not all(isinstance(v, jex.core.Literal) for v in eqn.invars)): if eqn.invars[0] in broadcasts_outputs: # Construct a merged equation from the previous and current one @@ -1139,7 +1130,7 @@ def merge_broadcasts_jaxpr(jaxpr: J) -> J: else: for v in eqn.invars: - if not isinstance(v, jax.core.Literal) and v in broadcasts_outputs: + if not isinstance(v, jex.core.Literal) and v in broadcasts_outputs: eqns.append(broadcasts_outputs[v]) eqns.append(eqn) @@ -1688,7 +1679,7 @@ def __init__( ): self._func_graph = func_graph self._tag_locations = tag_locations - self._flat_func = jax.core.jaxpr_as_fun(func_graph.closed_jaxpr) + self._flat_func = jex.core.jaxpr_as_fun(func_graph.closed_jaxpr) self._param_labels = self._compute_parameter_labels() def __call__(self, *args, **kwargs): @@ -1770,7 +1761,7 @@ def _auto_register_tags( eqns_for_registration.append(eqn) sub_graph_vars.update( - v for v in eqn.invars if not isinstance(v, jax.core.Literal)) + v for v in eqn.invars if not isinstance(v, jex.core.Literal)) eqns_for_registration = eqns_for_registration[::-1] diff --git a/kfac_jax/_src/tracer.py b/kfac_jax/_src/tracer.py index 15e698a..776dbf9 100644 --- a/kfac_jax/_src/tracer.py +++ b/kfac_jax/_src/tracer.py @@ -18,6 +18,7 @@ from absl import logging import jax +import jax.extend as jex import jax.numpy as jnp from kfac_jax._src import layers_and_loss_tags as tags from kfac_jax._src import loss_functions @@ -32,7 +33,7 @@ Params = utils.Params FuncArgs = utils.FuncArgs FuncOuts = utils.FuncOuts -Var = jax.core.Var +Var = jex.core.Var LossFunction = loss_functions.LossFunction LossFunctionInputs = loss_functions.LossFunctionInputs @@ -80,7 +81,7 @@ def tree_unflatten(cls, aux_data, children): tuple[LayerVjpData[Array], ...], # pytype: disable=invalid-annotation ], ] -JaxprOrClosedJaxpr = jax.core.Jaxpr | jax.core.ClosedJaxpr +JaxprOrClosedJaxpr = jex.core.Jaxpr | jex.core.ClosedJaxpr def shape_and_type(x: Array) -> tuple[Shape, jnp.dtype]: @@ -99,7 +100,7 @@ def make_cache_key( def extract_tags( - jaxpr: jax.core.Jaxpr, + jaxpr: jex.core.Jaxpr, ) -> tuple[tuple[tags.LayerTagEqn, ...], tuple[tags.LossTagEqn, ...]]: """Extracts the layer and the loss tags from the given Jaxpr.""" @@ -199,7 +200,7 @@ class ProcessedJaxpr(utils.Finalizable): def __init__( self, - jaxpr: jax.core.Jaxpr, + jaxpr: jex.core.Jaxpr, consts: list[Any], in_tree: utils.PyTreeDef, params_index: int, @@ -819,16 +820,16 @@ def forward_aux( own_func_args = primal_func_args # Mapping from variable -> value - env: dict[jax.core.Var, Array] = {} + env: dict[jex.core.Var, Array] = {} read = functools.partial(tgm.read_env, env) - def write(variables: list[jax.core.Var], values: list[Array]) -> None: + def write(variables: list[jex.core.Var], values: list[Array]) -> None: # if not isinstance(variables, list): # variables = [variables] tgm.write_env(env, variables, values) for v in variables: - if not isinstance(v, jax.core.Literal) and v in aux: + if not isinstance(v, jex.core.Literal) and v in aux: env[v] = env[v] + aux[v] # Bind args and consts to environment diff --git a/pyproject.toml b/pyproject.toml index 0cb1936..11b1e7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,8 +13,8 @@ dependencies = [ "immutabledict>=2.2.1", "numpy>=1.22", "distrax>=0.1.3", - "jax>=0.4.25", - "jaxlib>=0.4.25", + "jax>=0.4.27", + "jaxlib>=0.4.27", "dm-tree>=0.1.7", "optax>=0.1.4", "typing-extensions>=4.0.0" @@ -68,8 +68,8 @@ tests = [ # these should be version pinned? "immutabledict>=2.2.1", "numpy>=1.22", "distrax>=0.1.3", - "jax>=0.4.25", - "jaxlib>=0.4.25", + "jax>=0.4.27", + "jaxlib>=0.4.27", "dm-haiku>=0.0.9", "dm-tree>=0.1.7", "optax>=0.1.4", diff --git a/tests/models.py b/tests/models.py index a0e82bf..f2f0818 100644 --- a/tests/models.py +++ b/tests/models.py @@ -78,7 +78,6 @@ def __init__( super().__init__(*args, **kwargs) def __call__(self, inputs: LayerInputs, *_) -> LayerInputs: # pytype: disable=signature-mismatch # overriding-parameter-name-checks - jax_version = tuple(map(int, jax.__version__.split(".")[:3])) x, layer_values, aux = inputs y = super().__call__(x, precision=jax.lax.Precision.HIGHEST) if aux is not None: @@ -87,11 +86,8 @@ def __call__(self, inputs: LayerInputs, *_) -> LayerInputs: # pytype: disable=s if self._explicit_tagging: params = _extract_params(self, ("w", "b")) - if jax_version < (0, 4, 14): - preferred_element_type = None - else: - assert all(p.dtype == y.dtype for p in params if p is not None) - preferred_element_type = y.dtype + assert all(p.dtype == y.dtype for p in params if p is not None) + preferred_element_type = y.dtype y = tags.register_dense( y, x, *params, diff --git a/tests/test_graph_matcher.py b/tests/test_graph_matcher.py index a9f74fc..e59c014 100644 --- a/tests/test_graph_matcher.py +++ b/tests/test_graph_matcher.py @@ -19,6 +19,7 @@ from absl.testing import absltest from absl.testing import parameterized import jax +import jax.extend as jex import jax.numpy as jnp import kfac_jax from tests import models @@ -76,8 +77,8 @@ def check_equation_match(self, eqn1, vars_to_vars, vars_to_eqn): if exclude_param: j1 = eqn1.params[exclude_param] j2 = eqn2.params[exclude_param] - if isinstance(j1, jax.core.ClosedJaxpr): - assert isinstance(j2, jax.core.ClosedJaxpr) + if isinstance(j1, jex.core.ClosedJaxpr): + assert isinstance(j2, jex.core.ClosedJaxpr) self.assertEqual(len(j1.consts), len(j2.consts)) j1 = j1.jaxpr j2 = j2.jaxpr @@ -85,8 +86,8 @@ def check_equation_match(self, eqn1, vars_to_vars, vars_to_eqn): # Check variables for v1, v2 in zip(eqn1.invars, eqn2.invars): - if isinstance(v1, jax.core.Literal): - self.assertIsInstance(v2, jax.core.Literal) + if isinstance(v1, jex.core.Literal): + self.assertIsInstance(v2, jex.core.Literal) self.assertEqual(v1.aval, v2.aval) else: self.assertEqual(v1.aval.shape, v2.aval.shape) @@ -125,8 +126,8 @@ def check_jaxpr_equal(self, jaxpr_1, jaxpr_2, map_output_vars: bool): for v1, v2 in zip(eqn1.outvars, eqn2.outvars): if isinstance(v1, jax.core.DropVar): self.assertIsInstance(v2, jax.core.DropVar) - elif isinstance(v1, jax.core.Literal): - self.assertIsInstance(v2, jax.core.Literal) + elif isinstance(v1, jex.core.Literal): + self.assertIsInstance(v2, jex.core.Literal) self.assertEqual(v1.aval, v2.aval) else: self.assertEqual(v1.aval.shape, v2.aval.shape)