Skip to content

Commit

Permalink
clean up examples
Browse files Browse the repository at this point in the history
  • Loading branch information
sangkeun00 committed May 14, 2024
1 parent 9298db7 commit 8021ad2
Show file tree
Hide file tree
Showing 10 changed files with 75 additions and 428 deletions.
2 changes: 1 addition & 1 deletion examples/bert_influence/compute_influence.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def main():
parser.add_argument("--project", type=str, default="sst2")
parser.add_argument("--config_path", type=str, default="./config.yaml")
parser.add_argument("--data_name", type=str, default="sst2")
parser.add_argument("--damping", type=float, default=1e-5)
parser.add_argument("--damping", type=float, default=None)
args = parser.parse_args()

# prepare model & data loader
Expand Down
7 changes: 6 additions & 1 deletion examples/bert_influence/extract_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ def main():
parser.add_argument("--config_path", type=str, default="./config.yaml")
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--data_name", type=str, default="sst2")
parser.add_argument("--lora", type=str, default="random")
parser.add_argument("--hessian", type=str, default="raw")
parser.add_argument("--save", type=str, default="grad")
args = parser.parse_args()

set_seed(0)
Expand All @@ -31,7 +34,9 @@ def main():

# LogIX
run = logix.init(args.project, config=args.config_path)
scheduler = logix.LogIXScheduler(run, lora=True)
scheduler = logix.LogIXScheduler(
run, lora=args.lora, hessian=args.hessian, save=args.save
)

logix.watch(model)
for _ in scheduler:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import time
import argparse

from tqdm import tqdm
import torch

from logix import LogIX
from logix import LogIX, LogIXScheduler
from logix.utils import DataIDGenerator
from logix.analysis import InfluenceFunction

Expand All @@ -14,12 +16,15 @@
parser = argparse.ArgumentParser("CIFAR Influence Analysis")
parser.add_argument("--data", type=str, default="cifar10", help="cifar10/100")
parser.add_argument("--eval-idxs", type=int, nargs="+", default=[0])
parser.add_argument("--damping", type=float, default=1e-5)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--damping", type=float, default=None)
parser.add_argument("--lora", type=str, default="none")
parser.add_argument("--hessian", type=str, default="raw")
parser.add_argument("--save", type=str, default="grad")
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


model = construct_rn9().to(DEVICE)

# Get a single checkpoint (first model_id and last epoch).
Expand All @@ -32,52 +37,51 @@
train_loader = dataloader_fn(
batch_size=512, split="train", shuffle=False, subsample=True, augment=False
)
query_loader = dataloader_fn(
test_loader = dataloader_fn(
batch_size=1, split="valid", shuffle=False, indices=args.eval_idxs, augment=False
)

logix = LogIX(project="test", config="./config.yaml")
logix_scheduler = LogIXScheduler(
logix, lora=args.lora, hessian=args.hessian, save=args.save
)

# Gradient & Hessian logging
logix.watch(model)
logix.setup({"log": "grad", "save": "grad", "statistic": "kfac"})

if not args.resume:
id_gen = DataIDGenerator()
for inputs, targets in train_loader:
data_id = id_gen(inputs)
with logix(data_id=data_id):
id_gen = DataIDGenerator()
for epoch in logix_scheduler:
for inputs, targets in tqdm(train_loader, desc="Extracting log"):
with logix(data_id=id_gen(inputs)):
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
model.zero_grad()
outs = model(inputs)
loss = torch.nn.functional.cross_entropy(outs, targets, reduction="sum")
loss.backward()
logix.finalize()
else:
logix.initialize_from_log()

# Influence Analysis
logix.eval()
log_loader = logix.build_log_dataloader()

logix.add_analysis({"influence": InfluenceFunction})
query_iter = iter(query_loader)
with logix(log=["grad"]) as al:
test_input, test_target = next(query_iter)
test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE)
model.zero_grad()
test_out = model(test_input)
test_loss = torch.nn.functional.cross_entropy(
test_out, test_target, reduction="sum"
logix.eval()
logix.setup({"log": "grad"})
for test_input, test_target in test_loader:
with logix(data_id=id_gen(test_input)):
test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE)
model.zero_grad()
test_out = model(test_input)
test_loss = torch.nn.functional.cross_entropy(
test_out, test_target, reduction="sum"
)
test_loss.backward()
test_log = logix.get_log()

# Influence computation
result = logix.influence.compute_influence_all(
test_log, log_loader, damping=args.damping
)
test_loss.backward()
test_log = al.get_log()
start = time.time()
result = logix.influence.compute_influence_all(
test_log, log_loader, damping=args.damping
)
break

# Save
if_scores = result["influence"].numpy().tolist()
torch.save(if_scores, "./if_baseline.pt")
print("Computation time:", time.time() - start)
torch.save(if_scores, "if_logix.pt")
87 changes: 0 additions & 87 deletions examples/cifar/compute_influences_pca.py

This file was deleted.

