Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update front-page readme with link to XLA flag doc #684

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,22 +270,29 @@ We currently enable training and evaluation for the following models:
We will update this table as new models become available, so stay tuned.

## Environment Variables

The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is embedded with the following flags and environment variables for performance tuning:

| XLA Flags | Value | Explanation |
| --------- | ----- | ----------- |
| `--xla_gpu_enable_latency_hiding_scheduler` | `true` | allows XLA to move communication collectives to increase overlap with compute kernels |
| `--xla_gpu_enable_async_all_gather` | `true` | allows XLA to run NCCL [AllGather](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html#allgather) kernels on a separate CUDA stream to allow overlap with compute kernels |
| `--xla_gpu_enable_async_reduce_scatter` | `true` | allows XLA to run NCCL [ReduceScatter](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/operations.html#reducescatter) kernels on a separate CUDA stream to allow overlap with compute kernels |
| `--xla_gpu_enable_triton_gemm` | `false` | use cuBLAS instead of Trition GeMM kernels |
The [JAX images](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) are embedded with the following environment variables and XLA flags for performance tuning:

| Environment Variable | Value | Explanation |
| -------------------- | ----- | ----------- |
| `CUDA_DEVICE_MAX_CONNECTIONS` | `1` | use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches |
| `NCCL_NVLS_ENABLE` | `0` | Disables NVLink SHARP ([1](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. |
| `CUDA_MODULE_LOADING` | `EAGER` | Disables lazy-loading ([1](https://docs.nvidia.com/cuda/cuda-c-programming-guide/#cuda-environment-variables)) which uses slightly more GPU memory. |

XLA flags that tune performance are also set by default in the [JAX images](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax). To view the
the flags currently set, you can inspect the container's environment variables:
```sh
# Update IMAGE to inspect a container of your choosing
IMAGE=ghcr.io/nvidia/jax:jax

docker run --rm quay.io/skopeo/stable inspect docker://$IMAGE | jq -r '.Env[]' | grep '^XLA_FLAGS='

# which returns

XLA_FLAGS= --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_async_all_gather=true --xla_gpu_enable_async_reduce_scatter=true --xla_gpu_enable_triton_gemm=false
```

See [GPU performance](./rosetta/docs/GPU_performance.md) for details about these, and other XLA flags, that enable high-performance for LLMs on NVIDIA GPUs.

## Profiling JAX programs on GPU
See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU.

Expand Down