Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

nsys-jax: optimise data loading and .zip creation #1193

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

olupton
Copy link
Collaborator

@olupton olupton commented Dec 10, 2024

Some rough measurements on vanilla jax-nccl-test and 8xH100:

Profile collection, whole execution: 52s (nsys), 58s (nsys-jax with this PR), 1m5s (nsys-jax without this PR)
Profile collection, restricted range: 46s (nsys), 50s (nsys-jax with this PR), 55s (nsys-jax without this PR)
Communication analysis, whole execution: 1.1s (with this PR), 2.1s (without this PR)
Communication analysis, restricted range: 1.0s (with this PR), 1.7s (without this PR)

The differences are more pronounced on larger workloads with more activity.

The two bigger changes are:

  • Convert .csv to .parquet as part of nsys-jax to avoid compressing .csv with Python's lzma module, which is slow and single-threaded. This speeds up nsys-jax and subsequent data-loading.
  • A new algorithm for calculating the hidden/exposed time of communication kernels when loading profile data -- essentially this adds a fast pandas-friendly pass to identify [most] non-overlapping kernels and skip running the [relatively slow and pandas-unfriendly] overlap calculation on them. This also removes an assumption that there is no compute-compute overlap.

Otherwise there are some tweaks to pandas usage and minor reorganisations to make Python profiles more informative, and minor bugfixes in the example Jupyter notebook.

@olupton olupton force-pushed the olupton/nsys-jax-python-opt branch 2 times, most recently from 4f1b629 to 608e45c Compare December 10, 2024 11:28
@olupton olupton force-pushed the olupton/nsys-jax-python-opt branch from 608e45c to 385c7f4 Compare December 10, 2024 14:10
@olupton olupton force-pushed the olupton/nsys-jax-python-opt branch 2 times, most recently from f3189d1 to 74a3d94 Compare December 10, 2024 15:43
@olupton olupton force-pushed the olupton/nsys-jax-python-opt branch from 74a3d94 to d8056e0 Compare December 10, 2024 15:58
@olupton olupton requested a review from gspschmid December 11, 2024 09:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant