Skip to content

Commit

Permalink
Convert .csv to .parquet in nsys-jax to avoid compressing a large .cs…
Browse files Browse the repository at this point in the history
…v with Python's lzma.
  • Loading branch information
olupton committed Dec 10, 2024
1 parent 9c56b12 commit d8056e0
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 15 deletions.
32 changes: 22 additions & 10 deletions .github/container/nsys_jax/nsys_jax/data_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,15 +593,25 @@ def _drop_non_tsl(compile_df: pd.DataFrame) -> pd.DataFrame:


def _read_nvtx_pushpop_trace_file(file: pathlib.Path) -> pd.DataFrame:
def keep_column(name):
return name not in {"PID", "Lvl", "NameTree"}

return pd.read_csv(
lzma.open(file, "rt", newline=""),
dtype={"RangeId": np.int32},
index_col="RangeId",
usecols=keep_column,
)
# `file` follows one of two patterns, depending on whether we are loading the
# results from a single profile or from multiple merged profiles:
# - nsys-jax: /path/to/report_nvtx_pushpop_trace.parquet
# - nsys-jax-combine: /path/to/report_nvtx_pushpop_trace.parquet/rank5
new_name = "report_nvtx_pushpop_trace.parquet"
if file.name == new_name or file.parent.name == new_name:
# New mode; the .csv to .parquet conversion is done in nsys-jax
return pd.read_parquet(file)
else:

def keep_column(name):
return name not in {"PID", "Lvl", "NameTree"}

return pd.read_csv(
lzma.open(file, "rt", newline=""),
dtype={"RangeId": np.int32},
index_col="RangeId",
usecols=keep_column,
)


def _load_nvtx_pushpop_trace_single(name: pathlib.Path) -> pd.DataFrame:
Expand Down Expand Up @@ -640,7 +650,9 @@ def remove_program_id_and_name(row):


def _load_nvtx_pushpop_trace(prefix: pathlib.Path, frames: set[str]) -> pd.DataFrame:
path = prefix / "report_nvtx_pushpop_trace.csv.xz"
new_path = prefix / "report_nvtx_pushpop_trace.parquet"
legacy_path = prefix / "report_nvtx_pushpop_trace.csv.xz"
path = new_path if new_path.exists() else legacy_path
if path.is_dir():
# We're looking at the output of nsys-jax-combine
filenames = sorted(path.iterdir())
Expand Down
31 changes: 27 additions & 4 deletions .github/container/nsys_jax/nsys_jax/scripts/nsys_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import contextmanager
from glob import glob, iglob
import lzma
import numpy as np
import os
import os.path as osp
import pandas as pd # type: ignore
Expand Down Expand Up @@ -369,7 +370,9 @@ def run_nsys_recipe(recipe, report_file, tmp_dir, output_queue):
if osp.isdir(full_path) or not osp.exists(full_path):
continue
output_queue.put((ofile, full_path, COMPRESS_NONE))
print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s")
print(
f"{archive_name}: recipe post-processing finished in {time.time()-start:.2f}s"
)

def compress_and_archive(prefix, file, output_queue):
"""
Expand Down Expand Up @@ -401,9 +404,29 @@ def run_nsys_stats_report(report, report_file, tmp_dir, output_queue):
],
check=True,
)
for ofile in iglob("report_" + report + ".csv", root_dir=tmp_dir):
compress_and_archive(tmp_dir, ofile, output_queue)
print(f"{archive_name}: post-processing finished in {time.time()-start:.2f}s")
output_path = osp.join(tmp_dir, f"report_{report}.csv")

# TODO: avoid the .csv indirection
def keep_column(name):
return name not in {"PID", "Lvl", "NameTree"}

try:
df = pd.read_csv(
output_path,
dtype={"RangeId": np.int32},
index_col="RangeId",
usecols=keep_column,
)
parquet_name = f"report_{report}.parquet"
parquet_path = osp.join(tmp_dir, parquet_name)
df.to_parquet(parquet_path)
output_queue.put((parquet_name, parquet_path, COMPRESS_NONE))
except pd.errors.EmptyDataError:
# If there's no data, don't write a file to the output at all
pass
print(
f"{archive_name}: stats post-processing finished in {time.time()-start:.2f}s"
)

def save_device_stream_thread_names(tmp_dir, report, output_queue):
"""
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/_ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,9 @@ jobs:
set -o pipefail
num_tests=0
num_failures=0
# Run the pytest-driven tests
# Run the pytest-driven tests; failure is explicitly handled below so set +e to
# avoid an early abort here.
set +e
docker run -i --shm-size=1g --gpus all \
-v $PWD:/opt/output \
${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }} \
Expand All @@ -333,6 +335,7 @@ jobs:
test_path=$(python -c 'import importlib.resources; print(importlib.resources.files("nsys_jax").joinpath("..", "tests").resolve())')
pytest --report-log=/opt/output/pytest-report.jsonl "${test_path}"
EOF
set -e
GPUS_PER_NODE=$(nvidia-smi -L | grep -c '^GPU')
for mode in 1-process 2-process process-per-gpu; do
DOCKER="docker run --shm-size=1g --gpus all --env XLA_FLAGS=--xla_gpu_enable_command_buffer= --env XLA_PYTHON_CLIENT_MEM_FRACTION=0.9 -v ${PWD}:/opt/output ${{ needs.build-jax.outputs.DOCKER_TAG_FINAL }}"
Expand Down

0 comments on commit d8056e0

Please sign in to comment.