Add JetStream to MaxText container #321
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
name: nsys-jax pure-Python CI | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} | |
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} | |
on: | |
pull_request: | |
types: | |
- opened | |
- reopened | |
- ready_for_review | |
- synchronize | |
paths-ignore: | |
- '**.md' | |
push: | |
branches: | |
- main | |
env: | |
NSYS_JAX_PYTHON_FILES: | | |
JAX-Toolbox/.github/container/nsys-jax | |
JAX-Toolbox/.github/container/nsys-jax-combine | |
JAX-Toolbox/.github/container/jax_nsys | |
JAX-Toolbox/.github/container/jax-nccl-test | |
jobs: | |
mypy: | |
runs-on: ubuntu-24.04 | |
steps: | |
- name: Check out the repository under ${GITHUB_WORKSPACE} | |
uses: actions/checkout@v4 | |
with: | |
path: JAX-Toolbox | |
sparse-checkout: | | |
.github/container | |
- name: "Setup Python 3.10" | |
uses: actions/setup-python@v5 | |
with: | |
python-version: '3.10' | |
# jax is just a CPU-only build of the latest release for type-checking purposes | |
- name: "Install jax / jax_nsys / mypy" | |
run: pip install jax -e JAX-Toolbox/.github/container/jax_nsys/python/jax_nsys matplotlib mypy nbconvert types-protobuf | |
- name: "Install protoc" | |
run: ./JAX-Toolbox/.github/container/jax_nsys/install-protoc local_protoc | |
- name: "Fetch XLA .proto files" | |
uses: actions/checkout@v4 | |
with: | |
path: xla | |
repository: openxla/xla | |
sparse-checkout: | | |
*.proto | |
sparse-checkout-cone-mode: false | |
- name: "Compile .proto files" | |
shell: bash -x -e {0} | |
run: | | |
mkdir compiled_protos compiled_stubs protos | |
mv -v xla/third_party/tsl/tsl protos/ | |
mv -v xla/xla protos/ | |
PATH=${PWD}/local_protoc/bin:$PATH python -c "from jax_nsys import compile_protos; compile_protos(proto_dir='protos', output_dir='compiled_protos', output_stub_dir='compiled_stubs')" | |
touch compiled_stubs/py.typed | |
- name: "Convert .ipynb to .py" | |
shell: bash -x -e {0} | |
run: | | |
for notebook in $(find ${NSYS_JAX_PYTHON_FILES} -name '*.ipynb'); do | |
jupyter nbconvert --to script ${notebook} | |
done | |
- name: "Run mypy checks" | |
shell: bash -x -e {0} | |
run: | | |
export MYPYPATH="${PWD}/compiled_stubs" | |
mypy --scripts-are-modules ${NSYS_JAX_PYTHON_FILES} | |
# Test nsys-jax-combine and notebook execution; in future perhaps upload the rendered | |
# notebook from here too. These input files were generated with something like | |
# srun -n 4 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06 | |
# env NPROC=4 XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler\ | |
# --xla_gpu_enable_command_buffer= nsys-jax -o ...-fsdp4-4proc-proc%q{SLURM_PROCID} | |
# -- test-pax.sh --steps=5 --fsdp=4 --multiprocess | |
# with newer nsys-jax components bind-mounted in. | |
combine: | |
runs-on: ubuntu-24.04 | |
steps: | |
- name: Check out the repository under ${GITHUB_WORKSPACE} | |
uses: actions/checkout@v4 | |
- name: "Setup Python 3.12" | |
uses: actions/setup-python@v5 | |
with: | |
python-version: '3.12' | |
# TODO: a modern nsys-jax-combine with old .zip input should probably produce a | |
# .zip with a modern jax_nsys/ | |
- name: Add modern jax_nsys/ files to static .zip inputs | |
run: | | |
cd .github/container/jax_nsys | |
for zip in ../../workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip; do | |
zip -ur "${zip}" . | |
zipinfo "${zip}" | |
done | |
- name: Use nsys-jax-combine to merge profiles from multiple nsys processes | |
shell: bash -x -e {0} | |
run: | | |
pip install -e .github/container/jax_nsys/python/jax_nsys | |
python .github/container/jax_nsys/install-protoc local_protoc | |
PATH=${PWD}/local_protoc/bin:$PATH .github/container/nsys-jax-combine \ | |
--analysis summary \ | |
--analysis communication \ | |
-o pax_fsdp4_4proc.zip \ | |
.github/workflows/nsys-jax/test_data/pax_fsdp4_4proc_proc*.zip | |
- name: Extract the output .zip file | |
run: | | |
mkdir combined/ | |
unzip -d combined/ pax_fsdp4_4proc.zip | |
- name: Run the install script, but skip launching Jupyter Lab | |
shell: bash -x -e {0} | |
run: | | |
pip install virtualenv | |
NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./combined/install.sh | |
- name: Test the Jupyter Lab installation and execute the notebook | |
shell: bash -x -e {0} | |
run: | | |
pushd combined/ | |
./nsys_jax_venv/bin/python -m jupyterlab --version | |
# Run with ipython for the sake of getting a clear error message | |
./nsys_jax_venv/bin/ipython Analysis.ipynb | |
# This input file was generated with something like | |
# srun -n 1 --container-name=XXX --container-image=ghcr.io/nvidia/jax:pax-2024-07-06 | |
# env NPROC=4 XLA_FLAGS=--xla_gpu_enable_latency_hiding_scheduler\ | |
# --xla_gpu_enable_command_buffer= nsys-jax -o ...-fsdp4-1proc -- test-pax.sh | |
# --steps=5 --fsdp=4 | |
notebook: | |
env: | |
# TODO: these could/should be saved in the repository settings instead | |
RENDERED_NOTEBOOK_GIST_ID: e2cd3520201caab6b67385ed36fad3c1 | |
MOCK_RENDERED_NOTEBOOK_GIST_ID: 16698d9e9e52320243165d61b5bb3975 | |
# Name/bash regex for shields.io endpoint JSON files | |
PUBLISH_NOTEBOOK_FILES: '(.*\.ipynb|.*\.svg)' | |
runs-on: ubuntu-24.04 | |
steps: | |
- name: Check out the repository under ${GITHUB_WORKSPACE} | |
uses: actions/checkout@v4 | |
- name: Mock up the structure of an extracted .zip file | |
shell: bash -x -e {0} | |
run: | | |
# Get the actual test data from a real archive, minus the .nsys-rep file | |
unzip -d .github/container/jax_nsys/ .github/workflows/nsys-jax/test_data/pax_fsdp4_1proc.zip | |
- name: "Setup Python 3.10" | |
uses: actions/setup-python@v5 | |
with: | |
python-version: '3.10' | |
- name: Run the install script, but skip launching Jupyter Lab | |
shell: bash -x -e {0} | |
run: | | |
pip install virtualenv | |
NSYS_JAX_INSTALL_SKIP_LAUNCH=1 ./.github/container/jax_nsys/install.sh | |
- name: Test the Jupyter Lab installation and execute the notebook | |
shell: bash -x -e {0} | |
run: | | |
pushd .github/container/jax_nsys | |
./nsys_jax_venv/bin/python -m jupyterlab --version | |
# Run with ipython for the sake of getting a clear error message | |
./nsys_jax_venv/bin/ipython Analysis.ipynb | |
- name: Render the notebook | |
id: render | |
shell: bash -x -e {0} | |
run: | | |
pushd .github/container/jax_nsys | |
workdir=$(mktemp -d) | |
./nsys_jax_venv/bin/jupyter nbconvert --execute --inplace Analysis.ipynb | |
cp *.ipynb *.svg "${workdir}" | |
echo "WORKDIR=${workdir}" >> $GITHUB_OUTPUT | |
- name: Upload rendered notebook to Gist | |
id: upload | |
uses: actions/github-script@v7 | |
with: | |
github-token: ${{ secrets.NVJAX_GIST_TOKEN }} | |
script: | | |
const currentDateTime = new Date().toISOString(); | |
const gistDescription = | |
`Rendered IPython notebook from workflow: ${{ github.workflow }}, ` + | |
`Run ID: ${{ github.run_id }}, ` + | |
`Repository: ${{ github.repository }}, ` + | |
`Event: ${{ github.event_name }}, ` + | |
`Created: ${currentDateTime}`; | |
const fs = require('fs').promises; | |
const workdir = '${{ steps.render.outputs.WORKDIR }}' | |
const files = await fs.readdir(workdir); | |
gist = await github.rest.gists.create({ | |
description: gistDescription, | |
public: false, | |
files: Object.fromEntries( | |
await Promise.all( | |
files.map( | |
async filename => { | |
const content = await fs.readFile(`${workdir}/${filename}`, 'utf8'); | |
return [filename, { content }]; | |
} | |
) | |
) | |
) | |
}); | |
console.log(gist) | |
return gist.data.id; | |
- name: Copy rendered notebook to Gist with well-known ID | |
uses: actions/github-script@v7 | |
with: | |
github-token: ${{ secrets.NVJAX_GIST_TOKEN }} | |
script: | | |
const srcId = ${{ steps.upload.outputs.result }}; | |
const dstId = "${{ github.ref == 'refs/heads/main' && env.RENDERED_NOTEBOOK_GIST_ID || env.MOCK_RENDERED_NOTEBOOK_GIST_ID }}"; | |
const { PUBLISH_NOTEBOOK_FILES } = process.env; | |
// Fetch existing files from destination gist | |
const { data: dstData } = await github.rest.gists.get({ | |
gist_id: dstId | |
}); | |
// Mark existing files in destination gist for deletion | |
let filesToUpdate = {}; | |
for (const filename of Object.keys(dstData.files)) { | |
filesToUpdate[filename] = null; | |
} | |
// Fetch files from source gist | |
const { data: srcData } = await github.rest.gists.get({ | |
gist_id: srcId | |
}); | |
// Add or update files based on the pattern | |
const pattern = new RegExp(`${PUBLISH_NOTEBOOK_FILES}`); | |
for (const [filename, fileObj] of Object.entries(srcData.files)) { | |
if (filename.match(pattern)) { | |
// If the total gist size is too large, not all the content will have | |
// been returned and we need some extra requests. | |
if (fileObj.truncated) { | |
const { data } = await github.request(fileObj.raw_url) | |
filesToUpdate[filename] = { | |
content: new TextDecoder().decode(data) | |
}; | |
} else { | |
filesToUpdate[filename] = { | |
content: fileObj.content | |
}; | |
} | |
} | |
} | |
// Update files in destination gist | |
await github.rest.gists.update({ | |
gist_id: dstId, | |
files: filesToUpdate | |
}); | |
console.log("Files copied successfully."); | |
console.log(Object.keys(filesToUpdate)); | |
ruff: | |
runs-on: ubuntu-24.04 | |
steps: | |
- name: Check out the repository under ${GITHUB_WORKSPACE} | |
uses: actions/checkout@v4 | |
with: | |
path: JAX-Toolbox | |
sparse-checkout: | | |
.github/container | |
- name: "Setup Python 3.10" | |
uses: actions/setup-python@v5 | |
with: | |
python-version: '3.10' | |
- name: "Install ruff" | |
run: pip install ruff | |
- name: "Run ruff checks" | |
shell: bash -x {0} | |
run: | | |
ruff check ${NSYS_JAX_PYTHON_FILES} | |
check_status=$? | |
ruff format --check ${NSYS_JAX_PYTHON_FILES} | |
format_status=$? | |
if [[ $format_status != 0 || $check_status != 0 ]]; then | |
exit 1 | |
fi |