Skip to content

Commit

Permalink
jax.nn.one_hot: deprecate non-integer inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 18, 2024
1 parent 74eca13 commit d5768df
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 11 deletions.
1 change: 1 addition & 0 deletions jax/_src/deprecations.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def warn(deprecation_id: str, message: str, stacklevel: int) -> None:
# always registered by the time `accelerate` and `is_acelerated` are called.
register('jax-aval-named-shape')
register('jax-dlpack-import-legacy')
register('jax-nn-one-hot-float-input')
register("jax-numpy-astype-complex-to-real")
register("jax-numpy-array-none")
register('jax-numpy-clip-args')
Expand Down
28 changes: 18 additions & 10 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from jax import lax
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import dtypes
from jax._src import util
from jax._src.core import AxisName
Expand Down Expand Up @@ -645,34 +646,33 @@ def standardize(x: ArrayLike,

# TODO(slebedev): Change the type of `x` to `ArrayLike`.
@partial(jax.jit, static_argnames=("num_classes", "dtype", "axis"))
def _one_hot(x: Any, num_classes: int, *,
def _one_hot(x: Array, num_classes: int, *,
dtype: Any, axis: int | AxisName) -> Array:
num_classes = core.concrete_dim_or_error(
num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
dtype = dtypes.canonicalize_dtype(dtype)
x_arr = jnp.asarray(x)
try:
output_pos_axis = util.canonicalize_axis(axis, x_arr.ndim + 1)
output_pos_axis = util.canonicalize_axis(axis, x.ndim + 1)
except TypeError:
axis_size = lax.psum(1, axis)
if num_classes != axis_size:
raise ValueError(f"Expected num_classes to match the size of axis {axis}, "
f"but {num_classes} != {axis_size}") from None
axis_idx = lax.axis_index(axis)
return jnp.asarray(x_arr == axis_idx, dtype=dtype)
return jnp.asarray(_dot_product_attention_xla == axis_idx, dtype=dtype)
axis = operator.index(axis) # type: ignore[arg-type]
lhs = lax.expand_dims(x_arr, (axis,))
rhs_shape = [1] * x_arr.ndim
lhs = lax.expand_dims(x, (axis,))
rhs_shape = [1] * x.ndim
rhs_shape.insert(output_pos_axis, num_classes)
if config.sharding_in_types.value:
# TODO(yashkatariya): Maybe expose `out_sharding` on `one_hot` too?
rhs_sharding = NamedSharding(x_arr.sharding.mesh, P(*[None] * len(rhs_shape)))
rhs_sharding = NamedSharding(x.sharding.mesh, P(*[None] * len(rhs_shape)))
else:
rhs_sharding = None
rhs = lax.broadcasted_iota(x_arr.dtype, rhs_shape, output_pos_axis,
rhs = lax.broadcasted_iota(x.dtype, rhs_shape, output_pos_axis,
_sharding=rhs_sharding)
return jnp.asarray(lhs == rhs, dtype=dtype)
return (lhs == rhs).astype(dtype)

# TODO(slebedev): Change the type of `x` to `ArrayLike`.
def one_hot(x: Any, num_classes: int, *,
Expand Down Expand Up @@ -703,7 +703,15 @@ def one_hot(x: Any, num_classes: int, *,
num_classes = core.concrete_dim_or_error(
num_classes,
"The error arose in jax.nn.one_hot argument `num_classes`.")
return _one_hot(x, num_classes, dtype=dtype, axis=axis)
x_arr = jnp.asarray(x)
if not jnp.isdtype(x_arr.dtype, "integral"):
# Deprecated 2024-12-18
deprecations.warn(
'jax-nn-one-hot-float-input',
f"jax.nn.one_hot input should be integer-typed; got dtype={x_arr.dtype}",
stacklevel=1)
x_arr = x_arr.astype('int32')
return _one_hot(x_arr, num_classes, dtype=dtype, axis=axis)


@jax.custom_jvp
Expand Down
12 changes: 11 additions & 1 deletion tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@

import scipy.stats

from jax._src import ad_checkpoint
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import test_util as jtu
from jax._src import ad_checkpoint
from jax._src.interpreters import mlir
from jax._src.lib import cuda_versions
from jax.test_util import check_grads
Expand Down Expand Up @@ -530,6 +531,15 @@ def testOneHotAxis(self):
actual = nn.one_hot(jnp.array([1, 2, 0]), 3, axis=-2)
self.assertAllClose(actual, expected, check_dtypes=False)

def testOneHotNonInteger(self):
def assert_warns_or_errors(msg):
if deprecations.is_accelerated("jax-nn-one-hot-float-input"):
return self.assertRaisesRegex(ValueError, msg)
else:
return self.assertWarnsRegex(DeprecationWarning, msg)
with assert_warns_or_errors("jax.nn.one_hot input should be integer-typed"):
nn.one_hot(jnp.array([1.0]), 3)

def testTanhExists(self):
nn.tanh # doesn't crash

Expand Down

0 comments on commit d5768df

Please sign in to comment.