Minor bug fix: changing train_step in examples code to take a mean of the stats instead of taking from the first device. Because the optimizer syncs its own stats (like loss), this didn't matter except for stats returned from the kfac_jax optimizer (or Optax optimizers using OptaxWrapper). However, the Polyak averaged loss wasn't actually synced across devices (as its not part of the optimizer anymore), so "loss_polyak" was being reported only for the first device. #1526
Workflow file for this run
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: ci | |
on: | |
push: | |
branches: ["main"] | |
pull_request: | |
branches: ["main"] | |
jobs: | |
build-and-test: | |
name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" | |
runs-on: "${{ matrix.os }}" | |
strategy: | |
matrix: | |
python-version: ["3.10", "3.11"] | |
os: [ubuntu-latest] | |
steps: | |
- uses: "actions/checkout@v3" | |
- uses: "actions/setup-python@v4" | |
with: | |
python-version: "${{ matrix.python-version }}" | |
cache: "pip" | |
cache-dependency-path: 'pyproject.toml' | |
- name: Run CI tests | |
run: bash test.sh | |
shell: bash |