Skip to content

Commit

Permalink
Merge pull request #25374 from traversaro:patch-1
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 704673954
  • Loading branch information
Google-ML-Automation committed Dec 10, 2024
2 parents 90de28c + 09309e6 commit 8e7aaa7
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,18 +253,14 @@ simply run:
conda install jax -c conda-forge
```

To install it on a machine with an NVIDIA GPU, run:
If you run this command on machine with an NVIDIA GPU, this should install a CUDA-enabled package of `jaxlib`.

To ensure that the jax version you are installing is indeed CUDA-enabled, run:

```bash
conda install "jaxlib=*=*cuda*" jax cuda-nvcc -c conda-forge -c nvidia
conda install "jaxlib=*=*cuda*" jax -c conda-forge
```

Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas`, which
JAX requires. You must therefore either install the `cuda-nvcc` package from
the `nvidia` channel, or install CUDA on your machine separately so that `ptxas`
is in your path. The channel order above is important (`conda-forge` before
`nvidia`).

If you would like to override which release of CUDA is used by JAX, or to
install the CUDA build on a machine without GPUs, follow the instructions in the
[Tips & tricks](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch)
Expand Down

0 comments on commit 8e7aaa7

Please sign in to comment.