diff --git a/CHANGELOG.md b/CHANGELOG.md index b6d0f97f439d..2e411e27faee 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,7 +10,9 @@ Remember to align the itemized text with the first line of an item within a list When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md. --> -## jax 0.4.36 +## jax 0.4.37 + +## jax 0.4.36 (Dec 5, 2024) * Breaking Changes * This release lands "stackless", an internal change to JAX's tracing diff --git a/jax/_src/tree_util.py b/jax/_src/tree_util.py index 73cff5aa8042..bb9924f8bb72 100644 --- a/jax/_src/tree_util.py +++ b/jax/_src/tree_util.py @@ -291,13 +291,13 @@ def register_pytree_node( """ if xla_extension_version >= 299: default_registry.register_node( # type: ignore[call-arg] - nodetype, flatten_func, unflatten_func, flatten_with_keys_func + nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] ) none_leaf_registry.register_node( # type: ignore[call-arg] - nodetype, flatten_func, unflatten_func, flatten_with_keys_func + nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] ) dispatch_registry.register_node( # type: ignore[call-arg] - nodetype, flatten_func, unflatten_func, flatten_with_keys_func + nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type] ) else: default_registry.register_node(nodetype, flatten_func, unflatten_func) diff --git a/jax/version.py b/jax/version.py index 941b34f1226f..9da3d63f8708 100644 --- a/jax/version.py +++ b/jax/version.py @@ -21,7 +21,7 @@ import pathlib import subprocess -_version = "0.4.36" +_version = "0.4.37" # The following line is overwritten by build scripts in distributions & # releases. Do not modify this manually, or jax/jaxlib build will fail. _release_version: str | None = None diff --git a/setup.py b/setup.py index ea42d625eadc..dfe64c4d83ac 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ _current_jaxlib_version = '0.4.36' # The following should be updated after each new jaxlib release. -_latest_jaxlib_version_on_pypi = '0.4.35' +_latest_jaxlib_version_on_pypi = '0.4.36' _libtpu_version = '0.0.5' _libtpu_nightly_terminal_version = '0.1.dev20241010+nightly.cleanup'