2 changes: 2 additions & 0 deletions examples/cifar/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
lora:
init: pca
13 changes: 3 additions & 10 deletions examples/mnist/compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,7 @@


kfac = torch.load("if_kfac.pt")
kfac_true = torch.load("if_kfac_true.pt")
ekfac = torch.load("if_ekfac.pt")
ekfac_true = torch.load("if_ekfac_true.pt")
logix_kfac = torch.load("if_logix.pt")
logix_lora = torch.load("if_logix_lora64_pca.pt")
logix_ekfac = torch.load("if_logix_ekfac.pt")
print("[KFAC (base) vs KFAC (logix)] pearson:", pearsonr(kfac, logix_kfac))
print("[KFAC (base) vs LoRA (logix)] pearson:", pearsonr(kfac, logix_lora))
print("[EKFAC (base) vs EKFAC (logix)] pearson:", pearsonr(ekfac, logix_ekfac))
print("[EKFAC (base) vs KFAC (logix)] pearson:", pearsonr(ekfac, logix_kfac))
print("[EKFAC (base) vs LoRA (logix)] pearson:", pearsonr(ekfac, logix_lora))
logix = torch.load("if_logix.pt")
print("[KFAC (base) vs LogIX] pearson:", pearsonr(kfac, logix))
print("[EKFAC (base) vs LogIX] pearson:", pearsonr(ekfac, logix))
59 changes: 29 additions & 30 deletions examples/mnist/compute_influences.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import time
import argparse

from tqdm import tqdm
import torch
from logix import LogIX
from logix import LogIX, LogIXScheduler
from logix.utils import DataIDGenerator
from logix.analysis import InfluenceFunction

from train import (
get_mnist_dataloader,
Expand All @@ -15,7 +17,9 @@
parser.add_argument("--data", type=str, default="mnist", help="mnist or fmnist")
parser.add_argument("--eval-idxs", type=int, nargs="+", default=[0])
parser.add_argument("--damping", type=float, default=1e-5)
parser.add_argument("--resume", action="store_true")
parser.add_argument("--hessian", type=str, default="none")
parser.add_argument("--lora", type=str, default="none")
parser.add_argument("--save", type=str, default="grad")
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -30,56 +34,51 @@
train_loader = dataloader_fn(
batch_size=512, split="train", shuffle=False, subsample=True
)
query_loader = dataloader_fn(
test_loader = dataloader_fn(
batch_size=1, split="valid", shuffle=False, indices=args.eval_idxs
)

logix = LogIX(project="test")
logix = LogIX(project="test", config="./config.yaml")
scheduler = LogIXScheduler(logix, lora=args.lora, hessian=args.hessian, save=args.save)

# Gradient & Hessian logging
logix.watch(model)
logix.setup({"log": "grad", "save": "grad", "statistic": "kfac"})
id_gen = DataIDGenerator()

start = time.time()
if not args.resume:
id_gen = DataIDGenerator()
for inputs, targets in train_loader:
data_id = id_gen(inputs)
with logix(data_id=data_id):
for epoch in scheduler:
for inputs, targets in tqdm(train_loader):
with logix(data_id=id_gen(inputs)):
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
model.zero_grad()
outs = model(inputs)
loss = torch.nn.functional.cross_entropy(outs, targets, reduction="sum")
loss.backward()
logix.finalize()
else:
logix.initialize_from_log()
print("logging time:", time.time() - start)

# Influence Analysis
log_loader = logix.build_log_dataloader()
log_loader = logix.build_log_dataloader(batch_size=64, num_workers=0)

query_iter = iter(query_loader)
# logix.add_analysis({"influence": InfluenceFunction})
logix.setup({"log": "grad"})
logix.eval()
with logix(data_id=["test"]):
test_input, test_target = next(query_iter)
test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE)
model.zero_grad()
test_out = model(test_input)
test_loss = torch.nn.functional.cross_entropy(
test_out, test_target, reduction="sum"
for test_input, test_target in test_loader:
with logix(data_id=id_gen(test_input)):
test_input, test_target = test_input.to(DEVICE), test_target.to(DEVICE)
model.zero_grad()
test_out = model(test_input)
test_loss = torch.nn.functional.cross_entropy(
test_out, test_target, reduction="sum"
)
test_loss.backward()

test_log = logix.get_log()
result = logix.influence.compute_influence_all(
test_log, log_loader, damping=args.damping
)
test_loss.backward()
test_log = logix.get_log()
start = time.time()
result = logix.influence.compute_influence_all(
test_log, log_loader, damping=args.damping
)
break
_, top_influential_data = torch.topk(result["influence"], k=10)

# Save
if_scores = result["influence"].cpu().numpy().tolist()[0]
torch.save(if_scores, "if_logix.pt")
print("Computation time:", time.time() - start)
print("Top influential data indices:", top_influential_data.cpu().numpy().tolist())
Loading

0 comments on commit 8021ad2

Please sign in to comment.