Skip to content

Commit

Permalink
Minor bug fix: changing train_step in examples code to take a mean of…
Browse files Browse the repository at this point in the history
… the stats instead of taking from the first device. Because the optimizer syncs its own stats (like loss), this didn't matter except for stats returned from the kfac_jax optimizer (or Optax optimizers using OptaxWrapper). However, the Polyak averaged loss wasn't actually synced across devices (as its not part of the optimizer anymore), so "loss_polyak" was being reported only for the first device.

PiperOrigin-RevId: 700710272
  • Loading branch information
james-martens authored and KfacJaxDev committed Nov 27, 2024
1 parent 4de99f5 commit ac44bc6
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions examples/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ def train_step(self, global_step: Array, rng: PRNGKey) -> dict[str, Numeric]:
for i in range(gathered_stat.shape[0]):
stats[f"{name}_{i}"] = jnp.array([gathered_stat[i]])

return kfac_jax.utils.get_first(stats)
return jax.tree_util.tree_map(functools.partial(jnp.mean, axis=0), stats)

# _
# _____ ____ _| |
Expand Down Expand Up @@ -774,26 +774,35 @@ def run_evaluation(
global_step, self._params, self._state, self._opt_state, key, batch)

if params_polyak is not None:

stats_no_polyak = stats

stats = self.eval_batch_pmap(
global_step, params_polyak, func_state_polyak, self._opt_state,
key, batch)

stats.update(
{k + "_no_polyak": v for k, v in stats_no_polyak.items()
if k != "data_seen"})

if params_schedule_free is not None:

stats_no_sf = stats

stats = self.eval_batch_pmap(
global_step, params_schedule_free, func_state_schedule_free,
self._opt_state, key, batch)

stats.update(
{k + "_no_sf": v for k, v in stats_no_sf.items()
if k != "data_seen"})

averaged_stats.add(stats, 1)

# Extract all stats
# Extract all stats.
# Note that MultiChunkAccumulator.value will perform a pmean
# automatically, so it's fine to call "get_first" here instead of taking
# the mean.
for k, v in averaged_stats.value.items(): # pytype: disable=attribute-error
all_stats[f"{name}_{k}"] = kfac_jax.utils.get_first(v)

Expand Down

0 comments on commit ac44bc6

Please sign in to comment.