Skip to content

Add JetStream to MaxText container #321

Add JetStream to MaxText container

Add JetStream to MaxText container #321

Workflow file for this run

name: nsys-jax pure-Python CI
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
on:
pull_request:
types:
- opened
- reopened
- ready_for_review
- synchronize
paths-ignore:
- '**.md'
push:
branches:
- main
env:
NSYS_JAX_PYTHON_FILES: |
JAX-Toolbox/.github/container/nsys-jax
JAX-Toolbox/.github/container/nsys-jax-combine
JAX-Toolbox/.github/container/jax_nsys
JAX-Toolbox/.github/container/jax-nccl-test
jobs:
mypy:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
with:
path: JAX-Toolbox
sparse-checkout: |
.github/container
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
# jax is just a CPU-only build of the latest release for type-checking purposes
- name: "Install jax / jax_nsys / mypy"
run: pip install jax -e JAX-Toolbox/.github/container/jax_nsys/python/jax_nsys matplotlib mypy nbconvert types-protobuf
- name: "Install protoc"
run: ./JAX-Toolbox/.github/container/jax_nsys/install-protoc local_protoc
- name: "Fetch XLA .proto files"
uses: actions/checkout@v4
with:
path: xla
repository: openxla/xla
sparse-checkout: |
*.proto
sparse-checkout-cone-mode: false
- name: "Compile .proto files"
shell: bash -x -e {0}
run: |
mkdir compiled_protos compiled_stubs protos
mv -v xla/third_party/tsl/tsl protos/
mv -v xla/xla protos/
PATH=${PWD}/local_protoc/bin:$PATH python -c "from jax_nsys import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')"
touch compiled_stubs/py.typed
- name: "Convert .ipynb to .py"
shell: bash -x -e {0}
run: |
for notebook in $(find ${NSYS_JAX_PYTHON_FILES} -name '*.ipynb'); do
jupyter nbconvert --to script ${notebook}
done
- name: "Run mypy checks"
shell: bash -x -e {0}
run: |
export MYPYPATH="${PWD}/compiled_stubs"
mypy --scripts-are-modules ${NSYS_JAX_PYTHON_FILES}
# Test nsys-jax-combine and notebook execution; in future perhaps upload the rendered
# notebook from here too. These input files were generated with something like
# srun -n 4 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06
# env NPROC=4 XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler\
# --xla_gpu_enable_command_buffer= nsys-jax -o ...-fsdp4-4proc-proc%q{SLURM_PROCID}
# -- test-pax.sh --steps=5 --fsdp=4 --multiprocess
# with newer nsys-jax components bind-mounted in.
combine:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
- name: "Setup Python 3.12"
uses: actions/setup-python@v5
with:
python-version: '3.12'
# TODO: a modern nsys-jax-combine with old .zip input should probably produce a
# .zip with a modern jax_nsys/
- name: Add modern jax_nsys/ files to static .zip inputs
run: |
cd .github/container/jax_nsys
for zip in ../../workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip; do
zip -ur "${zip}" .
zipinfo "${zip}"
done
- name: Use nsys-jax-combine to merge profiles from multiple nsys processes
shell: bash -x -e {0}
run: |
pip install -e .github/container/jax_nsys/python/jax_nsys
python .github/container/jax_nsys/install-protoc local_protoc
PATH=${PWD}/local_protoc/bin:$PATH .github/container/nsys-jax-combine \
--analysis summary \
--analysis communication \
-o pax_fsdp4_4proc.zip \
.github/workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip
- name: Extract the output .zip file
run: |
mkdir combined/
unzip -d combined/ pax_fsdp4_4proc.zip
- name: Run the install script, but skip launching Jupyter Lab
shell: bash -x -e {0}
run: |
pip install virtualenv
NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./combined/install.sh
- name: Test the Jupyter Lab installation and execute the notebook
shell: bash -x -e {0}
run: |
pushd combined/
./nsys_jax_venv/bin/python -m jupyterlab --version
# Run with ipython for the sake of getting a clear error message
./nsys_jax_venv/bin/ipython Analysis.ipynb
# This input file was generated with something like
# srun -n 1 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06
# env NPROC=4 XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler\
# --xla_gpu_enable_command_buffer= nsys-jax -o ...-fsdp4-1proc -- test-pax.sh
# --steps=5 --fsdp=4
notebook:
env:
# TODO: these could/should be saved in the repository settings instead
RENDERED_NOTEBOOK_GIST_ID: e2cd3520201caab6b67385ed36fad3c1
MOCK_RENDERED_NOTEBOOK_GIST_ID: 16698d9e9e52320243165d61b5bb3975
# Name/bash regex for shields.io endpoint JSON files
PUBLISH_NOTEBOOK_FILES: '(.*\.ipynb|.*\.svg)'
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
- name: Mock up the structure of an extracted .zip file
shell: bash -x -e {0}
run: |
# Get the actual test data from a real archive, minus the .nsys-rep file
unzip -d .github/container/jax_nsys/ .github/workflows/nsys-jax/test_data/pax_fsdp4_1proc.zip
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Run the install script, but skip launching Jupyter Lab
shell: bash -x -e {0}
run: |
pip install virtualenv
NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./.github/container/jax_nsys/install.sh
- name: Test the Jupyter Lab installation and execute the notebook
shell: bash -x -e {0}
run: |
pushd .github/container/jax_nsys
./nsys_jax_venv/bin/python -m jupyterlab --version
# Run with ipython for the sake of getting a clear error message
./nsys_jax_venv/bin/ipython Analysis.ipynb
- name: Render the notebook
id: render
shell: bash -x -e {0}
run: |
pushd .github/container/jax_nsys
workdir=$(mktemp -d)
./nsys_jax_venv/bin/jupyter nbconvert --execute --inplace Analysis.ipynb
cp *.ipynb *.svg "${workdir}"
echo "WORKDIR=${workdir}" >> $GITHUB_OUTPUT
- name: Upload rendered notebook to Gist
id: upload
uses: actions/github-script@v7
with:
github-token: ${{ secrets.NVJAX_GIST_TOKEN }}
script: |
const currentDateTime = new Date().toISOString();
const gistDescription =
`Rendered IPython notebook from workflow: ${{ github.workflow }}, ` +
`Run ID: ${{ github.run_id }}, ` +
`Repository: ${{ github.repository }}, ` +
`Event: ${{ github.event_name }}, ` +
`Created: ${currentDateTime}`;
const fs = require('fs').promises;
const workdir = '${{ steps.render.outputs.WORKDIR }}'
const files = await fs.readdir(workdir);
gist = await github.rest.gists.create({
description: gistDescription,
public: false,
files: Object.fromEntries(
await Promise.all(
files.map(
async filename => {
const content = await fs.readFile(`${workdir}/${filename}`, 'utf8');
return [filename, { content }];
}
)
)
)
});
console.log(gist)
return gist.data.id;
- name: Copy rendered notebook to Gist with well-known ID
uses: actions/github-script@v7
with:
github-token: ${{ secrets.NVJAX_GIST_TOKEN }}
script: |
const srcId = ${{ steps.upload.outputs.result }};
const dstId = "${{ github.ref == 'refs/heads/main' && env.RENDERED_NOTEBOOK_GIST_ID || env.MOCK_RENDERED_NOTEBOOK_GIST_ID }}";
const { PUBLISH_NOTEBOOK_FILES } = process.env;
// Fetch existing files from destination gist
const { data: dstData } = await github.rest.gists.get({
gist_id: dstId
});
// Mark existing files in destination gist for deletion
let filesToUpdate = {};
for (const filename of Object.keys(dstData.files)) {
filesToUpdate[filename] = null;
}
// Fetch files from source gist
const { data: srcData } = await github.rest.gists.get({
gist_id: srcId
});
// Add or update files based on the pattern
const pattern = new RegExp(`${PUBLISH_NOTEBOOK_FILES}`);
for (const [filename, fileObj] of Object.entries(srcData.files)) {
if (filename.match(pattern)) {
// If the total gist size is too large, not all the content will have
// been returned and we need some extra requests.
if (fileObj.truncated) {
const { data } = await github.request(fileObj.raw_url)
filesToUpdate[filename] = {
content: new TextDecoder().decode(data)
};
} else {
filesToUpdate[filename] = {
content: fileObj.content
};
}
}
}
// Update files in destination gist
await github.rest.gists.update({
gist_id: dstId,
files: filesToUpdate
});
console.log("Files copied successfully.");
console.log(Object.keys(filesToUpdate));
ruff:
runs-on: ubuntu-24.04
steps:
- name: Check out the repository under ${GITHUB_WORKSPACE}
uses: actions/checkout@v4
with:
path: JAX-Toolbox
sparse-checkout: |
.github/container
- name: "Setup Python 3.10"
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: "Install ruff"
run: pip install ruff
- name: "Run ruff checks"
shell: bash -x {0}
run: |
ruff check ${NSYS_JAX_PYTHON_FILES}
check_status=$?
ruff format --check ${NSYS_JAX_PYTHON_FILES}
format_status=$?
if [[ $format_status != 0 || $check_status != 0 ]]; then
exit 1
fi