Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Minor bug fix: changing train_step in examples code to take a mean of…
… the stats instead of taking from the first device. Because the optimizer syncs its own stats (like loss), this didn't matter except for stats returned from the kfac_jax optimizer (or Optax optimizers using OptaxWrapper). However, the Polyak averaged loss wasn't actually synced across devices (as its not part of the optimizer anymore), so "loss_polyak" was being reported only for the first device. PiperOrigin-RevId: 700710272
- Loading branch information