-
Notifications
You must be signed in to change notification settings - Fork 0
/
fast_flops.py
46 lines (39 loc) · 1.49 KB
/
fast_flops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import nvtx
from functools import wraps
def is_jax_tensor(tensor):
"""
Check if the given tensor is a JAX tensor.
"""
return (str(type(tensor)) == "<class 'jaxlib.xla_extension.DeviceArray'>") | (str(type(tensor)) == "<class 'jaxlib.xla_extension.ArrayImpl'>")
def is_jax_dynamic_tensor(tensor):
"""
Check if the given tensor is a Dynamic JAX tensor so that block_until_ready can be ignored.
"""
return str(type(tensor)) == "<class 'jax._src.interpreters.partial_eval.DynamicJaxprTracer'>"
def is_torch_tensor(tensor):
"""
Check if the given tensor is a PyTorch tensor.
"""
return str(type(tensor)).startswith("<class 'torch.Tensor")
def flops_counter(func):
@wraps(func)
def wrapper(*args, **kwargs):
# A few warmup laps to cleanse the JIT
for _ in range(10):
result = func(*args, **kwargs)
if is_jax_tensor(result):
import jax
jax.tree_util.tree_map(lambda x: x.block_until_ready(), result)
elif is_torch_tensor(result):
import torch
torch.cuda.synchronize()
elif is_jax_dynamic_tensor(result):
# Doesnt make sense to block traced tensors
pass
else:
raise ValueError(f"{type(result)} not supported")
nvtx_range = nvtx.start_range(f"profile")
result = func(*args, **kwargs)
nvtx.end_range(nvtx_range)
return result
return wrapper