Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding capability pass custom arguments to the registration functions, and call them in a custom module, for standard losses in the example code. #175

Merged
merged 1 commit into from
Sep 19, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions examples/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions for computing and automatically registering losses."""
from typing import Optional, Sequence, Tuple, Dict
import types

from typing import Any, Optional, Sequence, Tuple, Dict

import haiku as hk
import jax
Expand Down Expand Up @@ -50,15 +52,24 @@ def sigmoid_cross_entropy(
labels: Array,
weight: float = 1.0,
register_loss: bool = True,
extra_registration_kwargs: Optional[Dict[str, Any]] = None,
registration_module: types.ModuleType = kfac_jax,
) -> Array:
"""Sigmoid cross-entropy loss."""

if register_loss:
kfac_jax.register_sigmoid_cross_entropy_loss(logits, labels, weight)
# Code is copied from Tensorflow.
registration_module.register_sigmoid_cross_entropy_loss(
logits, labels, weight, **extra_registration_kwargs)

# Code below is copied from Tensorflow:

zeros = jnp.zeros_like(logits)

relu_logits = jnp.where(logits >= zeros, logits, zeros)
neg_abs_logits = jnp.where(logits >= zeros, -logits, logits)

log_1p = jnp.log1p(jnp.exp(neg_abs_logits))

return weight * jnp.add(relu_logits - logits * labels, log_1p)


Expand All @@ -68,21 +79,28 @@ def softmax_cross_entropy(
weight: Numeric = 1.0,
register_loss: bool = True,
mask: Optional[Array] = None,
extra_registration_kwargs: Optional[Dict[str, Any]] = None,
registration_module: types.ModuleType = kfac_jax,
) -> Array:
"""Softmax cross entropy loss."""

if extra_registration_kwargs is None:
extra_registration_kwargs = {}

if register_loss:

if not isinstance(weight, float):
raise NotImplementedError("Non-constant loss weights are not currently "
"supported.")

# Currently the registration functions only support 2D array inputs values
# for `logits`, and so we need the reshapes below.
kfac_jax.register_softmax_cross_entropy_loss(
registration_module.register_softmax_cross_entropy_loss(
logits.reshape([-1, logits.shape[-1]]),
targets=labels.reshape([-1]),
mask=mask.reshape([-1]) if mask is not None else None,
weight=weight)
weight=weight,
**extra_registration_kwargs)

max_logits = jnp.max(logits, keepdims=True, axis=-1)

Expand Down Expand Up @@ -128,14 +146,20 @@ def squared_error(
targets: Array,
weight: float = 1.0,
register_loss: bool = True,
extra_registration_kwargs: Optional[Dict[str, Any]] = None,
registration_module: types.ModuleType = kfac_jax,
) -> Array:
"""Squared error loss."""

if extra_registration_kwargs is None:
extra_registration_kwargs = {}

if prediction.shape != targets.shape:
raise ValueError("prediction and targets should have the same shape.")

if register_loss:
kfac_jax.register_squared_error_loss(prediction, targets, weight)
registration_module.register_squared_error_loss(
prediction, targets, weight, **extra_registration_kwargs)

return weight * jnp.sum(jnp.square(prediction - targets), axis=-1)

Expand Down Expand Up @@ -206,6 +230,8 @@ def classifier_loss_and_stats(
register_loss: bool = True,
mask: Optional[Array] = None,
normalization_mode: str = "batch_size_only",
extra_registration_kwargs: Optional[Dict[str, Any]] = None,
registration_module: types.ModuleType = kfac_jax,
) -> Tuple[Array, Dict[str, Array]]:
"""Softmax cross-entropy with regularizer and accuracy statistics."""

Expand Down Expand Up @@ -238,7 +264,9 @@ def classifier_loss_and_stats(
labels = add_label_smoothing(labels_as_int, label_smoothing, logits.shape[-1])

softmax_loss = softmax_cross_entropy(
logits, labels, weight=weight, register_loss=register_loss, mask=mask)
logits, labels, weight=weight, register_loss=register_loss, mask=mask,
extra_registration_kwargs=extra_registration_kwargs,
registration_module=registration_module)

averaged_raw_loss = jnp.sum(softmax_loss, axis=0) / batch_size

Expand Down