Skip to content

Commit

Permalink
Update references to JAX's GitHub repo
Browse files Browse the repository at this point in the history
JAX has moved from https://github.com/google/jax to https://github.com/jax-ml/jax

PiperOrigin-RevId: 702886640
  • Loading branch information
jakeharmon8 authored and KfacJaxDev committed Dec 5, 2024
1 parent aaf3064 commit c21505d
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ implementation of the [K-FAC] optimizer and curvature estimator.

KFAC-JAX is written in pure Python, but depends on C++ code via JAX.

First, follow [these instructions](https://github.com/google/jax#installation)
First, follow [these instructions](https://github.com/jax-ml/jax#installation)
to install JAX with the relevant accelerator support.

Then, install KFAC-JAX using pip:
Expand Down Expand Up @@ -219,6 +219,6 @@ and the year corresponds to the project's open-source release.


[K-FAC]: https://arxiv.org/abs/1503.05671
[JAX]: https://github.com/google/jax
[JAX]: https://github.com/jax-ml/jax
[Haiku]: https://github.com/google-deepmind/dm-haiku
[documentation]: https://kfac-jax.readthedocs.io/
4 changes: 2 additions & 2 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
KFAC-JAX Documentation
======================

KFAC-JAX is a library built on top of `JAX <https://github.com/google/jax>`_ for
KFAC-JAX is a library built on top of `JAX <https://github.com/jax-ml/jax>`_ for
second-order optimization of neural networks and for computing scalable
curvature approximations.
The main goal of the library is to provide researchers with an easy-to-use
Expand All @@ -16,7 +16,7 @@ Installation

KFAC-JAX is written in pure Python, but depends on C++ code via JAX.

First, follow `these instructions <https://github.com/google/jax#installation>`_
First, follow `these instructions <https://github.com/jax-ml/jax#installation>`_
to install JAX with the relevant accelerator support.

Then, install KFAC-JAX using pip::
Expand Down
2 changes: 1 addition & 1 deletion examples/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def softmax_cross_entropy(
max_logits = jnp.max(logits, keepdims=True, axis=-1)

# It's unclear whether this stop_gradient is a good idea.
# See https://github.com/google/jax/issues/13529
# See https://github.com/jax-ml/jax/issues/13529
max_logits = lax.stop_gradient(max_logits)

logits = logits - max_logits
Expand Down

0 comments on commit c21505d

Please sign in to comment.