Skip to content

Commit

Permalink
Convert .csv to .parquet in nsys-jax to avoid compressing a large .cs…
Browse files Browse the repository at this point in the history
…v with Python's lzma.
  • Loading branch information
olupton committed Dec 10, 2024
1 parent 9c56b12 commit 74a3d94
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 14 deletions.
32 changes: 22 additions & 10 deletions .github/container/nsys_jax/nsys_jax/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,15 +593,25 @@ def _drop_non_tsl(compile_df: pd.DataFrame) -> pd.DataFrame:


def _read_nvtx_pushpop_trace_file(file: pathlib.Path) -> pd.DataFrame:
def keep_column(name):
return name not in {"PID", "Lvl", "NameTree"}

return pd.read_csv(
lzma.open(file, "rt", newline=""),
dtype={"RangeId": np.int32},
index_col="RangeId",
usecols=keep_column,
)
# `file` follows one of two patterns, depending on whether we are loading the
# results from a single profile or from multiple merged profiles:
# - nsys-jax: /path/to/report_nvtx_pushpop_trace.parquet
# - nsys-jax-combine: /path/to/report_nvtx_pushpop_trace.parquet/rank5
new_name = "report_nvtx_pushpop_trace.parquet"
if file.name == new_name or file.parent.name == new_name:
# New mode; the .csv to .parquet conversion is done in nsys-jax
return pd.read_parquet(file)
else:

def keep_column(name):
return name not in {"PID", "Lvl", "NameTree"}

return pd.read_csv(
lzma.open(file, "rt", newline=""),
dtype={"RangeId": np.int32},
index_col="RangeId",
usecols=keep_column,
)


def _load_nvtx_pushpop_trace_single(name: pathlib.Path) -> pd.DataFrame:
Expand Down Expand Up @@ -640,7 +650,9 @@ def remove_program_id_and_name(row):


def _load_nvtx_pushpop_trace(prefix: pathlib.Path, frames: set[str]) -> pd.DataFrame:
path = prefix / "report_nvtx_pushpop_trace.csv.xz"
new_path = prefix / "report_nvtx_pushpop_trace.parquet"
legacy_path = prefix / "report_nvtx_pushpop_trace.csv.xz"
path = new_path if new_path.exists() else legacy_path
if path.is_dir():
# We're looking at the output of nsys-jax-combine
filenames = sorted(path.iterdir())
Expand Down
27 changes: 23 additions & 4 deletions .github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import contextmanager
from glob import glob, iglob
import lzma
import numpy as np
import os
import os.path as osp
import pandas as pd # type: ignore
Expand Down Expand Up @@ -369,7 +370,9 @@ def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue):
if osp.isdir(full_path) or not osp.exists(full_path):
continue
output_queue.put((ofile, full_path, COMPRESS_NONE))
print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s")
print(
f"{archive_name}: recipe post-processing finished in {time.time()-start:.2f}s"
)

def compress_and_archive(prefix, file, output_queue):
"""
Expand Down Expand Up @@ -401,9 +404,25 @@ def run_nsys_stats_report(report, report_file, tmp_dir, output_queue):
],
check=True,
)
for ofile in iglob("report_" + report + ".csv", root_dir=tmp_dir):
compress_and_archive(tmp_dir, ofile, output_queue)
print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s")
output_path = osp.join(tmp_dir, f"report_{report}.csv")

# TODO: avoid the .csv indirection
def keep_column(name):
return name not in {"PID", "Lvl", "NameTree"}

df = pd.read_csv(
output_path,
dtype={"RangeId": np.int32},
index_col="RangeId",
usecols=keep_column,
)
parquet_name = f"report_{report}.parquet"
parquet_path = osp.join(tmp_dir, parquet_name)
df.to_parquet(parquet_path)
output_queue.put((parquet_name, parquet_path, COMPRESS_NONE))
print(
f"{archive_name}: stats post-processing finished in {time.time()-start:.2f}s"
)

def save_device_stream_thread_names(tmp_dir, report, output_queue):
"""
Expand Down

0 comments on commit 74a3d94

Please sign in to comment.