Skip to content

Commit

Permalink
kfac performance optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed Nov 26, 2023
1 parent bfddd80 commit ad9ecdb
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 15 deletions.
2 changes: 1 addition & 1 deletion analog/analog.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,4 +435,4 @@ def print_tracked_modules(self) -> None:
for k, v in self.logging_handler.modules_to_name.items():
get_logger().info(f"{v}: {k}")
repr_dim += k.weight.data.numel()
get_logger().info(f"Total number of parameters: {repr_dim}\n")
get_logger().info(f"Total number of parameters: {repr_dim:,}\n")
23 changes: 17 additions & 6 deletions analog/hessian/kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def parse_config(self) -> None:

@torch.no_grad()
def on_exit(self, current_log=None, update_hessian=True) -> None:
torch.cuda.current_stream().synchronize()
if update_hessian:
if self.reduce:
raise NotImplementedError
Expand All @@ -48,16 +49,26 @@ def update_hessian(
if self.reduce or self.ekfac:
return
# extract activations
activation = self.extract_activations(module, mode, data)

# compute covariance
covariance = torch.matmul(torch.t(activation), activation).cpu().detach()
activation = self.extract_activations(module, mode, data).detach()

# update covariance
if deep_get(self.hessian_state, [module_name, mode]) is None:
self.hessian_state[module_name][mode] = torch.zeros_like(covariance)
self.hessian_state[module_name][mode] = torch.zeros(
(activation.shape[-1], activation.shape[-1])
).pin_memory()
self.sample_counter[module_name][mode] = 0
self.hessian_state[module_name][mode].add_(covariance)

# move to gpu
if activation.is_cuda:
hessian_state_gpu = self.hessian_state[module_name][mode].to(
device=activation.device
)
hessian_state_gpu.addmm_(activation.t(), activation)
self.hessian_state[module_name][mode] = hessian_state_gpu.to(
device="cpu", non_blocking=True
)
else:
self.hessian_state[module_name][mode].addmm_(activation.t(), activation)
self.sample_counter[module_name][mode] += self.get_sample_size(data, mask)

@torch.no_grad()
Expand Down
8 changes: 5 additions & 3 deletions analog/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,11 @@ def _forward_hook_fn(
if self.mask is not None:
if len(self.mask.shape) != len(activations.shape):
assert len(self.mask.shape) == len(activations.shape) - 1
assert self.mask.shape == activations.shape[:-1]
self.mask = self.mask.unsqueeze(-1)
activations = activations * self.mask
if self.mask.shape[-1] == activations.shape[-2]:
activations = activations * self.mask.unsqueeze(-1)
else:
if self.mask.shape[-1] == activations.shape[-1]:
activations = activations * self.mask

if self.hessian and self.hessian_type == "kfac":
self.hessian_handler.update_hessian(
Expand Down
10 changes: 5 additions & 5 deletions examples/bert_influence/compute_influence.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import time
import argparse

from tqdm import tqdm

import torch
import torch.nn.functional as F
from analog import AnaLog
from analog.analysis import InfluenceFunction
from analog.utils import DataIDGenerator
from tqdm import tqdm

from pipeline import construct_model, get_loaders
from utils import set_seed
from utils import construct_model, get_loaders, set_seed

parser = argparse.ArgumentParser("GLUE Influence Analysis")
parser.add_argument("--data_name", type=str, default="sst2")
Expand All @@ -20,17 +21,16 @@
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
set_seed(0)


# model
model = construct_model(args.data_name, ckpt_path)
model = construct_model(args.data_name)
model.load_state_dict(
torch.load(f"files/checkpoints/0/{args.data_name}_epoch_3.pt", map_location="cpu")
)
model.to(DEVICE)
model.eval()

# data
_, eval_train_loader, test_loader = get_loaders(data_name=data_name)
_, eval_train_loader, test_loader = get_loaders(data_name=args.data_name)

# Set-up
analog = AnaLog(project="test")
Expand Down

0 comments on commit ad9ecdb

Please sign in to comment.