Skip to content

Commit

Permalink
Fix typing for NumPy 2.2
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 10, 2024
1 parent c0f6642 commit a298b50
Show file tree
Hide file tree
Showing 8 changed files with 10 additions and 10 deletions.
2 changes: 1 addition & 1 deletion jax/_src/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def _issubdtype_cached(a: type | np.dtype | ExtendedDType,
return b_sctype in {a_sctype, np.unsignedinteger, np.integer, np.number, np.generic}

# Otherwise, fall back to numpy.issubdtype
return np.issubdtype(a_sctype, b_sctype)
return bool(np.issubdtype(a_sctype, b_sctype))

can_cast = np.can_cast

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def write_primal(v, val):
write_cotangent(eqn.primitive, val_var, ct_out)
elif eqn.primitive is core.freeze_p:
val_var, = eqn.outvars
ref_var, = eqn.invars
ref_var, = eqn.invars # type: ignore
ct_in = instantiate_zeros(read_cotangent(val_var))
write_primal(ref_var, core.mutable_array(ct_in))
continue
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2793,7 +2793,7 @@ def _wrapped_callback(*args): # pylint: disable=function-redefined
def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
if minor_to_major is None:
# Needed for token layouts
layout = np.zeros((0,), dtype="int64")
layout: np.ndarray = np.zeros((0,), dtype="int64")
else:
layout = np.array(minor_to_major, dtype="int64")
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/mesh_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def _create_device_mesh_for_nd_torus_splitting_axes(
)
):
best_logical_axis_assignment = logical_axis_assignment
assignment[:, logical_axis] = best_logical_axis_assignment
assignment[:, logical_axis] = best_logical_axis_assignment # type: ignore # numpy 2.2

# Read out the assignment.
logical_mesh = _generate_logical_mesh(
Expand Down Expand Up @@ -597,10 +597,10 @@ def _generate_logical_mesh(
zip(logical_indices, physical_indices, range(len(logical_indices)))
)
)
logical_mesh = np.transpose(logical_mesh, transpose_axes)
logical_mesh = np.transpose(logical_mesh, transpose_axes) # type: ignore # numpy 2.2

# Reshape to add the trivial dimensions back.
logical_mesh = np.reshape(logical_mesh, logical_mesh_shape)
logical_mesh = np.reshape(logical_mesh, logical_mesh_shape) # type: ignore # numpy 2.2

return logical_mesh

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1374,7 +1374,7 @@ def _lstsq(a: ArrayLike, b: ArrayLike, rcond: float | None, *,
x = jnp.empty((n, *b.shape[1:]), dtype=a.dtype)
else:
if rcond is None:
rcond = jnp.finfo(dtype).eps * max(n, m)
rcond = float(jnp.finfo(dtype).eps) * max(n, m)
else:
rcond = jnp.where(rcond < 0, jnp.finfo(dtype).eps, rcond)
u, s, vt = svd(a, full_matrices=False)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/op_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def op_sharding_to_numpy_indices(

for i, idxs in enumerate(itertools.product(*axis_indices)):
for _ in range(num_replicas):
indices[next(device_it)] = idxs
indices[next(device_it)] = idxs # type: ignore # numpy 2.2
return indices


Expand Down
2 changes: 1 addition & 1 deletion jax/_src/sharding_impls.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,7 +725,7 @@ def __repr__(self) -> str:
ids = self._ids.copy()
platform_name = self._devices[0].platform.upper()
for idx, x in np.ndenumerate(ids):
ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x))
ids[idx] = DeviceIdSet(platform_name, *(self._devices[i].id for i in x)) # type: ignore # numpy 2.2
body = np.array2string(ids, prefix=cls_name + '(', suffix=')',
max_line_width=100)
mem = '' if self._memory_kind is None else f', memory_kind={self._memory_kind}'
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/sharding_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def _sharding_spec_indices(self, shape: tuple[int, ...]) -> np.ndarray:
# is used to extract the corresponding shard of the logical array.
shard_indices = np.empty([math.prod(shard_indices_shape)], dtype=np.object_)
for i, idxs in enumerate(itertools.product(*axis_indices)):
shard_indices[i] = idxs
shard_indices[i] = idxs # type: ignore # numpy 2.2
shard_indices = shard_indices.reshape(shard_indices_shape)

# Ensure that each sharded axis is used exactly once in the mesh mapping
Expand Down

0 comments on commit a298b50

Please sign in to comment.