Skip to content

Commit

Permalink
Merge pull request #25569 from hawkinsp:numpyver
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707570246
  • Loading branch information
Google-ML-Automation committed Dec 18, 2024
2 parents 3f24dfd + ee45718 commit 464e5a2
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 8 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.

## Unreleased

* Changes:
* The minimum NumPy version is now 1.25. NumPy 1.25 will remain the minimum
supported version until June 2025.

## jax 0.4.38 (Dec 17, 2024)

* Changes:
Expand Down
2 changes: 1 addition & 1 deletion jaxlib/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def has_ext_modules(self):
install_requires=[
'scipy>=1.10',
"scipy>=1.11.1; python_version>='3.12'",
'numpy>=1.24',
'numpy>=1.25',
'ml_dtypes>=0.2.0',
],
url='https://github.com/jax-ml/jax',
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def load_version_module(pkg_path):
install_requires=[
f'jaxlib >={_minimum_jaxlib_version}, <={_jax_version}',
'ml_dtypes>=0.4.0',
'numpy>=1.24',
'numpy>=1.25',
"numpy>=1.26.0; python_version>='3.12'",
'opt_einsum',
'scipy>=1.10',
Expand Down
6 changes: 0 additions & 6 deletions tests/array_interoperability_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@

import numpy as np

numpy_version = jtu.numpy_version()

config.parse_flags_with_absl()

try:
Expand All @@ -48,10 +46,6 @@
[dt for dt in jax.dlpack.SUPPORTED_DTYPES if dt != jnp.bfloat16],
key=lambda x: x.__name__)

# NumPy didn't support bool as a dlpack type until 1.25.
if jtu.numpy_version() < (1, 25, 0):
numpy_dtypes = [dt for dt in numpy_dtypes if dt != jnp.bool_]

cuda_array_interface_dtypes = [dt for dt in dlpack_dtypes if dt != jnp.bfloat16]

nonempty_nonscalar_array_shapes = [(4,), (3, 4), (2, 3, 4)]
Expand Down

0 comments on commit 464e5a2

Please sign in to comment.