Skip to content

Commit

Permalink
profiler
Browse files Browse the repository at this point in the history
  • Loading branch information
eatpk committed Feb 25, 2024
1 parent b1d86d3 commit a9a3853
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 8 deletions.
2 changes: 1 addition & 1 deletion analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from analog.lora import LoRAHandler
from analog.lora.utils import is_lora
from analog.state import AnaLogState
from analog.timer.timer import DeviceFunctionTimer
from analog.monitor_util.timer import DeviceFunctionTimer
from analog.utils import (
get_logger,
get_rank,
Expand Down
2 changes: 1 addition & 1 deletion analog/analysis/influence_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from einops import einsum, rearrange, reduce
from analog.config import InfluenceConfig
from analog.state import AnaLogState
from analog.timer.timer import DeviceFunctionTimer
from analog.monitor_util.timer import DeviceFunctionTimer
from analog.utils import get_logger, nested_dict
from analog.analysis.utils import synchronize_device

Expand Down
2 changes: 1 addition & 1 deletion analog/logging/log_saver.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from concurrent.futures import ThreadPoolExecutor
import torch

from analog.timer.timer import HostFunctionTimer, DeviceFunctionTimer
from analog.monitor_util.timer import DeviceFunctionTimer
from analog.utils import nested_dict, to_numpy, get_rank
from analog.logging.mmap import MemoryMapHandler

Expand Down
2 changes: 1 addition & 1 deletion analog/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from analog.logging.option import LogOption
from analog.logging.log_saver import LogSaver
from analog.logging.utils import compute_per_sample_gradient
from analog.timer.timer import DeviceFunctionTimer
from analog.monitor_util.timer import DeviceFunctionTimer
from analog.utils import get_logger


Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .timer import FunctionTimer, Timer
from .profiler import memory_profiler
26 changes: 26 additions & 0 deletions analog/monitor_util/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
import functools
from torch.profiler import profile, ProfilerActivity


def memory_profiler(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
activities = [ProfilerActivity.CPU]
if device.type == "cuda":
activities.append(ProfilerActivity.CUDA)

with profile(activities=activities, profile_memory=True) as prof:
result = func(*args, **kwargs)

print(
prof.key_averages().table(
sort_by="self_cuda_memory_usage"
if device.type == "cuda"
else "self_cpu_memory_usage"
)
)
return result

return wrapper
8 changes: 5 additions & 3 deletions analog/timer/timer.py → analog/monitor_util/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ class DeviceFunctionTimer(FunctionTimer):
if torch.cuda.is_available():
host_timer = False
else:
logging.warning("CUDA is not set, setting the timer is set to host timer.")
logging.warning(
"CUDA is not set, setting the monitor_util is set to host monitor_util."
)
host_timer = True


Expand All @@ -135,13 +137,13 @@ def __init__(self):
def start_timer(self, name, host_timer=False):
if host_timer:
if name in self.timers["cpu"]:
logging.warning(f"timer for {name} already exist")
logging.warning(f"monitor_util for {name} already exist")
return
start_time = time.time()
self.timers["cpu"][name] = [start_time]
else:
if name in self.timers["gpu"]:
logging.warning(f"timer for {name} already exist")
logging.warning(f"monitor_util for {name} already exist")
return
self.is_synchronized = False
start_event = torch.cuda.Event(enable_timing=True)
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist_influence/compute_influences.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
construct_mlp,
)

from analog.timer import FunctionTimer
from analog.monitor_util import FunctionTimer

parser = argparse.ArgumentParser("MNIST Influence Analysis")
parser.add_argument("--data", type=str, default="mnist", help="mnist or fmnist")
Expand Down

0 comments on commit a9a3853

Please sign in to comment.