From d8056e01b4aa087935ee9c9576ab8f13d7e54359 Mon Sep 17 00:00:00 2001 From: Olli Lupton Date: Wed, 4 Dec 2024 16:41:39 +0000 Subject: [PATCH] Convert .csv to .parquet in nsys-jax to avoid compressing a large .csv with Python's lzma. --- .../nsys_jax/nsys_jax/data_loaders.py | 32 +++++++++++++------ .../nsys_jax/nsys_jax/scripts/nsys_jax.py | 31 +++++++++++++++--- .github/workflows/_ci.yaml | 5 ++- 3 files changed, 53 insertions(+), 15 deletions(-) diff --git a/.github/container/nsys_jax/nsys_jax/data_loaders.py b/.github/container/nsys_jax/nsys_jax/data_loaders.py index 46dff6c5a..57e52403d 100644 --- a/.github/container/nsys_jax/nsys_jax/data_loaders.py +++ b/.github/container/nsys_jax/nsys_jax/data_loaders.py @@ -593,15 +593,25 @@ def _drop_non_tsl(compile_df: pd.DataFrame) -> pd.DataFrame: def _read_nvtx_pushpop_trace_file(file: pathlib.Path) -> pd.DataFrame: - def keep_column(name): - return name not in {"PID", "Lvl", "NameTree"} - - return pd.read_csv( - lzma.open(file, "rt", newline=""), - dtype={"RangeId": np.int32}, - index_col="RangeId", - usecols=keep_column, - ) + # `file` follows one of two patterns, depending on whether we are loading the + # results from a single profile or from multiple merged profiles: + # - nsys-jax: /path/to/report_nvtx_pushpop_trace.parquet + # - nsys-jax-combine: /path/to/report_nvtx_pushpop_trace.parquet/rank5 + new_name = "report_nvtx_pushpop_trace.parquet" + if file.name == new_name or file.parent.name == new_name: + # New mode; the .csv to .parquet conversion is done in nsys-jax + return pd.read_parquet(file) + else: + + def keep_column(name): + return name not in {"PID", "Lvl", "NameTree"} + + return pd.read_csv( + lzma.open(file, "rt", newline=""), + dtype={"RangeId": np.int32}, + index_col="RangeId", + usecols=keep_column, + ) def _load_nvtx_pushpop_trace_single(name: pathlib.Path) -> pd.DataFrame: @@ -640,7 +650,9 @@ def remove_program_id_and_name(row): def _load_nvtx_pushpop_trace(prefix: pathlib.Path, frames: set[str]) -> pd.DataFrame: - path = prefix / "report_nvtx_pushpop_trace.csv.xz" + new_path = prefix / "report_nvtx_pushpop_trace.parquet" + legacy_path = prefix / "report_nvtx_pushpop_trace.csv.xz" + path = new_path if new_path.exists() else legacy_path if path.is_dir(): # We're looking at the output of nsys-jax-combine filenames = sorted(path.iterdir()) diff --git a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py index 522e636f1..291d3298f 100644 --- a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py +++ b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py @@ -3,6 +3,7 @@ from contextlib import contextmanager from glob import glob, iglob import lzma +import numpy as np import os import os.path as osp import pandas as pd # type: ignore @@ -369,7 +370,9 @@ def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue): if osp.isdir(full_path) or not osp.exists(full_path): continue output_queue.put((ofile, full_path, COMPRESS_NONE)) - print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s") + print( + f"{archive_name}: recipe post-processing finished in {time.time()-start:.2f}s" + ) def compress_and_archive(prefix, file, output_queue): """ @@ -401,9 +404,29 @@ def run_nsys_stats_report(report, report_file, tmp_dir, output_queue): ], check=True, ) - for ofile in iglob("report_" + report + ".csv", root_dir=tmp_dir): - compress_and_archive(tmp_dir, ofile, output_queue) - print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s") + output_path = osp.join(tmp_dir, f"report_{report}.csv") + + # TODO: avoid the .csv indirection + def keep_column(name): + return name not in {"PID", "Lvl", "NameTree"} + + try: + df = pd.read_csv( + output_path, + dtype={"RangeId": np.int32}, + index_col="RangeId", + usecols=keep_column, + ) + parquet_name = f"report_{report}.parquet" + parquet_path = osp.join(tmp_dir, parquet_name) + df.to_parquet(parquet_path) + output_queue.put((parquet_name, parquet_path, COMPRESS_NONE)) + except pd.errors.EmptyDataError: + # If there's no data, don't write a file to the output at all + pass + print( + f"{archive_name}: stats post-processing finished in {time.time()-start:.2f}s" + ) def save_device_stream_thread_names(tmp_dir, report, output_queue): """ diff --git a/.github/workflows/_ci.yaml b/.github/workflows/_ci.yaml index 0848a6e11..5027e3802 100644 --- a/.github/workflows/_ci.yaml +++ b/.github/workflows/_ci.yaml @@ -322,7 +322,9 @@ jobs: set -o pipefail num_tests=0 num_failures=0 - # Run the pytest-driven tests + # Run the pytest-driven tests; failure is explicitly handled below so set +e to + # avoid an early abort here. + set +e docker run -i --shm-size=1g --gpus all \ -v $PWD:/opt/output \ ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \ @@ -333,6 +335,7 @@ jobs: test_path=$(python -c 'import importlib.resources; print(importlib.resources.files("nsys_jax").joinpath("..", "tests").resolve())') pytest --report-log=/opt/output/pytest-report.jsonl "${test_path}" EOF + set -e GPUS_PER_NODE=$(nvidia-smi -L | grep -c '^GPU') for mode in 1-process 2-process process-per-gpu; do DOCKER="docker run --shm-size=1g --gpus all --env XLA_FLAGS=--xla_gpu_enable_command_buffer= --env XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 -v ${PWD}:/opt/output ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }}"