Skip to content

Commit

Permalink
Change CenterNet reg and heatmap loss to be compatible with both pmap…
Browse files Browse the repository at this point in the history
… and jit sharding.

PiperOrigin-RevId: 701813674
  • Loading branch information
Anthony Sherbondy authored and Scenic Authors committed Dec 2, 2024
1 parent 0340172 commit 96db8b0
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions scenic/projects/baselines/centernet/modeling/centernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,15 +214,17 @@ def heatmap_focal_loss(self, heatmaps, gt_heatmaps):
pos_w = (gt_heatmaps == 1.).astype(jnp.float32) # B x m x C
pos_loss = jnp.log(pred) * jnp.power(1 - pred, self.focal_gamma) * pos_w
neg_loss = jnp.log(1. - pred) * jnp.power(pred, self.focal_gamma) * neg_w
norm = jnp.maximum(pos_w.sum(), 1.) # scalar
bs = pos_w.shape[0]
norm = jnp.maximum(pos_w.reshape((bs, -1)).sum(1), 1.0)
norm = jnp.mean(norm) # scalar
if self.sync_device_norm: # sync across GPUs. Helpful for small batch size.
norm = jax.lax.pmean(norm, axis_name='batch')
pos_loss = pos_loss.sum() / norm # scalar
neg_loss = neg_loss.sum() / norm # scalar
norm = jax.lax.pmean(norm, axis_name='batch') # scalar
pos_loss = jnp.mean(pos_loss.reshape((bs, -1)).sum(1)) / norm # scalar
neg_loss = jnp.mean(neg_loss.reshape((bs, -1)).sum(1)) / norm # scalar
if self.focal_alpha >= 0:
pos_loss = self.focal_alpha * pos_loss
neg_loss = (1. - self.focal_alpha) * neg_loss
return - pos_loss, - neg_loss, norm / heatmaps.shape[0]
return -pos_loss, -neg_loss, norm

def reg_loss(self, box_regs, gt_regs):
"""Compute regression loss.
Expand All @@ -242,11 +244,12 @@ def reg_loss(self, box_regs, gt_regs):
"""
reg_inds = gt_regs.max(axis=2) >= 0 # B x m: find valid pixels.
gious = centernet_utils.giou_loss(box_regs, gt_regs) # B x m
norm = jnp.maximum(reg_inds.sum(), 1.) # scalar
norm = jnp.maximum(reg_inds.sum(1), 1.0)
norm = jnp.mean(norm) # scalar
if self.sync_device_norm:
norm = jax.lax.pmean(norm, axis_name='batch')
reg_loss = (gious * reg_inds).sum() / norm # scalar
return reg_loss, gious, norm / reg_inds.shape[0]
reg_loss = jnp.mean((gious * reg_inds).sum(1)) / norm # scalar
return reg_loss, gious, norm

def _get_bbox_ltrb(self, grids, boxes, m, n):
"""generate FCOS style regression targets.
Expand Down

0 comments on commit 96db8b0

Please sign in to comment.