Skip to content

Commit

Permalink
Merge pull request #25392 from carlosgmartin:add_nn_relu_grad_at_zero…
Browse files Browse the repository at this point in the history
…_test_update_paper_link

PiperOrigin-RevId: 704947960
  • Loading branch information
Google-ML-Automation committed Dec 11, 2024
2 parents 41f490a + 0880114 commit cfdac00
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def relu(x: ArrayLike) -> Array:
For more information see
`Numerical influence of ReLU’(0) on backpropagation
<https://openreview.net/forum?id=urrcVI-_jRm>`_.
<https://dl.acm.org/doi/10.5555/3540261.3540297>`_.
Args:
x : input array
Expand All @@ -84,7 +84,7 @@ def relu(x: ArrayLike) -> Array:
"""
return jnp.maximum(x, 0)
# For behavior at 0, see https://openreview.net/forum?id=urrcVI-_jRm
# For behavior at 0, see https://dl.acm.org/doi/10.5555/3540261.3540297
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))

@jax.jit
Expand Down
5 changes: 5 additions & 0 deletions tests/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,11 @@ def testReluGrad(self):
jaxpr = jax.make_jaxpr(jax.grad(nn.relu))(0.)
self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)

def testReluGradAtZero(self):
# https://dl.acm.org/doi/10.5555/3540261.3540297
grad = jax.grad(nn.relu)(0.)
self.assertEqual(grad, 0.)

def testRelu6Grad(self):
rtol = 1e-2 if jtu.test_device_matches(["tpu"]) else None
check_grads(nn.relu6, (1.,), order=3, rtol=rtol)
Expand Down

0 comments on commit cfdac00

Please sign in to comment.