Skip to content

Commit

Permalink
Merge pull request #27 from sangkeun00/mnist_ood_test
Browse files Browse the repository at this point in the history
OOD example with MNIST and FMNIST
  • Loading branch information
sangkeun00 authored Nov 27, 2023
2 parents 7440dda + 02e846d commit c44bb5f
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 0 deletions.
137 changes: 137 additions & 0 deletions examples/mnist_uncertainty/compute_ood_influences.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import time
import argparse

import tqdm
import numpy as np
import torch

from analog import AnaLog
from analog.utils import DataIDGenerator
from analog.analysis import InfluenceFunction
from examples.mnist_influence.utils import (
get_mnist_dataloader,
get_fmnist_dataloader,
construct_mlp,
)
from examples.mnist_uncertainty.ood_utils import (
get_ood_input_processor,
)

parser = argparse.ArgumentParser("OOD Self-influce Score Analysis")
parser.add_argument(
"--id-data",
type=str,
default="mnist",
help="mnist or fmnist; OOD is set to the other one",
)
parser.add_argument("--damping", type=float, default=1e-5)
args = parser.parse_args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ood_data = "fmnist" if args.id_data == "mnist" else "mnist"
model = construct_mlp().to(DEVICE)

# Get a single checkpoint (first model_id and last epoch).
model.load_state_dict(
torch.load(f"checkpoints/{args.id_data}_0_epoch_9.pt", map_location="cpu")
)
model.eval()

id_dataloader_fn = (
get_mnist_dataloader if args.id_data == "mnist" else get_fmnist_dataloader
)
ood_dataloader_fn = (
get_mnist_dataloader if ood_data == "mnist" else get_fmnist_dataloader
)
id_train_loader = id_dataloader_fn(
batch_size=512, split="train", shuffle=False, subsample=True
)
id_query_loader = id_dataloader_fn(
batch_size=512,
split="valid",
shuffle=False,
subsample=True,
)
ood_query_loader = ood_dataloader_fn(
batch_size=512,
split="valid",
shuffle=False,
subsample=True,
)
ood_input_processor = get_ood_input_processor(
source_data=ood_data, target_model=args.id_data
)

# Set-up
analog = AnaLog(project="test")

# Gradient & Hessian logging
analog.watch(model, lora=False)
id_gen = DataIDGenerator()
for inputs, targets in id_train_loader:
data_id = id_gen(inputs)
with analog(data_id=data_id, log=["grad"], hessian=True, save=True):
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE)
model.zero_grad()
outs = model(inputs)
loss = torch.nn.functional.cross_entropy(outs, targets, reduction="sum")
loss.backward()
analog.finalize()

# Influence Analysis
analog.add_analysis({"influence": InfluenceFunction})

print("Computing OOD self-influence scores...")
ood_self_influence_scores = []
for ood_test_input, ood_test_target in tqdm.tqdm(ood_query_loader):
with analog(log=["grad"], test=True) as al:
ood_test_input = ood_input_processor(ood_test_input)
ood_test_input, ood_test_target = (
ood_test_input.to(DEVICE),
ood_test_target.to(DEVICE),
)
model.zero_grad()
ood_test_out = model(ood_test_input)
ood_test_loss = torch.nn.functional.cross_entropy(
ood_test_out, ood_test_target, reduction="sum"
)
ood_test_loss.backward()
ood_test_log = al.get_log()
start = time.time()
if_scores = analog.influence.compute_self_influence(
ood_test_log, damping=args.damping
)
ood_self_influence_scores.append(if_scores.numpy().flatten())

print("Computing ID self-influence scores...")
id_self_influence_scores = []
for id_test_input, id_test_target in tqdm.tqdm(id_query_loader):
with analog(log=["grad"], test=True) as al:
id_test_input, id_test_target = id_test_input.to(DEVICE), id_test_target.to(
DEVICE
)
model.zero_grad()
id_test_out = model(id_test_input)
id_test_loss = torch.nn.functional.cross_entropy(
id_test_out, id_test_target, reduction="sum"
)
id_test_loss.backward()
id_test_log = al.get_log()
start = time.time()
if_scores = analog.influence.compute_self_influence(
id_test_log, damping=args.damping
)
id_self_influence_scores.append(if_scores.numpy().flatten())

