From 879a5994feb9134638a3c5169a201f86eeb33349 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Wed, 11 Dec 2024 14:42:27 -0500 Subject: [PATCH 1/2] Simplify implementation of nn.relu. --- jax/_src/nn/functions.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 0bba9d8c6f63..758d9c6b3d29 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -25,7 +25,6 @@ import jax import jax.numpy as jnp -from jax import custom_jvp from jax import lax from jax._src import config from jax._src import core @@ -50,7 +49,6 @@ def __repr__(self): # activations -@custom_jvp @jax.jit def relu(x: ArrayLike) -> Array: r"""Rectified linear unit activation function. @@ -83,9 +81,7 @@ def relu(x: ArrayLike) -> Array: :func:`relu6` """ - return jnp.maximum(x, 0) -# For behavior at 0, see https://dl.acm.org/doi/10.5555/3540261.3540297 -relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0))) + return jnp.where(jnp.greater(x, 0), x, lax.zeros_like_array(x)) @jax.jit def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Array: From 777e57758d5d48cb37d399e3991bb2bccede2085 Mon Sep 17 00:00:00 2001 From: carlosgmartin Date: Wed, 11 Dec 2024 15:08:21 -0500 Subject: [PATCH 2/2] Testing. --- jax/_src/nn/functions.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/_src/nn/functions.py b/jax/_src/nn/functions.py index 758d9c6b3d29..e3257c7c6688 100644 --- a/jax/_src/nn/functions.py +++ b/jax/_src/nn/functions.py @@ -81,7 +81,8 @@ def relu(x: ArrayLike) -> Array: :func:`relu6` """ - return jnp.where(jnp.greater(x, 0), x, lax.zeros_like_array(x)) + z = lax.zeros_like_array(x) + return lax.select(lax.ge(x, z), x, z) @jax.jit def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Array: