diff --git a/ci/envs/default.env b/ci/envs/default.env index f50b7549b823..e3bf1a5ab47a 100644 --- a/ci/envs/default.env +++ b/ci/envs/default.env @@ -42,4 +42,18 @@ export JAXCI_OUTPUT_DIR="$(pwd)/dist" # When enabled, artifacts will be built with RBE. Requires gcloud authentication # and only certain platforms support RBE. Therefore, this flag is enabled only # for CI builds where RBE is supported. -export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} \ No newline at end of file +export JAXCI_BUILD_ARTIFACT_WITH_RBE=${JAXCI_BUILD_ARTIFACT_WITH_RBE:-0} + +# ############################################################################# +# Test script specific environment variables. +# ############################################################################# +# The maximum number of tests to run per GPU when running single accelerator +# tests with parallel execution with Bazel. The GPU limit is set because we +# need to allow about 2GB of GPU RAM per test. Default is set to 12 because we +# use L4 machines which have 24GB of RAM but can be overriden if we use a +# different GPU type. +export JAXCI_MAX_TESTS_PER_GPU=${JAXCI_MAX_TESTS_PER_GPU:-12} + +# Sets the value of `JAX_ENABLE_X64` in the test scripts. CI builds override +# this value in the Github action workflow files. +export JAXCI_ENABLE_X64=${JAXCI_ENABLE_X64:-0} diff --git a/ci/run_bazel_test_cpu_rbe.sh b/ci/run_bazel_test_cpu_rbe.sh index 6ba9f6dce239..248111e0247a 100755 --- a/ci/run_bazel_test_cpu_rbe.sh +++ b/ci/run_bazel_test_cpu_rbe.sh @@ -50,7 +50,7 @@ if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ) --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ --test_env=JAX_NUM_GENERATED_CASES=25 \ --test_env=JAX_SKIP_SLOW_TESTS=true \ - --action_env=JAX_ENABLE_X64=0 \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --test_output=errors \ --color=yes \ //tests:cpu_tests //tests:backend_independent_tests @@ -61,7 +61,7 @@ else --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ --test_env=JAX_NUM_GENERATED_CASES=25 \ --test_env=JAX_SKIP_SLOW_TESTS=true \ - --action_env=JAX_ENABLE_X64=0 \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --test_output=errors \ --color=yes \ //tests:cpu_tests //tests:backend_independent_tests diff --git a/ci/run_bazel_test_gpu_non_rbe.sh b/ci/run_bazel_test_gpu_non_rbe.sh new file mode 100755 index 000000000000..7828cf41c60e --- /dev/null +++ b/ci/run_bazel_test_gpu_non_rbe.sh @@ -0,0 +1,87 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Run Bazel GPU tests without RBE. This runs two commands: single accelerator +# tests with one GPU a piece, multiaccelerator tests with all GPUS. +# Requires that jaxlib, jax-cuda-plugin, and jax-cuda-pjrt wheels are stored +# inside the ../dist folder +# +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exu -o history -o allexport + +# Source default JAXCI environment variables. +source ci/envs/default.env + +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +# Run Bazel GPU tests (single accelerator and multiaccelerator tests) directly +# on the VM without RBE. +nvidia-smi +echo "Running single accelerator tests (without RBE)..." + +# Set up test environment variables. +export gpu_count=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +export num_test_jobs=$((gpu_count * JAXCI_MAX_TESTS_PER_GPU)) +export num_cpu_cores=$(nproc) + +# tests_jobs = max(gpu_count * max_tests_per_gpu, num_cpu_cores) +if [[ $num_test_jobs -gt $num_cpu_cores ]]; then + num_test_jobs=$num_cpu_cores +fi +# End of test environment variables setup. + +# Runs single accelerator tests with one GPU apiece. +# It appears --run_under needs an absolute path. +# The product of the `JAX_ACCELERATOR_COUNT`` and `JAX_TESTS_PER_ACCELERATOR` +# should match the VM's CPU core count (set in `--local_test_jobs`). +bazel test --config=ci_linux_x86_64_cuda \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --//jax:build_jaxlib=false \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --run_under "$(pwd)/build/parallel_accelerator_execute.sh" \ + --test_output=errors \ + --test_env=JAX_ACCELERATOR_COUNT=$gpu_count \ + --test_env=JAX_TESTS_PER_ACCELERATOR=$JAXCI_MAX_TESTS_PER_GPU \ + --local_test_jobs=$num_test_jobs \ + --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ + --test_tag_filters=-multiaccelerator \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --action_env=NCCL_DEBUG=WARN \ + --color=yes \ + //tests:gpu_tests //tests:backend_independent_tests \ + //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests + +echo "Running multi-accelerator tests (without RBE)..." +# Runs multiaccelerator tests with all GPUs directly on the VM without RBE.. +bazel test --config=ci_linux_x86_64_cuda \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --//jax:build_jaxlib=false \ + --test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform \ + --test_output=errors \ + --jobs=8 \ + --test_tag_filters=multiaccelerator \ + --test_env=TF_CPP_MIN_LOG_LEVEL=0 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ + --action_env=NCCL_DEBUG=WARN \ + --color=yes \ + //tests:gpu_tests //tests/pallas:gpu_tests \ No newline at end of file diff --git a/ci/run_bazel_test_gpu_rbe.sh b/ci/run_bazel_test_gpu_rbe.sh index 0c004c584300..17bd8d9db4f8 100755 --- a/ci/run_bazel_test_gpu_rbe.sh +++ b/ci/run_bazel_test_gpu_rbe.sh @@ -46,6 +46,6 @@ bazel test --config=rbe_linux_x86_64_cuda \ --test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow \ --test_tag_filters=-multiaccelerator \ --test_env=JAX_SKIP_SLOW_TESTS=true \ - --action_env=JAX_ENABLE_X64=0 \ + --action_env=JAX_ENABLE_X64="$JAXCI_ENABLE_X64" \ --color=yes \ //tests:gpu_tests //tests:backend_independent_tests //tests/pallas:gpu_tests //tests/pallas:backend_independent_tests \ No newline at end of file