Skip to content

Commit

Permalink
Simplify implementation of nn.relu.
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosgmartin committed Dec 11, 2024
1 parent 6c45d31 commit 879a599
Showing 1 changed file with 1 addition and 5 deletions.
6 changes: 1 addition & 5 deletions jax/_src/nn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

import jax
import jax.numpy as jnp
from jax import custom_jvp
from jax import lax
from jax._src import config
from jax._src import core
Expand All @@ -50,7 +49,6 @@ def __repr__(self):

# activations

@custom_jvp
@jax.jit
def relu(x: ArrayLike) -> Array:
r"""Rectified linear unit activation function.
Expand Down Expand Up @@ -83,9 +81,7 @@ def relu(x: ArrayLike) -> Array:
:func:`relu6`
"""
return jnp.maximum(x, 0)
# For behavior at 0, see https://dl.acm.org/doi/10.5555/3540261.3540297
relu.defjvps(lambda g, ans, x: lax.select(x > 0, g, lax.full_like(g, 0)))
return jnp.where(jnp.greater(x, 0), x, lax.zeros_like_array(x))

@jax.jit
def squareplus(x: ArrayLike, b: ArrayLike = 4) -> Array:
Expand Down

0 comments on commit 879a599

Please sign in to comment.