diff --git a/.github/container/nsys_jax/nsys_jax/__init__.py b/.github/container/nsys_jax/nsys_jax/__init__.py index e89395d8a..93fde24ee 100644 --- a/.github/container/nsys_jax/nsys_jax/__init__.py +++ b/.github/container/nsys_jax/nsys_jax/__init__.py @@ -7,7 +7,7 @@ from .data_loaders import load_profiler_data from .protobuf import xla_module_metadata from .protobuf_utils import compile_protos, ensure_compiled_protos_are_importable -from .utils import remove_autotuning_detail, remove_child_ranges +from .utils import default_data_prefix, remove_autotuning_detail, remove_child_ranges from .visualization import create_flamegraph, display_flamegraph __all__ = [ @@ -16,6 +16,7 @@ "calculate_collective_metrics", "compile_protos", "create_flamegraph", + "default_data_prefix", "display_flamegraph", "ensure_compiled_protos_are_importable", "generate_compilation_statistics", diff --git a/.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb b/.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb index d8e8c6248..ed2954c12 100644 --- a/.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb +++ b/.github/container/nsys_jax/nsys_jax/analyses/Analysis.ipynb @@ -12,6 +12,7 @@ "from nsys_jax import (\n", " align_profiler_data_timestamps,\n", " apply_warmup_heuristics,\n", + " default_data_prefix,\n", " display_flamegraph,\n", " ensure_compiled_protos_are_importable,\n", " generate_compilation_statistics,\n", @@ -23,6 +24,18 @@ "import numpy as np" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "7a91f0e7-17da-4534-8ea9-29bcf3742567", + "metadata": {}, + "outputs": [], + "source": [ + "# Set the input data to use. default_data_prefix() checks the NSYS_JAX_DEFAULT_PREFIX environment variable, and if that is\n", + "# not set then the current working directory is used. Use pathlib.Path if setting this explicitly.\n", + "prefix = default_data_prefix()" + ] + }, { "cell_type": "code", "execution_count": null, @@ -32,7 +45,7 @@ "source": [ "# Make sure that the .proto files under protos/ have been compiled to .py, and\n", "# that those generated .py files are importable.]\n", - "compiled_dir = ensure_compiled_protos_are_importable()" + "compiled_dir = ensure_compiled_protos_are_importable(prefix=prefix)" ] }, { @@ -43,7 +56,7 @@ "outputs": [], "source": [ "# Load the runtime profile data\n", - "all_data = load_profiler_data()\n", + "all_data = load_profiler_data(prefix)\n", "# Remove some detail from the autotuner\n", "all_data = remove_autotuning_detail(all_data)\n", "# Align GPU timestamps across profiles collected by different Nsight Systems processes\n", @@ -82,7 +95,7 @@ "source": [ "This data frame has a three-level index:\n", "- `ProgramId` is an integer ID that uniquely identifies the XLA module\n", - "- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 1, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n", + "- This is the `ProgramExecution`-th execution of the module within the profiles. You may see this starting from 2, not 0, because of the `warmup_removal_heuristics` option passed to `load_profiler_data`.\n", "- `Device` is the global (across multiple nodes and processes) index of the GPU on which the module execution took place\n", "\n", "The columns are as follows:\n", @@ -90,8 +103,6 @@ "- `NumThunks`: the number of thunks executed inside this module execution\n", "- `ProjStartMs`: the timestamp of the start of the module execution on the GPU, in milliseconds\n", "- `ProjDurMs`: the duration of the module execution on the GPU, in milliseconds\n", - "- `OrigStartMs`: the timestamp of the start of the module launch **on the host**, in milliseconds. *i.e.* `ProjStartMs-OrigStartMs` is something like the launch latency of the first kernel\n", - "- `OrigDurMs`: the duration of the module launch **on the host**, in milliseconds\n", "- `LocalDevice`: the index within the node/slice of the GPU on which the module execution took place\n", "- `Process`: the global (across multiple nodes) index of the process\n", "- `Slice`: the global index of the node/slice; devices within the same node/slice should have faster interconnects than to devices in different slices\n", @@ -117,13 +128,13 @@ "id": "7727d800-13d3-4505-89e8-80a5fed63512", "metadata": {}, "source": [ - "Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `module_df`.\n", + "Here the index has four levels. `ProgramId`, `ProgramExecution` and `Device` have the same meanings as in `steady_state.module`.\n", "The fourth level (in the 3rd position) shows that this row is the `ThunkIndex`-th thunk within the `ProgramExecution`-th execution of XLA module `ProgramId`.\n", "Note that a given thunk can be executed multiple times within the same module, so indexing on the thunk name would not be unique.\n", "\n", "The columns are as follows:\n", "- `Name`: the name of the thunk; this should be unique within a given `ProgramId` and can be used as a key to look up XLA metadata\n", - "- `ProjStartMs`, `OrigStartMs`, `OrigDurMs`: see above, same meaning as in `module_df`.\n", + "- `ProjStartMs`: see above, same meaning as in `steady_state.module`.\n", "- `Communication`: does this thunk represent communication between GPUs (*i.e.* a NCCL collective)? XLA overlaps communication and computation kernels, and `load_profiler_data` triggers an overlap calculation. `ProjDurMs` for a communication kernel shows only the duration that was **not** overlapped with computation kernels, while `ProjDurHiddenMs` shows the duration that **was** overlapped.\n", "- This is the `ThunkExecution`-th execution of this thunk for this `(ProgramId, ProgramExecution, Device)`\n", "\n", @@ -299,7 +310,7 @@ "# Print out the largest entries adding up to at least this fraction of the total\n", "threshold = 0.97\n", "compile_summary[\"FracNonChild\"] = compile_summary[\"DurNonChildMs\"] / total_compile_time\n", - "print(f\"Top {threshold:.0%}+ of {total_compile_time*1e-9:.2f}s compilation time\")\n", + "print(f\"Top {threshold:.0%}+ of {total_compile_time*1e-3:.2f}s compilation time\")\n", "for row in compile_summary[\n", " compile_summary[\"FracNonChild\"].cumsum() <= threshold\n", "].itertuples():\n", @@ -378,7 +389,7 @@ " program_id, thunk_name = thunk_row.Index\n", " # policy=\"all\" means we may get a set of HloProto instead of a single one, if\n", " # nsys-jax-combine was used and the dumped metadata were not bitwise identical\n", - " hlo_modules = xla_module_metadata(program_id, policy=\"all\")\n", + " hlo_modules = xla_module_metadata(program_id, policy=\"all\", prefix=prefix)\n", " thunk_opcode, inst_metadata, inst_frames = hlo_modules.unique_result(\n", " lambda proto: instructions_and_frames(proto, thunk_name)\n", " )\n",