Skip to content

Commit

Permalink
Bump JAX version after release.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 703472753
  • Loading branch information
hawkinsp authored and Google-ML-Automation committed Dec 6, 2024
1 parent 9fc077a commit ba626fa
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 6 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ Remember to align the itemized text with the first line of an item within a list
When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.md.
-->

## jax 0.4.36
## jax 0.4.37

## jax 0.4.36 (Dec 5, 2024)

* Breaking Changes
* This release lands "stackless", an internal change to JAX's tracing
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/tree_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,13 +291,13 @@ def register_pytree_node(
"""
if xla_extension_version >= 299:
default_registry.register_node( # type: ignore[call-arg]
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type]
)
none_leaf_registry.register_node( # type: ignore[call-arg]
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type]
)
dispatch_registry.register_node( # type: ignore[call-arg]
nodetype, flatten_func, unflatten_func, flatten_with_keys_func
nodetype, flatten_func, unflatten_func, flatten_with_keys_func # type: ignore[arg-type]
)
else:
default_registry.register_node(nodetype, flatten_func, unflatten_func)
Expand Down
2 changes: 1 addition & 1 deletion jax/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pathlib
import subprocess

_version = "0.4.36"
_version = "0.4.37"
# The following line is overwritten by build scripts in distributions &
# releases. Do not modify this manually, or jax/jaxlib build will fail.
_release_version: str | None = None
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

_current_jaxlib_version = '0.4.36'
# The following should be updated after each new jaxlib release.
_latest_jaxlib_version_on_pypi = '0.4.35'
_latest_jaxlib_version_on_pypi = '0.4.36'

_libtpu_version = '0.0.5'
_libtpu_nightly_terminal_version = '0.1.dev20241010+nightly.cleanup'
Expand Down

0 comments on commit ba626fa

Please sign in to comment.