Skip to content

Latest commit

 

History

History
78 lines (57 loc) · 2.21 KB

README.md

File metadata and controls

78 lines (57 loc) · 2.21 KB

Fast FLOPS

... and where to find them 🐇

Ever needed to report FLOPs for your Torch/JAX code. We got you covered ! (More importantly, here's a survey on why bunnies flop)

Borrowed from https://gitlab.com/NERSC/roofline-on-nvidia-gpus/-/tree/roofline-hackathon-2020

Nsight Compute installation needed. Make sure the GPU counters are enabled

bunny

(Image Credit: Tuong Phung)

Workflow (Warning: Extremely clunky right now)

  • Define your function that you wanna profile

    • JAX

      SIZE = 500
      x = jax.random.normal(jax.random.PRNGKey(0), (SIZE, SIZE))
      y = jax.random.normal(jax.random.PRNGKey(1), (SIZE, SIZE))
      
      def func(x, y):
          return jnp.einsum("ij,jk->ik", x, y)
    • Torch

      SIZE = 500
      x = torch.randn(SIZE, SIZE).to(device='cuda')
      y = torch.randn(SIZE, SIZE).to(device='cuda')
      
      def func(x, y):
          return torch.einsum("ij,jk->ik", x, y)
  • Wrap the function in the flops_counter decorator

    from fast_flops import func_flops
    
    @flops_counter
    def func_flops(func, x, y):
        return func(x, y)
  • Let it run through JIT (Don't worry we have warmups to cleanse the JIT overhead) and execute !

    • JAX

      func = jax.jit(func)
      func_flops(func, x, y)
    • Torch

      func = torch.compile(func, fullgraph=True, mode='max-autotune')
      func_flops(func, x, y)
  • The pipeline can be executed using

    bash launch_profiler.sh examples/matmul/test_matmul_torch.py

    with the output looking something like

    Measured Time: 3.196823494576654e-05
    Measured GFLOP/s: 2353.1630879907675
    Measured FLOPS: 67108864.0

TODOs

  • Only turn on ncu during the hot loop in addition to nvtx
  • Axe run_profiler.py and postprocess.py. Ideally should work with python decorated_file.py
  • Plotting utilieis and csv plumbing