# Save
ood_sif_scores = np.concatenate(ood_self_influence_scores)
id_sif_scores = np.concatenate(id_self_influence_scores)
torch.save(ood_sif_scores, "ood_sif_scores.pt")
torch.save(id_sif_scores, "id_sif_scores.pt")
print(
f"OOD self-influence scores: mean={ood_sif_scores.mean():.2f}, std={ood_sif_scores.std():.2f}"
)
print(
f"ID self-influence scores: mean={id_sif_scores.mean():.2f}, std={id_sif_scores.std():.2f}"
)
113 changes: 113 additions & 0 deletions examples/mnist_uncertainty/ood_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import os

import numpy as np
import torch

from examples.mnist_influence.utils import (
get_mnist_dataloader,
get_fmnist_dataloader,
)

MNIST = "mnist"
FMNIST = "fmnist"
CIFAR10 = "cifar10"
DATA_LIST = [MNIST, FMNIST, CIFAR10]


def compute_input_stats(dataloader: torch.utils.data.DataLoader):
"""
Compute the mean and standard deviation of the input data.
"""
# Initialize sum and sum of squares tensors
sum_tensor = torch.zeros_like(next(iter(dataloader))[0][0])
sum_sq_tensor = torch.zeros_like(sum_tensor)

# Make sure the dimensions are (channel, height, width)
assert len(sum_tensor.shape) == 3

num_images = 0
# Loop over the DataLoader
for images, _ in dataloader:
sum_tensor += images.sum(dim=0) # Summing over the batch dimension
sum_sq_tensor += (images**2).sum(dim=0) # Sum of squares
num_images += images.size(0) # Counting the number of images

# Calculate the mean and standard deviation
mean_per_pixel = sum_tensor / num_images
std_per_pixel = torch.sqrt((sum_sq_tensor / num_images) - (mean_per_pixel**2))

# Make sure the dimensions are (channel, height, width)
assert len(mean_per_pixel.shape) == 3

return mean_per_pixel.detach().cpu().numpy(), std_per_pixel.detach().cpu().numpy()


def save_input_stats(dataloader: torch.utils.data.DataLoader, save_name: str):
"""
Compute and save the mean and standard deviation of the input data.
"""
mean_per_pixel, std_per_pixel = compute_input_stats(dataloader)
np.save(f"{save_name}_mean.npy", mean_per_pixel)
np.save(f"{save_name}_std.npy", std_per_pixel)


def get_ood_input_processor(source_data: str, target_model: str):
"""
Process OOD input data for the target model.
"""
assert source_data in DATA_LIST
assert target_model in DATA_LIST
for d in [source_data, target_model]:
if not os.path.exists(f"{d}_train_mean.npy") or not os.path.exists(
f"{d}_train_std.npy"
):
if d == MNIST:
dataloader = get_mnist_dataloader(
batch_size=10000, split="train", shuffle=False
)
elif d == FMNIST:
dataloader = get_fmnist_dataloader(
batch_size=10000, split="train", shuffle=False
)
else:
raise ValueError(f"Unsupported data: {d}")
save_input_stats(dataloader, f"{d}_train")

if source_data == "mnist":
mnist_tr_mean, mnist_tr_std = np.load("mnist_train_mean.npy"), np.load(
"mnist_train_std.npy"
)
source_tr_mean = torch.from_numpy(mnist_tr_mean)
source_tr_std = torch.from_numpy(mnist_tr_std)
if target_model == "fmnist":
fmnist_tr_mean, fmnist_tr_std = np.load("fmnist_train_mean.npy"), np.load(
"fmnist_train_std.npy"
)
target_tr_mean = torch.from_numpy(fmnist_tr_mean)
target_tr_std = torch.from_numpy(fmnist_tr_std)
else:
raise ValueError(f"Unsupported target model: {target_model}")
elif source_data == "fmnist":
fmnist_tr_mean, fmnist_tr_std = np.load("fmnist_train_mean.npy"), np.load(
"fmnist_train_std.npy"
)
source_tr_mean = torch.from_numpy(fmnist_tr_mean)
source_tr_std = torch.from_numpy(fmnist_tr_std)
if target_model == "mnist":
mnist_tr_mean, mnist_tr_std = np.load("mnist_train_mean.npy"), np.load(
"mnist_train_std.npy"
)
target_tr_mean = torch.from_numpy(mnist_tr_mean)
target_tr_std = torch.from_numpy(mnist_tr_std)
else:
raise ValueError(f"Unsupported target model: {target_model}")
else:
raise ValueError(f"Unsupported source data: {source_data}")

def ood_input_processor(x):
device = x.device
x_transform = (x - source_tr_mean.to(device)) / source_tr_std.to(device)
x_transform = x_transform * target_tr_std.to(device) + target_tr_mean.to(device)
return x_transform

return ood_input_processor

0 comments on commit c44bb5f

Please sign in to comment.