-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add commands to run Bazel GPU (non-RBE) jobs
This commit adds the commands needed for running Bazel GPU (non-RBE) tests. These run two Bazel commands: Single accelerator tests with one GPU a piece and multi-accelerator tests with all GPUs PiperOrigin-RevId: 692383915
- Loading branch information
1 parent
6a124ac
commit f752ddb
Showing
10 changed files
with
436 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
name: CI - Bazel CPU tests (RBE) | ||
|
||
# TODO(srnitin): Do not submit without removing pull_request event. | ||
on: | ||
pull_request: | ||
branches: | ||
- main | ||
workflow_dispatch: | ||
inputs: | ||
halt-for-connection: | ||
description: 'Should this workflow run wait for a remote connection?' | ||
type: choice | ||
required: true | ||
default: 'no' | ||
options: | ||
- 'yes' | ||
- 'no' | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
run_tests: | ||
if: github.event.repository.fork == false | ||
strategy: | ||
matrix: | ||
runner: ["linux-x86-n2-16", "linux-arm64-t2a-16"] | ||
|
||
runs-on: ${{ matrix.runner }} | ||
# TODO(b/369382309): Replace Linux Arm64 container with the ml-build container once it is available | ||
container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || | ||
(contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') }} | ||
|
||
env: | ||
JAXCI_HERMETIC_PYTHON_VERSION: "3.12" | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Wait For Connection | ||
uses: google-ml-infra/actions/ci_connection@main | ||
with: | ||
halt-dispatch-input: ${{ inputs.halt-for-connection }} | ||
- name: Run Bazel CPU Tests with RBE | ||
run: ./ci/run_bazel_test.sh "ci/envs/run_tests/bazel_cpu_rbe.env" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
name: CI - Bazel GPU tests (RBE) | ||
|
||
on: | ||
pull_request: | ||
branches: | ||
- main | ||
workflow_dispatch: | ||
inputs: | ||
halt-for-connection: | ||
description: 'Should this workflow run wait for a remote connection?' | ||
type: choice | ||
required: true | ||
default: 'no' | ||
options: | ||
- 'yes' | ||
- 'no' | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | ||
cancel-in-progress: true | ||
|
||
jobs: | ||
run_tests: | ||
if: github.event.repository.fork == false | ||
strategy: | ||
matrix: | ||
runner: ["linux-x86-g2-16-l4-1gpu"] | ||
|
||
runs-on: ${{ matrix.runner }} | ||
container: 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest' | ||
|
||
env: | ||
JAXCI_HERMETIC_PYTHON_VERSION: "3.12" | ||
|
||
steps: | ||
- uses: actions/checkout@v3 | ||
- name: Wait For Connection | ||
uses: google-ml-infra/actions/ci_connection@main | ||
with: | ||
halt-dispatch-input: ${{ inputs.halt-for-connection }} | ||
- name: Run Bazel GPU Tests with RBE | ||
run: ./ci/run_bazel_test.sh "ci/envs/run_tests/bazel_gpu_rbe.env" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# JAX continuous integration | ||
|
||
> [!WARNING] | ||
> This folder is still under construction. It is part of an ongoing | ||
> effort to improve the structure of CI and build related files within the | ||
> JAX repo. This warning will be removed when the contents of this | ||
> directory are stable and appropriate documentation around its usage is in | ||
> place. | ||
******************************************************************************** |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# Copyright 2024 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
# This file contains all the default values for the "JAXCI_" environment | ||
# variables used in the CI scripts. These variables are used to control the | ||
# behavior of the CI scripts such as the Python version used, path to JAX/XLA | ||
# repo, if to clone XLA repo, etc. | ||
|
||
# The path to the JAX git repository. | ||
export JAXCI_JAX_GIT_DIR=$(pwd) | ||
|
||
# Controls the version of Hermetic Python to use. Use system default if not | ||
# set. | ||
export JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION:-$(python3 -V | awk '{print $2}' | awk -F. '{print $1"."$2}')} | ||
|
||
# Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository to use a local | ||
# copy of XLA instead of the pinned version in the WORKSPACE. When | ||
# JAXCI_CLONE_MAIN_XLA=1, this gets set automatically. | ||
export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-} | ||
|
||
# If set to 1, the builds will clone the XLA repository at HEAD and set its | ||
# path in JAXCI_XLA_GIT_DIR. | ||
export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0} | ||
|
||
# Allows overriding the XLA commit that is used. | ||
export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} | ||
|
||
# ############################################################################# | ||
# Environment variables that control the type of tests that are run by the test | ||
# scripts. | ||
# ############################################################################# | ||
export JAXCI_RUN_BAZEL_TEST_CPU_RBE=0 | ||
export JAXCI_RUN_BAZEL_TEST_GPU_RBE=0 | ||
export JAXCI_RUN_BAZEL_TEST_GPU_NON_RBE=0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright 2024 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
# Inherit default JAXCI environment variables. | ||
source ci/envs/default.env | ||
|
||
# Enable Bazel CPU tests. | ||
export JAXCI_RUN_BAZEL_TEST_CPU_RBE=1 | ||
|
||
# Clone XLA at HEAD. | ||
export JAXCI_CLONE_MAIN_XLA=1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
# Copyright 2024 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
# Inherit default JAXCI environment variables. | ||
source ci/envs/default.env | ||
|
||
# Enable non-RBE Bazel GPU tests (single accelerator and multi-accelerator tests) | ||
export JAXCI_RUN_BAZEL_TEST_GPU_NON_RBE=1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Copyright 2024 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
# Inherit default JAXCI environment variables. | ||
source ci/envs/default.env | ||
|
||
# Enable Bazel GPU tests with RBE that runs single accelerator tests with | ||
# one GPU a piece. | ||
export JAXCI_RUN_BAZEL_TEST_GPU_RBE=1 | ||
|
||
# Clone XLA at HEAD. | ||
export JAXCI_CLONE_MAIN_XLA=1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
#!/bin/bash | ||
# Copyright 2024 The JAX Authors. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
# Source "JAXCI_" environment variables. | ||
source "ci/utilities/source_jaxci_envs.sh" "$1" | ||
# Set up the build environment. | ||
source "ci/utilities/setup_build_environment.sh" | ||
|
||
# Run Bazel CPU tests with RBE. | ||
if [[ $JAXCI_RUN_BAZEL_TEST_CPU_RBE == 1 ]]; then | ||
os=$(uname -s | awk '{print tolower($0)}') | ||
arch=$(uname -m) | ||
|
||
# When running on Mac or Linux Aarch64, we only build the test targets and | ||
# not run them. These platforms do not have native RBE support so we | ||
# RBE cross-compile them on remote Linux x86 machines. As the tests still | ||
# need to be run on the host machine and because running the tests on a | ||
# single machine can take a long time, we skip running them on these | ||
# platforms. | ||
if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ); then | ||
echo "Building RBE CPU tests..." | ||
bazel build --config=rbe_cross_compile_${os}_${arch} \ | ||
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ | ||
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ | ||
--test_env=JAX_NUM_GENERATED_CASES=25 \ | ||
--test_env=JAX_SKIP_SLOW_TESTS=true \ | ||
--action_env=JAX_ENABLE_X64=0 \ | ||
--test_output=errors \ | ||
--color=yes \ | ||
//tests:cpu_tests //tests:backend_independent_tests | ||
else | ||
echo "Running RBE CPU tests..." | ||
bazel test --config=rbe_${os}_${arch} \ | ||
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ | ||
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ | ||
--test_env=JAX_NUM_GENERATED_CASES=25 \ | ||
--test_env=JAX_SKIP_SLOW_TESTS=true \ | ||
--action_env=JAX_ENABLE_X64=0 \ | ||
--test_output=errors \ | ||
--color=yes \ | ||
//tests:cpu_tests //tests:backend_independent_tests | ||
fi | ||
fi | ||
|
||
# Run Bazel GPU tests with RBE. | ||
if [[ $JAXCI_RUN_BAZEL_TEST_GPU_RBE == 1 ]]; then | ||
nvidia-smi | ||
echo "Running RBE GPU tests..." | ||
|
||
# Only Linux x86 builds run GPU tests | ||
# Runs single accelerator tests with one GPU apiece. | ||
bazel test --config=rbe_linux_x86_64_cuda \ | ||
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ | ||
--override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ | ||
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ | ||
--test_output=errors \ | ||
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \ | ||
--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ | ||
--test_tag_filters=-multiaccelerator \ | ||
--test_env=JAX_SKIP_SLOW_TESTS=true \ | ||
--action_env=JAX_ENABLE_X64=0 \ | ||
--color=yes \ | ||
//tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests | ||
fi | ||
|
||
# Run Non-RBE Bazel GPU tests (single accelerator and multiaccelerator tests). | ||
if [[ $JAXCI_RUN_BAZEL_TEST_GPU_NON_RBE == 1 ]]; then | ||
nvidia-smi | ||
echo "Running single accelerator tests (no RBE)..." | ||
|
||
# Runs single accelerator tests with one GPU apiece. | ||
# It appears --run_under needs an absolute path. | ||
# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR` | ||
# should match the VM's CPU core count (set in `--local_test_jobs`). | ||
bazel test --config=ci_linux_x86_64_cuda \ | ||
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ | ||
--//jax:build_jaxlib=false \ | ||
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ | ||
--run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ | ||
--test_output=errors \ | ||
--test_env=JAX_ACCELERATOR_COUNT=4 \ | ||
--test_env=JAX_TESTS_PER_ACCELERATOR=12 \ | ||
--local_test_jobs=48 \ | ||
--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ | ||
--test_tag_filters=-multiaccelerator \ | ||
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \ | ||
--test_env=JAX_SKIP_SLOW_TESTS=true \ | ||
--action_env=JAX_ENABLE_X64=0 \ | ||
--action_env=NCCL_DEBUG=WARN \ | ||
--color=yes \ | ||
//tests:gpu_tests //tests:backend_independent_tests \ | ||
//tests/pallas:gpu_tests //tests/pallas:backend_independent_tests | ||
|
||
echo "Running multi-accelerator tests (no RBE)..." | ||
# Runs multiaccelerator tests with all GPUs. | ||
bazel test --config=ci_linux_x86_64_cuda \ | ||
--repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ | ||
--//jax:build_jaxlib=false \ | ||
--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ | ||
--test_output=errors \ | ||
--jobs=8 \ | ||
--test_tag_filters=multiaccelerator \ | ||
--test_env=TF_CPP_MIN_LOG_LEVEL=0 \ | ||
--test_env=JAX_SKIP_SLOW_TESTS=true \ | ||
--action_env=JAX_ENABLE_X64=0 \ | ||
--action_env=NCCL_DEBUG=WARN \ | ||
--color=yes \ | ||
//tests:gpu_tests //tests/pallas:gpu_tests | ||
fi |
Oops, something went wrong.