Skip to content

Commit

Permalink
Update documentation and add setup.py pypi bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
Ruturaj4 authored and Github Actions committed Dec 12, 2024
1 parent e6d6c4e commit 99776f9
Show file tree
Hide file tree
Showing 6 changed files with 216 additions and 63 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ Some standouts:
| CPU | `pip install -U jax` |
| NVIDIA GPU | `pip install -U "jax[cuda12]"` |
| Google TPU | `pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html` |
| AMD GPU (Linux) | Use [Docker](https://hub.docker.com/r/rocm/jax-community/tags), [pre-built wheels](https://github.com/ROCm/jax/releases), or [build from source](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus). |
| AMD GPU (Linux) | Follow [AMD's instructions](https://github.com/jax-ml/jax/blob/main/build/rocm/README.md). |
| Mac GPU | Follow [Apple's instructions](https://developer.apple.com/metal/jax/). |
| Intel GPU | Follow [Intel's instructions](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md). |

Expand Down
213 changes: 192 additions & 21 deletions build/rocm/README.md
Original file line number Diff line number Diff line change
@@ -1,38 +1,209 @@
# JAX Builds on ROCm
This directory contains files and setup instructions to build and test JAX for ROCm in Docker environment (runtime and CI). You can build, test and run JAX on ROCm yourself!
***
### Build JAX-ROCm in docker for the runtime
# JAX on ROCm
This directory provides setup instructions and necessary files to build, test, and run JAX with ROCm support in a Docker environment, suitable for both runtime and CI workflows. Explore the following methods to use or build JAX on ROCm!

1. Install Docker: Follow the [instructions on the docker website](https://docs.docker.com/engine/installation/).
## 1. Using Prebuilt Docker Images

2. Build a runtime JAX-ROCm docker container and keep this image by running the following command. Note: must pass in appropriate
options. The example below builds Python 3.12 container.
The ROCm JAX team provides prebuilt Docker images, which the simplest way to use JAX on ROCm. These images are available on Docker Hub and come with JAX configured for ROCm.

To pull the latest ROCm JAX Docker image, run:

```Bash
> docker pull rocm/jax-community:latest
```

Once the image is downloaded, launch a container using the following command:

```Bash
> docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir --name rocm_jax rocm/jax-community:latest /bin/bash

> docker attach rocm_jax
```

### Notes:
1. The `--shm-size` parameter allocates shared memory for the container. Adjust it based on your system's resources if needed.
2. Replace `$(pwd)` with the absolute path to the directory you want to mount inside the container.

***For older versions please review the periodically pushed docker images at:
[ROCm JAX Community DockerHub](https://hub.docker.com/r/rocm/jax-community/tags).***

### Testing your ROCm environment with JAX:

After launching the container, test whether JAX detects ROCm devices as expected:

```Bash
> python -c "import jax; print(jax.devices())"
[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
```

If the setup is successful, the output should list all available ROCm devices.

## 2. Using a ROCm Docker Image and Installing JAX

If you prefer to use the ROCm Ubuntu image or already have a ROCm Ubuntu container, follow these steps to install JAX in the container.

### Step 1: Pull the ROCm Ubuntu Docker Image

For example, use the following command to pull the ROCm Ubuntu image:

```Bash
> docker pull rocm/dev-ubuntu-22.04:6.3-complete
```

### Step 2: Launch the Docker Container

After pulling the image, launch a container using this command:

```Bash
> docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v $(pwd):/jax_dir --name rocm_jax rocm/dev-ubuntu-22.04:6.3-complete /bin/bash
> docker attach rocm_jax
```

### Step 3: Install the Latest Version of JAX

Inside the running container, install the required version of JAX with ROCm support using pip:

```Bash
> pip3 install jax[rocm]
```

### Step 4: Verify the Installed JAX Version

Check whether the correct version of JAX and its ROCm plugins are installed:

```Bash
> pip3 freeze | grep jax
jax==0.4.35
jax-rocm60-pjrt==0.4.35
jax-rocm60-plugin==0.4.35
jaxlib==0.4.35
```

### Step 5: Set the `LLVM_PATH` Environment Variable

Explicitly set the `LLVM_PATH` environment variable (This helps XLA find `ld.lld` in the PATH during runtime):

```Bash
> export LLVM_PATH=/opt/rocm/llvm
```

### Step 6: Verify the Installation of ROCm JAX

Run the following command to verify that ROCm JAX is installed correctly:

```Bash
./build/rocm/ci_build.sh --py_version 3.12
> python3 -c "import jax; print(jax.devices())"
[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]

> python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"
[0 1 2 3 4]
```

3. To launch a JAX-ROCm container: If the build was successful, there should be a docker image with name "jax-rocm:latest" in list of docker images (use "docker images" command to list them).
## 3. Install JAX On Bare-metal or A Custom Container

Follow these steps if you prefer to install ROCm manually on your host system or in a custom container.

### Installing ROCm Libraries Manually

### Step 1: Install ROCm

Please follow [ROCm installation guide](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html) to install ROCm on your system.

Once installed, verify ROCm installation using:

```Bash
docker run -it -d --network=host --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size 64G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -v ./:/jax --name rocm_jax jax-rocm:latest /bin/bash
> rocm-smi

========================================== ROCm System Management Interface ==========================================
==================================================== Concise Info ====================================================
Device [Model : Revision] Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%
Name (20 chars) (Junction) (Socket) (Mem, Compute)
======================================================================================================================
0 [0x74a1 : 0x00] 50.0°C 170.0W NPS1, SPX 131Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
1 [0x74a1 : 0x00] 51.0°C 176.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
2 [0x74a1 : 0x00] 50.0°C 177.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
3 [0x74a1 : 0x00] 53.0°C 176.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
======================================================================================================================
================================================ End of ROCm SMI Log =================================================
```

***
### JAX ROCm Releases
We aim to push all ROCm-related changes to the OpenXLA repository. However, there may be times when certain JAX/jaxlib updates for
ROCm are not yet reflected in the upstream JAX repository. To address this, we maintain ROCm-specific JAX/jaxlib branches tied to JAX
releases. These branches are available in the ROCm fork of JAX at https://github.com/ROCm/jax. Look for branches named in the format
rocm-jaxlib-[jaxlib-version]. You can also find corresponding branches in https://github.com/ROCm/xla. For example, for JAX version
0.4.33, the branch is named rocm-jaxlib-v0.4.33, which can be accessed at https://github.com/ROCm/jax/tree/rocm-jaxlib-v0.4.33.
### Step 2: Install the Latest Version of JAX

JAX source-code and related wheels for ROCm are available here
Install the required version of JAX with ROCm support using pip:

```Bash
https://github.com/ROCm/jax/releases
> pip3 install jax[rocm]
```

***Note:*** Some earlier jaxlib versions on ROCm were released on ***PyPi***.
### Step 3: Verify the Installed JAX Version

Check whether the correct version of JAX and its ROCm plugins are installed:

```Bash
> pip3 freeze | grep jax
jax==0.4.35
jax-rocm60-pjrt==0.4.35
jax-rocm60-plugin==0.4.35
jaxlib==0.4.35
```
https://pypi.org/project/jaxlib-rocm/#history

### Step 4: Set the `LLVM_PATH` Environment Variable

Explicitly set the `LLVM_PATH` environment variable (This helps XLA find `ld.lld` in the PATH during runtime):

```Bash
> export LLVM_PATH=/opt/rocm/llvm
```

### Step 5: Verify the Installation of ROCm JAX

Run the following command to verify that ROCm JAX is installed correctly:

```Bash
> python3 -c "import jax; print(jax.devices())"
[RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]

> python3 -c "import jax.numpy as jnp; x = jnp.arange(5); print(x)"
[0 1 2 3 4]
```

## 4. Build ROCm JAX from Source

Follow these steps to build JAX with ROCm support from source:

### Step 1: Clone the Repository

Clone the ROCm-specific fork of JAX for the desired branch:

```Bash
> git clone https://github.com/ROCm/jax -b <branch_name>
> cd jax
```

### Step 2: Build the Wheels

Run the following command to build the necessary wheels:

```Bash
> python3 ./build/build.py build --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt \
--rocm_version=60 --rocm_path=/opt/rocm-[version]
```

This will generate three wheels in the `dist/` directory:

* jaxlib (generic, device agnostic library)
* jax-rocm-plugin (ROCm-specific plugin)
* jax-rocm-pjrt (ROCm-specific runtime)

### Step 3: Then install custom JAX using:

```Bash
> python3 setup.py develop --user && pip3 -m pip install dist/*.whl
```

### Simplified Build Script

For a streamlined process, consider using the `jax/build/rocm/dev_build_rocm.py` script.
15 changes: 13 additions & 2 deletions build/rocm/tools/build_wheels.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,22 @@ def build_jaxlib_wheel(
jax_path, rocm_path, python_version, xla_path=None, compiler="gcc"
):
use_clang = "true" if compiler == "clang" else "false"

# Avoid git warning by setting safe.directory.
try:
subprocess.run(
["git", "config", "--global", "--add", "safe.directory", "*"],
check=True,
)
except subprocess.CalledProcessError as e:
print(f"Failed to configure Git safe directory: {e}")
raise

cmd = [
"python",
"build/build.py",
"build"
"--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt"
"build",
"--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt",
"--rocm_path=%s" % rocm_path,
"--rocm_version=60",
"--use_clang=%s" % use_clang,
Expand Down
39 changes: 2 additions & 37 deletions docs/developer.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,43 +195,8 @@ To build with debug information, add the flag `--bazel_options='--copt=/Z7'`.

### Additional notes for building a ROCM `jaxlib` for AMD GPUs

You need several ROCM/HIP libraries installed to build for ROCM. For
example, on a Ubuntu machine with
[AMD's `apt` repositories available](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html),
you need a number of packages installed:

```
sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \
rccl-dev rccl hip-dev rocfft-dev roctracer-dev hipblas-dev rocm-device-libs
```

The recommended way to install these dependencies is by running our script, `jax/build/rocm/tools/get_rocm.py`,
and selecting the appropriate options.

To build jaxlib with ROCM support, you can run the following build commands,
suitably adjusted for your paths and ROCM version.

```
python3 ./build/build.py build --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt --rocm_version=60 --rocm_path=/opt/rocm-6.2.3
```
to generate three wheels (jaxlib without rocm, jax-rocm-plugin, and
jax-rocm-pjrt)

AMD's fork of the XLA repository may include fixes not present in the upstream
XLA repository. If you experience problems with the upstream repository, you can
try AMD's fork, by cloning their repository:

```
git clone https://github.com/ROCm/xla.git
```

and override the XLA repository with which JAX is built:

```
python3 ./build/build.py build --wheels=jax-rocm-plugin --rocm_version=60 --rocm_path=/opt/rocm-6.2.3 --local_xla_path=/rel/xla/
```

For a simplified installation process, we also recommend checking out the `jax/build/rocm/dev_build_rocm.py script`.
For detailed instructions on building `jaxlib` with ROCm support, refer to the official guide:
[Build ROCm JAX from Source](https://github.com/jax-ml/jax/blob/main/build/rocm/README.md)

## Managing hermetic Python

Expand Down
4 changes: 2 additions & 2 deletions docs/installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,8 @@ refer to

JAX has experimental ROCm support. There are two ways to install JAX:

* Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax); or
* Build from source (refer to {ref}`building-from-source` — a section called _Additional notes for building a ROCM `jaxlib` for AMD GPUs_).
* Use [AMD's Docker container](https://hub.docker.com/r/rocm/jax-community/tags); or
* Build from source. Refer to the section [Additional notes for building a ROCm jaxlib for AMD GPUs](https://jax.readthedocs.io/en/latest/developer.html#additional-notes-for-building-a-rocm-jaxlib-for-amd-gpus).

(install-intel-gpu)=
## Intel GPU
Expand Down
6 changes: 6 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ def load_version_module(pkg_path):
f"jax-cuda12-plugin=={_current_jaxlib_version}",
],

# ROCm support for ROCm 6.0 and above.
'rocm': [
f"jaxlib=={_current_jaxlib_version}",
f"jax-rocm60-plugin>={_current_jaxlib_version},<={_jax_version}",
],

# For automatic bootstrapping distributed jobs in Kubernetes
'k8s': [
'kubernetes',
Expand Down

0 comments on commit 99776f9

Please sign in to comment.