Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.nn.one_hot: deprecate non-integer inputs #25590

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax/_src/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None:
# always registered by the time `accelerate` and `is_acelerated` are called.
register('jax-aval-named-shape')
register('jax-dlpack-import-legacy')
register('jax-nn-one-hot-float-input')
register("jax-numpy-astype-complex-to-real")
register("jax-numpy-array-none")
register('jax-numpy-clip-args')
Expand Down
28 changes: 18 additions & 10 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from jax import lax
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
from jax._src import util
from jax._src.core import AxisName
Expand Down Expand Up @@ -645,34 +646,33 @@ def standardize(x: ArrayLike,

# TODO(slebedev): Change the type of `x` to `ArrayLike`.
@partial(jax.jit, static_argnames=("num_classes", "dtype", "axis"))
def _one_hot(x: Any, num_classes: int, *,
def _one_hot(x: Array, num_classes: int, *,
dtype: Any, axis: int | AxisName) -> Array:
num_classes = core.concrete_dim_or_error(
num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
dtype = dtypes.canonicalize_dtype(dtype)
x_arr = jnp.asarray(x)
try:
output_pos_axis = util.canonicalize_axis(axis, x_arr.ndim + 1)
output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
except TypeError:
axis_size = lax.psum(1, axis)
if num_classes != axis_size:
raise ValueError(f"Expected num_classes to match the size of axis {axis}, "
f"but {num_classes} != {axis_size}") from None
axis_idx = lax.axis_index(axis)
return jnp.asarray(x_arr == axis_idx, dtype=dtype)
return jnp.asarray(_dot_product_attention_xla == axis_idx, dtype=dtype)
axis = operator.index(axis) # type: ignore[arg-type]
lhs = lax.expand_dims(x_arr, (axis,))
rhs_shape = [1] * x_arr.ndim
lhs = lax.expand_dims(x, (axis,))
rhs_shape = [1] * x.ndim
rhs_shape.insert(output_pos_axis, num_classes)
if config.sharding_in_types.value:
# TODO(yashkatariya): Maybe expose `out_sharding` on `one_hot` too?
rhs_sharding = NamedSharding(x_arr.sharding.mesh, P(*[None] * len(rhs_shape)))
rhs_sharding = NamedSharding(x.sharding.mesh, P(*[None] * len(rhs_shape))) # pytype: disable=attribute-error
else:
rhs_sharding = None
rhs = lax.broadcasted_iota(x_arr.dtype, rhs_shape, output_pos_axis,
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,
_sharding=rhs_sharding)
return jnp.asarray(lhs == rhs, dtype=dtype)
return (lhs == rhs).astype(dtype)

# TODO(slebedev): Change the type of `x` to `ArrayLike`.
def one_hot(x: Any, num_classes: int, *,
Expand Down Expand Up @@ -703,7 +703,15 @@ def one_hot(x: Any, num_classes: int, *,
num_classes = core.concrete_dim_or_error(
num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
return _one_hot(x, num_classes, dtype=dtype, axis=axis)
x_arr = jnp.asarray(x)
if not jnp.isdtype(x_arr.dtype, "integral"):
# Deprecated 2024-12-18
deprecations.warn(
'jax-nn-one-hot-float-input',
f"jax.nn.one_hot input should be integer-typed; got dtype={x_arr.dtype}",
stacklevel=1)
x_arr = x_arr.astype('int32')
return _one_hot(x_arr, num_classes, dtype=dtype, axis=axis)


@jax.custom_jvp
Expand Down
4 changes: 2 additions & 2 deletions jax/experimental/jax2tf/tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1959,11 +1959,11 @@ def f2(z, w): # z: f32[a, 5] w: f32[a + b, 5] -> f32[2*a + b, 10]
expect_error=expect_error_associative_scan),
PolyHarness("one_hot", "poly_num_classes",
lambda x, y: jax.nn.one_hot(x, y.shape[0]),
arg_descriptors=[np.arange(16, dtype=_f32), RandArg((16,), _f32)],
arg_descriptors=[np.arange(16, dtype=_i32), RandArg((16,), _f32)],
polymorphic_shapes=[None, "b0, ..."]),
PolyHarness("one_hot", "all_poly",
lambda x, y: jax.nn.one_hot(x, y.shape[0]),
arg_descriptors=[np.arange(16, dtype=_f32), RandArg((16,), _f32)],
arg_descriptors=[np.arange(16, dtype=_i32), RandArg((16,), _f32)],
polymorphic_shapes=["b, ...", "b, ..."]),
PolyHarness("ones", "",
lambda x: jnp.ones(x.shape, dtype=_f32) + x,
Expand Down
12 changes: 11 additions & 1 deletion tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@

import scipy.stats

from jax._src import ad_checkpoint
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import test_util as jtu
from jax._src import ad_checkpoint
from jax._src.interpreters import mlir
from jax._src.lib import cuda_versions
from jax.test_util import check_grads
Expand Down Expand Up @@ -530,6 +531,15 @@ def testOneHotAxis(self):
actual = nn.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
self.assertAllClose(actual, expected, check_dtypes=False)

def testOneHotNonInteger(self):
def assert_warns_or_errors(msg):
if deprecations.is_accelerated("jax-nn-one-hot-float-input"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
with assert_warns_or_errors("jax.nn.one_hot input should be integer-typed"):
nn.one_hot(jnp.array([1.0]), 3)

def testTanhExists(self):
nn.tanh # doesn't crash

Expand Down
4 changes: 2 additions & 2 deletions tests/shape_poly_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2760,11 +2760,11 @@ def f(x_ref):
expect_error=expect_error_associative_scan),
PolyHarness("one_hot", "poly_num_classes",
lambda x, y: jax.nn.one_hot(x, y.shape[0]),
arg_descriptors=[np.arange(16, dtype=_f32), RandArg((16,), _f32)],
arg_descriptors=[np.arange(16, dtype=_i32), RandArg((16,), _f32)],
polymorphic_shapes=[None, "b0, ..."]),
PolyHarness("one_hot", "all_poly",
lambda x, y: jax.nn.one_hot(x, y.shape[0]),
arg_descriptors=[np.arange(16, dtype=_f32), RandArg((16,), _f32)],
arg_descriptors=[np.arange(16, dtype=_i32), RandArg((16,), _f32)],
polymorphic_shapes=["b, ...", "b, ..."]),
PolyHarness("ones", "",
lambda x: jnp.ones(x.shape, dtype=_f32) + x,
Expand Down
Loading