Skip to content

Commit

Permalink
Minor bug fixes in example notebook
Browse files Browse the repository at this point in the history
Plumb through `prefix` so it's more convenient to explicitly set the
input data path.
  • Loading branch information
olupton committed Dec 10, 2024
1 parent 66715ec commit 9c56b12
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 10 deletions.
3 changes: 2 additions & 1 deletion .github/container/nsys_jax/nsys_jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .data_loaders import load_profiler_data
from .protobuf import xla_module_metadata
from .protobuf_utils import compile_protos, ensure_compiled_protos_are_importable
from .utils import remove_autotuning_detail, remove_child_ranges
from .utils import default_data_prefix, remove_autotuning_detail, remove_child_ranges
from .visualization import create_flamegraph, display_flamegraph

__all__ = [
Expand All @@ -16,6 +16,7 @@
"calculate_collective_metrics",
"compile_protos",
"create_flamegraph",
"default_data_prefix",
"display_flamegraph",
"ensure_compiled_protos_are_importable",
"generate_compilation_statistics",
Expand Down
29 changes: 20 additions & 9 deletions .github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"from nsys_jax import (\n",
" align_profiler_data_timestamps,\n",
" apply_warmup_heuristics,\n",
" default_data_prefix,\n",
" display_flamegraph,\n",
" ensure_compiled_protos_are_importable,\n",
" generate_compilation_statistics,\n",
Expand All @@ -23,6 +24,18 @@
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7a91f0e7-17da-4534-8ea9-29bcf3742567",
"metadata": {},
"outputs": [],
"source": [
"# Set the input data to use. default_data_prefix() checks the NSYS_JAX_DEFAULT_PREFIX environment variable, and if that is\n",
"# not set then the current working directory is used. Use pathlib.Path if setting this explicitly.\n",
"prefix = default_data_prefix()"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -32,7 +45,7 @@
"source": [
"# Make sure that the .proto files under protos/ have been compiled to .py, and\n",
"# that those generated .py files are importable.]\n",
"compiled_dir = ensure_compiled_protos_are_importable()"
"compiled_dir = ensure_compiled_protos_are_importable(prefix=prefix)"
]
},
{
Expand All @@ -43,7 +56,7 @@
"outputs": [],
"source": [
"# Load the runtime profile data\n",
"all_data = load_profiler_data()\n",
"all_data = load_profiler_data(prefix)\n",
"# Remove some detail from the autotuner\n",
"all_data = remove_autotuning_detail(all_data)\n",
"# Align GPU timestamps across profiles collected by different Nsight Systems processes\n",
Expand Down Expand Up @@ -82,16 +95,14 @@
"source": [
"This data frame has a three-level index:\n",
"- `ProgramId` is an integer ID that uniquely identifies the XLA module\n",
"- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 1, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n",
"- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 2, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n",
"- `Device` is the global (across multiple nodes and processes) index of the GPU on which the module execution took place\n",
"\n",
"The columns are as follows:\n",
"- `Name`: the name of the XLA module; this should always be the same for a given `ProgramId`\n",
"- `NumThunks`: the number of thunks executed inside this module execution\n",
"- `ProjStartMs`: the timestamp of the start of the module execution on the GPU, in milliseconds\n",
"- `ProjDurMs`: the duration of the module execution on the GPU, in milliseconds\n",
"- `OrigStartMs`: the timestamp of the start of the module launch **on the host**, in milliseconds. *i.e.* `ProjStartMs-OrigStartMs` is something like the launch latency of the first kernel\n",
"- `OrigDurMs`: the duration of the module launch **on the host**, in milliseconds\n",
"- `LocalDevice`: the index within the node/slice of the GPU on which the module execution took place\n",
"- `Process`: the global (across multiple nodes) index of the process\n",
"- `Slice`: the global index of the node/slice; devices within the same node/slice should have faster interconnects than to devices in different slices\n",
Expand All @@ -117,13 +128,13 @@
"id": "7727d800-13d3-4505-89e8-80a5fed63512",
"metadata": {},
"source": [
"Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `module_df`.\n",
"Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `steady_state.module`.\n",
"The fourth level (in the 3rd position) shows that this row is the `ThunkIndex`-th thunk within the `ProgramExecution`-th execution of XLA module `ProgramId`.\n",
"Note that a given thunk can be executed multiple times within the same module, so indexing on the thunk name would not be unique.\n",
"\n",
"The columns are as follows:\n",
"- `Name`: the name of the thunk; this should be unique within a given `ProgramId` and can be used as a key to look up XLA metadata\n",
"- `ProjStartMs`, `OrigStartMs`, `OrigDurMs`: see above, same meaning as in `module_df`.\n",
"- `ProjStartMs`: see above, same meaning as in `steady_state.module`.\n",
"- `Communication`: does this thunk represent communication between GPUs (*i.e.* a NCCL collective)? XLA overlaps communication and computation kernels, and `load_profiler_data` triggers an overlap calculation. `ProjDurMs` for a communication kernel shows only the duration that was **not** overlapped with computation kernels, while `ProjDurHiddenMs` shows the duration that **was** overlapped.\n",
"- This is the `ThunkExecution`-th execution of this thunk for this `(ProgramId, ProgramExecution, Device)`\n",
"\n",
Expand Down Expand Up @@ -299,7 +310,7 @@
"# Print out the largest entries adding up to at least this fraction of the total\n",
"threshold = 0.97\n",
"compile_summary[\"FracNonChild\"] = compile_summary[\"DurNonChildMs\"] / total_compile_time\n",
"print(f\"Top {threshold:.0%}+ of {total_compile_time*1e-9:.2f}s compilation time\")\n",
"print(f\"Top {threshold:.0%}+ of {total_compile_time*1e-3:.2f}s compilation time\")\n",
"for row in compile_summary[\n",
" compile_summary[\"FracNonChild\"].cumsum() <= threshold\n",
"].itertuples():\n",
Expand Down Expand Up @@ -378,7 +389,7 @@
" program_id, thunk_name = thunk_row.Index\n",
" # policy=\"all\" means we may get a set of HloProto instead of a single one, if\n",
" # nsys-jax-combine was used and the dumped metadata were not bitwise identical\n",
" hlo_modules = xla_module_metadata(program_id, policy=\"all\")\n",
" hlo_modules = xla_module_metadata(program_id, policy=\"all\", prefix=prefix)\n",
" thunk_opcode, inst_metadata, inst_frames = hlo_modules.unique_result(\n",
" lambda proto: instructions_and_frames(proto, thunk_name)\n",
" )\n",
Expand Down

0 comments on commit 9c56b12

Please sign in to comment.