diff --git a/.github/container/nsys_jax/nsys_jax/data_loaders.py b/.github/container/nsys_jax/nsys_jax/data_loaders.py index dabafd497..195989b13 100644 --- a/.github/container/nsys_jax/nsys_jax/data_loaders.py +++ b/.github/container/nsys_jax/nsys_jax/data_loaders.py @@ -592,15 +592,20 @@ def _drop_non_tsl(compile_df: pd.DataFrame) -> pd.DataFrame: def _read_nvtx_pushpop_trace_file(file: pathlib.Path) -> pd.DataFrame: - def keep_column(name): - return name not in {"PID", "Lvl", "NameTree"} - - return pd.read_csv( - lzma.open(file, "rt", newline=""), - dtype={"RangeId": np.int32}, - index_col="RangeId", - usecols=keep_column, - ) + if file.suffix == ".parquet": + # New mode; the .csv to .parquet conversion is done in nsys-jax + return pd.read_parquet(file) + else: + + def keep_column(name): + return name not in {"PID", "Lvl", "NameTree"} + + return pd.read_csv( + lzma.open(file, "rt", newline=""), + dtype={"RangeId": np.int32}, + index_col="RangeId", + usecols=keep_column, + ) def _load_nvtx_pushpop_trace_single(name: pathlib.Path) -> pd.DataFrame: @@ -639,7 +644,9 @@ def remove_program_id_and_name(row): def _load_nvtx_pushpop_trace(prefix: pathlib.Path, frames: set[str]) -> pd.DataFrame: - path = prefix / "report_nvtx_pushpop_trace.csv.xz" + new_path = prefix / "report_nvtx_pushpop_trace.parquet" + legacy_path = prefix / "report_nvtx_pushpop_trace.csv.xz" + path = new_path if new_path.exists() else legacy_path if path.is_dir(): # We're looking at the output of nsys-jax-combine filenames = sorted(path.iterdir()) diff --git a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py index 522e636f1..c0c3dfc95 100644 --- a/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py +++ b/.github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py @@ -3,6 +3,7 @@ from contextlib import contextmanager from glob import glob, iglob import lzma +import numpy as np import os import os.path as osp import pandas as pd # type: ignore @@ -369,7 +370,9 @@ def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue): if osp.isdir(full_path) or not osp.exists(full_path): continue output_queue.put((ofile, full_path, COMPRESS_NONE)) - print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s") + print( + f"{archive_name}: recipe post-processing finished in {time.time()-start:.2f}s" + ) def compress_and_archive(prefix, file, output_queue): """ @@ -401,9 +404,25 @@ def run_nsys_stats_report(report, report_file, tmp_dir, output_queue): ], check=True, ) - for ofile in iglob("report_" + report + ".csv", root_dir=tmp_dir): - compress_and_archive(tmp_dir, ofile, output_queue) - print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s") + output_path = osp.join(tmp_dir, f"report_{report}.csv") + + # TODO: avoid the .csv indirection + def keep_column(name): + return name not in {"PID", "Lvl", "NameTree"} + + df = pd.read_csv( + output_path, + dtype={"RangeId": np.int32}, + index_col="RangeId", + usecols=keep_column, + ) + parquet_name = f"report_{report}.parquet" + parquet_path = osp.join(tmp_dir, parquet_name) + df.to_parquet(parquet_path) + output_queue.put((parquet_name, parquet_path, COMPRESS_NONE)) + print( + f"{archive_name}: stats post-processing finished in {time.time()-start:.2f}s" + ) def save_device_stream_thread_names(tmp_dir, report, output_queue): """