Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
eatpk committed Feb 20, 2024
1 parent 81ce797 commit b1d86d3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 27 deletions.
42 changes: 16 additions & 26 deletions analog/timer/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import functools

import torch
import psutil
import os


def get_gpu_memory(device_index=None):
Expand All @@ -15,15 +13,6 @@ def get_gpu_max_memory(device_index=None):
return torch.cuda.max_memory_allocated(device_index)


def get_host_memory():
process = psutil.Process(os.getpid())
return process.memory_info().rss


def get_cpu_swap_memory():
return psutil.swap_memory().used


class FunctionTimer:
log = {}

Expand All @@ -40,52 +29,44 @@ def wrapper(*args, **kwargs):

@classmethod
def _host_timer_wrapper(cls, func, label, *args, **kwargs):
before_memory = get_host_memory()
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
after_memory = get_host_memory()
if label not in cls.log:
cls.log[label] = [
{
"time_delta": end_time - start_time,
"memory_delta": (before_memory - after_memory) >> 20,
}
]
else:
cls.log[label].append(
{
"time_delta": end_time - start_time,
"memory_delta": (before_memory - after_memory) >> 20,
}
)
return result

@classmethod
def _device_timer_wrapper(cls, func, label, *args, **kwargs):
before_memory = get_gpu_memory()
start_event = torch.cuda.Event(enable_timing=True)
start_event.record()
result = func(*args, **kwargs)
end_event = torch.cuda.Event(enable_timing=True)
end_event.record()
after_memory = get_gpu_memory()
torch.cuda.current_stream().wait_event(end_event)
torch.cuda.synchronize()
if label not in cls.log:
cls.log[label] = [
{
"time_delta": start_event.elapsed_time(end_event)
/ 1000, # turn to seconds
"memory_delta": (before_memory - after_memory) >> 20,
}
]
else:
cls.log[label].append(
{
"time_delta": start_event.elapsed_time(end_event)
/ 1000, # turn to seconds
"memory_delta": (before_memory - after_memory) >> 20,
}
)
return result
Expand All @@ -110,15 +91,24 @@ def get_log(cls):

@classmethod
def print_log(cls):
print("Function Timer Logs:")
print(
"###########################################################################"
)
print(
"################################ TIMER LOG ################################"
)
header = f"{'Label':<50} | {'Total Time (sec)':>20}"
print(header)
print("-" * len(header))
for label, details in cls.log.items():
print(f" {label}:")
sum_time = 0
for log in details:
for key, value in log.items():
if key == "time_delta":
sum_time += value
print(f" operation costs {sum_time} seconds")
for log_entry in details:
time_delta = log_entry.get("time_delta", 0)
sum_time += time_delta
# truncate 47 letters if the label is longer than 50.
display_label = (label[:47] + "...") if len(label) > 50 else label
row = f"{display_label:<50} | {sum_time:>20.4f}"
print(row)


class HostFunctionTimer(FunctionTimer):
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ numpy
pandas
torch
einops
psutil

pyyaml

0 comments on commit b1d86d3

Please sign in to comment.