Skip to content

Minor bug fix: changing train_step in examples code to take a mean of the stats instead of taking from the first device. Because the optimizer syncs its own stats (like loss), this didn't matter except for stats returned from the kfac_jax optimizer (or Optax optimizers using OptaxWrapper). However, the Polyak averaged loss wasn't actually synced across devices (as its not part of the optimizer anymore), so "loss_polyak" was being reported only for the first device. #1525

Minor bug fix: changing train_step in examples code to take a mean of the stats instead of taking from the first device. Because the optimizer syncs its own stats (like loss), this didn't matter except for stats returned from the kfac_jax optimizer (or Optax optimizers using OptaxWrapper). However, the Polyak averaged loss wasn't actually synced across devices (as its not part of the optimizer anymore), so "loss_polyak" was being reported only for the first device.

Minor bug fix: changing train_step in examples code to take a mean of the stats instead of taking from the first device. Because the optimizer syncs its own stats (like loss), this didn't matter except for stats returned from the kfac_jax optimizer (or Optax optimizers using OptaxWrapper). However, the Polyak averaged loss wasn't actually synced across devices (as its not part of the optimizer anymore), so "loss_polyak" was being reported only for the first device. #1525

Workflow file for this run

name: ci
on:
push:
branches: ["main"]
pull_request:
branches: ["main"]
jobs:
build-and-test:
name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}"
runs-on: "${{ matrix.os }}"
strategy:
matrix:
python-version: ["3.10", "3.11"]
os: [ubuntu-latest]
steps:
- uses: "actions/checkout@v3"
- uses: "actions/setup-python@v4"
with:
python-version: "${{ matrix.python-version }}"
cache: "pip"
cache-dependency-path: 'pyproject.toml'
- name: Run CI tests
run: bash test.sh
shell: bash