diff --git a/.github/workflows/bazel_cpu_rbe.yml b/.github/workflows/bazel_cpu_rbe.yml new file mode 100644 index 000000000000..8862ddf05a19 --- /dev/null +++ b/.github/workflows/bazel_cpu_rbe.yml @@ -0,0 +1,44 @@ +name: CI - Bazel CPU tests (RBE) + +on: + pull_request: + branches: + - main + workflow_dispatch: + inputs: + halt-for-connection: + description: 'Should this workflow run wait for a remote connection?' + type: choice + required: true + default: 'no' + options: + - 'yes' + - 'no' + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} + cancel-in-progress: true + +jobs: + run_tests: + if: github.event.repository.fork == false + strategy: + matrix: + runner: ["linux-x86-n2-64", "linux-arm64-t2a-48"] + + runs-on: ${{ matrix.runner }} + # TODO(b/369382309): Replace Linux Arm64 container with the ml-build container once it is available + container: ${{ (contains(matrix.runner, 'linux-x86') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/ml-build:latest') || + (contains(matrix.runner, 'linux-arm64') && 'us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/linux-arm64-arc-container:latest') }} + + env: + JAXCI_HERMETIC_PYTHON_VERSION: "3.12" + + steps: + - uses: actions/checkout@v3 + - name: Wait For Connection + uses: google-ml-infra/actions/ci_connection@main + with: + halt-dispatch-input: ${{ inputs.halt-for-connection }} + - name: Run Bazel CPU Tests with RBE + run: ./ci/run_bazel_test.sh "ci/envs/run_tests/bazel_cpu_rbe.env" \ No newline at end of file diff --git a/ci/README.md b/ci/README.md new file mode 100644 index 000000000000..ea867df52f97 --- /dev/null +++ b/ci/README.md @@ -0,0 +1,10 @@ +# JAX continuous integration + +> [!WARNING] +> This folder is still under construction. It is part of an ongoing +> effort to improve the structure of CI and build related files within the +> JAX repo. This warning will be removed when the contents of this +> directory are stable and appropriate documentation around its usage is in +> place. + +******************************************************************************** \ No newline at end of file diff --git a/ci/envs/default.env b/ci/envs/default.env new file mode 100644 index 000000000000..c89eefccbc55 --- /dev/null +++ b/ci/envs/default.env @@ -0,0 +1,43 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# This file contains all the default values for the "JAXCI_" environment +# variables used in the CI scripts. These variables are used to control the +# behavior of the CI scripts such as the Python version used, path to JAX/XLA +# repo, if to clone XLA repo, etc. + +# The path to the JAX git repository. +export JAXCI_JAX_GIT_DIR=$(pwd) + +# Controls the version of Hermetic Python to use. Use system default if not +# set. +export JAXCI_HERMETIC_PYTHON_VERSION=${JAXCI_HERMETIC_PYTHON_VERSION:-$(python3 -V | awk '{print $2}' | awk -F. '{print $1"."$2}')} + +# Set JAXCI_XLA_GIT_DIR to the root of the XLA git repository to use a local +# copy of XLA instead of the pinned version in the WORKSPACE. When +# JAXCI_CLONE_MAIN_XLA=1, this gets set automatically. +export JAXCI_XLA_GIT_DIR=${JAXCI_XLA_GIT_DIR:-} + +# If set to 1, the builds will clone the XLA repository at HEAD and set its +# path in JAXCI_XLA_GIT_DIR. +export JAXCI_CLONE_MAIN_XLA=${JAXCI_CLONE_MAIN_XLA:-0} + +# Allows overriding the XLA commit that is used. +export JAXCI_XLA_COMMIT=${JAXCI_XLA_COMMIT:-} + +# ############################################################################# +# Environment variables that control the type of tests that are run by the test +# scripts. +# ############################################################################# +export JAXCI_RUN_BAZEL_TEST_CPU_RBE=0 \ No newline at end of file diff --git a/ci/envs/run_tests/bazel_cpu_rbe.env b/ci/envs/run_tests/bazel_cpu_rbe.env new file mode 100644 index 000000000000..79847364b495 --- /dev/null +++ b/ci/envs/run_tests/bazel_cpu_rbe.env @@ -0,0 +1,22 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Inherit default JAXCI environment variables. +source ci/envs/default.env + +# Enable Bazel CPU tests. +export JAXCI_RUN_BAZEL_TEST_CPU_RBE=1 + +# Clone XLA at HEAD. +export JAXCI_CLONE_MAIN_XLA=1 \ No newline at end of file diff --git a/ci/run_bazel_test.sh b/ci/run_bazel_test.sh new file mode 100755 index 000000000000..3390baaf2924 --- /dev/null +++ b/ci/run_bazel_test.sh @@ -0,0 +1,53 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# Source "JAXCI_" environment variables. +source "ci/utilities/source_jaxci_envs.sh" "$1" +# Set up the build environment. +source "ci/utilities/setup_build_environment.sh" + +# Run Bazel CPU tests with RBE. +if [[ $JAXCI_RUN_BAZEL_TEST_CPU_RBE == 1 ]]; then + os=$(uname -s | awk '{print tolower($0)}') + arch=$(uname -m) + + # When running on Mac or Linux Aarch64, we only build the test targets and + # not run them. These platforms do not have native RBE support so we + # RBE cross-compile them on remote Linux x86 machines. As the tests still + # need to be run on the host machine and because running the tests on a + # single machine can take a long time, we skip running them on these + # platforms. + if [[ $os == "darwin" ]] || ( [[ $os == "linux" ]] && [[ $arch == "aarch64" ]] ); then + echo "Building RBE CPU tests..." + bazel build --config=rbe_cross_compile_${os}_${arch} \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64=0 \ + --test_output=errors \ + //tests:cpu_tests //tests:backend_independent_tests + else + echo "Running RBE CPU tests..." + bazel test --config=rbe_${os}_${arch} \ + --repo_env=HERMETIC_PYTHON_VERSION="$JAXCI_HERMETIC_PYTHON_VERSION" \ + --override_repository=xla="${JAXCI_XLA_GIT_DIR}" \ + --test_env=JAX_NUM_GENERATED_CASES=25 \ + --test_env=JAX_SKIP_SLOW_TESTS=true \ + --action_env=JAX_ENABLE_X64=0 \ + --test_output=errors \ + //tests:cpu_tests //tests:backend_independent_tests + fi +fi \ No newline at end of file diff --git a/ci/utilities/setup_build_environment.sh b/ci/utilities/setup_build_environment.sh new file mode 100644 index 000000000000..6ade4c7b4893 --- /dev/null +++ b/ci/utilities/setup_build_environment.sh @@ -0,0 +1,75 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Set up the build environment for JAX CI jobs. This script depends on the +# environment variables sourced in `source_jax_ci_envs.sh`. +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o pipefail: entire command fails if pipe fails. watch out for yes | ... +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exuo pipefail -o history -o allexport + +# Pre-emptively mark the JAX git directory as safe. This is necessary for JAX CI +# jobs running on Linux runners in GitHub Actions. Without this, git complains +# that the directory has dubious ownership and refuses to run any commands. +# Avoid running on Windows runners as git runs into issues with not being able +# to lock the config file. Other git commands seem to work on the Windows +# runners so we can skip this step for Windows. +# TODO(b/375073267): Remove this once we understand why git repositories are +# being marked as unsafe inside the self-hosted runners. +if [[ ! $(uname -s) =~ "MSYS_NT" ]]; then + git config --global --add safe.directory $JAXCI_JAX_GIT_DIR +fi + +function clone_main_xla() { + echo "Cloning XLA at HEAD to $(pwd)/xla" + git clone --depth=1 https://github.com/openxla/xla.git $(pwd)/xla + export JAXCI_XLA_GIT_DIR=$(pwd)/xla +} + +# Clone XLA at HEAD if required. +if [[ "$JAXCI_CLONE_MAIN_XLA" == 1 ]]; then + # Clone only if $(pwd)/xla does not exist to avoid failure on re-runs. + if [[ ! -d $(pwd)/xla ]]; then + clone_main_xla + fi +fi + +# If a XLA commit is provided, check out XLA at that commit. +if [[ ! -z "$JAXCI_XLA_COMMIT" ]]; then + # Clone XLA at HEAD if a path to local XLA is not provided. + if [[ -z "$JAXCI_XLA_GIT_DIR" ]]; then + clone_main_xla + fi + pushd "$JAXCI_XLA_GIT_DIR" + + git fetch --depth=1 origin "$JAXCI_XLA_COMMIT" + echo "JAXCI_XLA_COMMIT is set. Checking out XLA at $JAXCI_XLA_COMMIT" + git checkout "$JAXCI_XLA_COMMIT" + + popd +fi + +if [[ ! -z ${JAXCI_XLA_GIT_DIR} ]]; then + echo "INFO: Overriding XLA to be read from $JAXCI_XLA_GIT_DIR instead of the" + echo "pinned version in the WORKSPACE." + echo "If you would like to revert this behavior, unset JAXCI_CLONE_MAIN_XLA" + echo "and JAXCI_XLA_COMMIT in your environment. Note that the Bazel RBE test" + echo "commands overrides the XLA repository and thus require a local copy of" + echo "XLA to run." +fi \ No newline at end of file diff --git a/ci/utilities/source_jaxci_envs.sh b/ci/utilities/source_jaxci_envs.sh new file mode 100644 index 000000000000..51acc601eacd --- /dev/null +++ b/ci/utilities/source_jaxci_envs.sh @@ -0,0 +1,34 @@ +#!/bin/bash +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +# +# Source "JAXCI_" environment variables. + +# If a JAX CI env file has not been passed, exit. +if [[ -z "$1" ]]; then + echo "ERROR: No JAX CI env file passed." + echo "This script requires a path to a JAX CI env file as an argument." + echo "Please provide an env file from the ci/envs directory." + exit 1 +fi + +# -e: abort script if one command fails +# -u: error if undefined variable used +# -x: log all commands +# -o pipefail: entire command fails if pipe fails. watch out for yes | ... +# -o history: record shell history +# -o allexport: export all functions and variables to be available to subscripts +set -exuo pipefail -o history -o allexport +source "$1" \ No newline at end of file