Skip to content

Commit

Permalink
Simplify implementation of nn.relu. Update paper link.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Dec 11, 2024
1 parent ad00ee1 commit 107911e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
10 changes: 3 additions & 7 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import jax
import jax.numpy as jnp
from jax import custom_jvp
from jax import lax
from jax._src import config
from jax._src import core
Expand All @@ -50,7 +49,6 @@ def __repr__(self):

# activations

@custom_jvp
@jax.jit
def relu(x: ArrayLike) -> Array:
r"""Rectified linear unit activation function.
Expand All @@ -66,8 +64,8 @@ def relu(x: ArrayLike) -> Array:
\nabla \mathrm{relu}(0) = 0
For more information see
`Numerical influence of ReLU(0) on backpropagation
<https://openreview.net/forum?id=urrcVI-_jRm>`_.
`Numerical influence of ReLU'(0) on backpropagation
<https://dl.acm.org/doi/10.5555/3540261.3540297>`_.
Args:
x : input array
Expand All @@ -83,9 +81,7 @@ def relu(x: ArrayLike) -> Array:
:func:`relu6`
"""
return jnp.maximum(x, 0)
# For behavior at 0, see https://openreview.net/forum?id=urrcVI-_jRm
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
return jnp.where(jnp.greater(x, 0), x, 0)

@jax.jit
def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Array:
Expand Down
5 changes: 5 additions & 0 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,11 @@ def testReluGrad(self):
jaxpr = jax.make_jaxpr(jax.grad(nn.relu))(0.)
self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)

def testReluGradAtZero(self):
# https://dl.acm.org/doi/10.5555/3540261.3540297
grad = jax.grad(nn.relu)(0.)
self.assertEqual(grad, 0.)

def testRelu6Grad(self):
rtol = 1e-2 if jtu.test_device_matches(["tpu"]) else None
check_grads(nn.relu6, (1.,), order=3, rtol=rtol)
Expand Down

0 comments on commit 107911e

Please sign in to comment.