Skip to content

Commit

Permalink
Increase the minimum NumPy version to v1.25.
Browse files Browse the repository at this point in the history
Per SPEC 0, we drop NumPy v1.24 support on Dec 18, 2024.
  • Loading branch information
hawkinsp committed Dec 18, 2024
1 parent 25524ab commit ee45718
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.

## Unreleased

* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.

## jax 0.4.38 (Dec 17, 2024)

* Changes:
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def has_ext_modules(self):
install_requires=[
'scipy>=1.10',
"scipy>=1.11.1; python_version>='3.12'",
'numpy>=1.24',
'numpy>=1.25',
'ml_dtypes>=0.2.0',
],
url='https://github.com/jax-ml/jax',
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def load_version_module(pkg_path):
install_requires=[
f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}',
'ml_dtypes>=0.4.0',
'numpy>=1.24',
'numpy>=1.25',
"numpy>=1.26.0; python_version>='3.12'",
'opt_einsum',
'scipy>=1.10',
Expand Down
6 changes: 0 additions & 6 deletions tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@

import numpy as np

numpy_version = jtu.numpy_version()

config.parse_flags_with_absl()

try:
Expand All @@ -48,10 +46,6 @@
[dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16],
key=lambda x: x.__name__)

# NumPy didn't support bool as a dlpack type until 1.25.
if jtu.numpy_version() < (1, 25, 0):
numpy_dtypes = [dt for dt in numpy_dtypes if dt != jnp.bool_]

cuda_array_interface_dtypes = [dt for dt in dlpack_dtypes if dt != jnp.bfloat16]

nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]
Expand Down

0 comments on commit ee45718

Please sign in to comment.