-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
8e84706
commit 7dbddc0
Showing
4 changed files
with
280 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
import time | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from analog import AnaLog | ||
from analog.analysis import InfluenceFunction | ||
from analog.utils import DataIDGenerator | ||
from tqdm import tqdm | ||
|
||
from pipeline import construct_model, get_loaders | ||
from utils import set_seed | ||
|
||
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
set_seed(0) | ||
|
||
|
||
def single_checkpoint_influence( | ||
data_name: str, | ||
model_name: str, | ||
ckpt_path: str, | ||
save_name: str, | ||
train_batch_size=4, | ||
test_batch_size=4, | ||
train_indices=None, | ||
test_indices=None, | ||
): | ||
# model | ||
model = construct_model(model_name, ckpt_path) | ||
model.to(DEVICE) | ||
model.eval() | ||
|
||
# data | ||
_, eval_train_loader, test_loader = get_loaders(data_name=data_name) | ||
|
||
# Set-up | ||
analog = AnaLog(project="test", config="/data/tir/projects/tir6/general/hahn2/analog/examples/bert_influence/config.yaml") | ||
|
||
# Hessian logging | ||
analog.watch(model, type_filter=[torch.nn.Linear], lora=False) | ||
id_gen = DataIDGenerator() | ||
for batch in tqdm(eval_train_loader, desc="Hessian logging"): | ||
data_id = id_gen(batch["input_ids"]) | ||
with analog(data_id=data_id, log=[], save=False): | ||
inputs = ( | ||
batch["input_ids"].to(DEVICE), | ||
batch["token_type_ids"].to(DEVICE), | ||
batch["attention_mask"].to(DEVICE), | ||
) | ||
model.zero_grad() | ||
outputs = model(*inputs) | ||
|
||
logits = outputs.view(-1, outputs.shape[-1]) | ||
labels = batch["labels"].view(-1).to(DEVICE) | ||
loss = F.cross_entropy( | ||
logits, labels, reduction="sum", ignore_index=-100 | ||
) | ||
loss.backward() | ||
analog.finalize() | ||
|
||
# Compressed gradient logging | ||
analog.add_lora(model, parameter_sharing=False) | ||
for batch in tqdm(eval_train_loader, desc="Compressed gradient logging"): | ||
data_id = id_gen(batch["input_ids"]) | ||
with analog(data_id=data_id, log=["grad"], save=True): | ||
inputs = ( | ||
batch["input_ids"].to(DEVICE), | ||
batch["token_type_ids"].to(DEVICE), | ||
batch["attention_mask"].to(DEVICE), | ||
) | ||
model.zero_grad() | ||
outputs = model(*inputs) | ||
|
||
logits = outputs.view(-1, outputs.shape[-1]) | ||
labels = batch["labels"].view(-1).to(DEVICE) | ||
loss = F.cross_entropy( | ||
logits, labels, reduction="sum", ignore_index=-100, | ||
) | ||
loss.backward() | ||
analog.finalize() | ||
|
||
# Compute influence | ||
log_loader = analog.build_log_dataloader() | ||
analog.add_analysis({"influence": InfluenceFunction}) | ||
test_iter = iter(test_loader) | ||
with analog(log=["grad"], test=True) as al: | ||
test_batch = next(test_iter) | ||
test_inputs = ( | ||
test_batch["input_ids"].to(DEVICE), | ||
test_batch["token_type_ids"].to(DEVICE), | ||
test_batch["attention_mask"].to(DEVICE), | ||
) | ||
test_target = test_batch["labels"].to(DEVICE) | ||
model.zero_grad() | ||
test_outputs = model(*test_inputs) | ||
|
||
test_logits = test_outputs.view(-1, outputs.shape[-1]) | ||
test_labels = test_batch["labels"].view(-1).to(DEVICE) | ||
test_loss = F.cross_entropy( | ||
test_logits, test_labels, reduction="sum", ignore_index=-100, | ||
) | ||
test_loss.backward() | ||
|
||
test_log = al.get_log() | ||
|
||
start = time.time() | ||
if_scores = analog.influence.compute_influence_all(test_log, log_loader) | ||
print("Computation time:", time.time() - start) | ||
|
||
# Save | ||
torch.save(if_scores, "if_analog.pt") | ||
|
||
|
||
def main(): | ||
data_name = "sst2" | ||
model_name = "bert-base-uncased" | ||
ckpt_path = "/data/tir/projects/tir6/general/hahn2/analog/examples/bert_influence/files/checkpoints/0/sst2_epoch_3.pt" | ||
save_name = "sst2_score_if.pt" | ||
|
||
single_checkpoint_influence( | ||
data_name=data_name, | ||
model_name=model_name, | ||
ckpt_path=ckpt_path, | ||
save_name=save_name, | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import os | ||
import time | ||
import argparse | ||
from typing import Optional, Tuple | ||
|
||
from tqdm import tqdm | ||
import math | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import evaluate | ||
from torch.nn import CrossEntropyLoss | ||
|
||
from utils import clear_gpu_cache, set_seed, construct_model, get_loaders | ||
|
||
|
||
parser = argparse.ArgumentParser("MNIST Influence Analysis") | ||
parser.add_argument("--data_name", type=str, default="sst2") | ||
parser.add_argument("--num_train", type=int, default=1) | ||
args = parser.parse_args() | ||
|
||
os.makedirs("files/", exist_ok=True) | ||
os.makedirs("files/checkpoints", exist_ok=True) | ||
|
||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
train_loader, _, valid_loader = get_loaders(data_name=args.data_name) | ||
model = construct_model(data_name=args.data_name).to(device) | ||
|
||
def train( | ||
model: nn.Module, | ||
loader: torch.utils.data.DataLoader, | ||
model_id: int = 0, | ||
lr: float = 2e-5, | ||
weight_decay: float = 0.0, | ||
save_name: Optional[str] = None, | ||
) -> nn.Module: | ||
save = save_name is not None | ||
if save: | ||
os.makedirs(f"files/checkpoints/{model_id}", exist_ok=True) | ||
torch.save( | ||
model.state_dict(), | ||
f"files/checkpoints/{model_id}/{save_name}_epoch_0.pt", | ||
) | ||
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) | ||
loss_fn = CrossEntropyLoss() | ||
epochs = 3 | ||
|
||
num_update_steps_per_epoch = math.ceil(len(loader)) | ||
assert math.ceil(len(loader)) == num_update_steps_per_epoch | ||
|
||
model.train() | ||
num_iter = 0 | ||
for epoch in range(epochs): | ||
for batch in tqdm(loader): | ||
batch = {k: v.to(device) for k, v in batch.items()} | ||
optimizer.zero_grad() | ||
outputs = model( | ||
batch["input_ids"], batch["token_type_ids"], batch["attention_mask"] | ||
) | ||
loss = loss_fn(outputs, batch["labels"]) | ||
loss.backward() | ||
optimizer.step() | ||
num_iter += 1 | ||
|
||
if save: | ||
torch.save( | ||
model.state_dict(), | ||
f"files/checkpoints/{model_id}/{save_name}_epoch_{epoch + 1}.pt", | ||
) | ||
return model | ||
|
||
|
||
def model_evaluate( | ||
model: nn.Module, loader: torch.utils.data.DataLoader | ||
) -> Tuple[float, float]: | ||
model.eval() | ||
# Task name does not really matter here. | ||
metric = evaluate.load("glue", "qnli") | ||
total_loss, total_num = 0.0, 0.0 | ||
for step, batch in enumerate(loader): | ||
batch = {k: v.to(device) for k, v in batch.items()} | ||
with torch.no_grad(): | ||
outputs = model( | ||
batch["input_ids"], batch["token_type_ids"], batch["attention_mask"] | ||
) | ||
total_loss += ( | ||
F.cross_entropy(outputs, batch["labels"], reduction="sum").cpu().item() | ||
) | ||
total_num += batch["input_ids"].shape[0] | ||
predictions = outputs.argmax(dim=-1) | ||
metric.add_batch( | ||
predictions=predictions, | ||
references=batch["labels"], | ||
) | ||
eval_metric = metric.compute() | ||
return total_loss / total_num, eval_metric["accuracy"] | ||
|
||
|
||
for i in range(args.num_train): | ||
print(f"Training {i}th model ...") | ||
start_time = time.time() | ||
|
||
set_seed(i) | ||
|
||
train( | ||
model=model, | ||
loader=train_loader, | ||
model_id=i, | ||
save_name=args.data_name, | ||
) | ||
|
||
_, valid_acc = model_evaluate(model=model, loader=valid_loader) | ||
print(f"Validation Accuracy: {valid_acc}") | ||
del model | ||
clear_gpu_cache() | ||
print(f"Took {time.time() - start_time} seconds.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters