-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #27 from sangkeun00/mnist_ood_test
OOD example with MNIST and FMNIST
- Loading branch information
Showing
2 changed files
with
250 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import time | ||
import argparse | ||
|
||
import tqdm | ||
import numpy as np | ||
import torch | ||
|
||
from analog import AnaLog | ||
from analog.utils import DataIDGenerator | ||
from analog.analysis import InfluenceFunction | ||
from examples.mnist_influence.utils import ( | ||
get_mnist_dataloader, | ||
get_fmnist_dataloader, | ||
construct_mlp, | ||
) | ||
from examples.mnist_uncertainty.ood_utils import ( | ||
get_ood_input_processor, | ||
) | ||
|
||
parser = argparse.ArgumentParser("OOD Self-influce Score Analysis") | ||
parser.add_argument( | ||
"--id-data", | ||
type=str, | ||
default="mnist", | ||
help="mnist or fmnist; OOD is set to the other one", | ||
) | ||
parser.add_argument("--damping", type=float, default=1e-5) | ||
args = parser.parse_args() | ||
|
||
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
ood_data = "fmnist" if args.id_data == "mnist" else "mnist" | ||
model = construct_mlp().to(DEVICE) | ||
|
||
# Get a single checkpoint (first model_id and last epoch). | ||
model.load_state_dict( | ||
torch.load(f"checkpoints/{args.id_data}_0_epoch_9.pt", map_location="cpu") | ||
) | ||
model.eval() | ||
|
||
id_dataloader_fn = ( | ||
get_mnist_dataloader if args.id_data == "mnist" else get_fmnist_dataloader | ||
) | ||
ood_dataloader_fn = ( | ||
get_mnist_dataloader if ood_data == "mnist" else get_fmnist_dataloader | ||
) | ||
id_train_loader = id_dataloader_fn( | ||
batch_size=512, split="train", shuffle=False, subsample=True | ||
) | ||
id_query_loader = id_dataloader_fn( | ||
batch_size=512, | ||
split="valid", | ||
shuffle=False, | ||
subsample=True, | ||
) | ||
ood_query_loader = ood_dataloader_fn( | ||
batch_size=512, | ||
split="valid", | ||
shuffle=False, | ||
subsample=True, | ||
) | ||
ood_input_processor = get_ood_input_processor( | ||
source_data=ood_data, target_model=args.id_data | ||
) | ||
|
||
# Set-up | ||
analog = AnaLog(project="test") | ||
|
||
# Gradient & Hessian logging | ||
analog.watch(model, lora=False) | ||
id_gen = DataIDGenerator() | ||
for inputs, targets in id_train_loader: | ||
data_id = id_gen(inputs) | ||
with analog(data_id=data_id, log=["grad"], hessian=True, save=True): | ||
inputs, targets = inputs.to(DEVICE), targets.to(DEVICE) | ||
model.zero_grad() | ||
outs = model(inputs) | ||
loss = torch.nn.functional.cross_entropy(outs, targets, reduction="sum") | ||
loss.backward() | ||
analog.finalize() | ||
|
||
# Influence Analysis | ||
analog.add_analysis({"influence": InfluenceFunction}) | ||
|
||
print("Computing OOD self-influence scores...") | ||
ood_self_influence_scores = [] | ||
for ood_test_input, ood_test_target in tqdm.tqdm(ood_query_loader): | ||
with analog(log=["grad"], test=True) as al: | ||
ood_test_input = ood_input_processor(ood_test_input) | ||
ood_test_input, ood_test_target = ( | ||
ood_test_input.to(DEVICE), | ||
ood_test_target.to(DEVICE), | ||
) | ||
model.zero_grad() | ||
ood_test_out = model(ood_test_input) | ||
ood_test_loss = torch.nn.functional.cross_entropy( | ||
ood_test_out, ood_test_target, reduction="sum" | ||
) | ||
ood_test_loss.backward() | ||
ood_test_log = al.get_log() | ||
start = time.time() | ||
if_scores = analog.influence.compute_self_influence( | ||
ood_test_log, damping=args.damping | ||
) | ||
ood_self_influence_scores.append(if_scores.numpy().flatten()) | ||
|
||
print("Computing ID self-influence scores...") | ||
id_self_influence_scores = [] | ||
for id_test_input, id_test_target in tqdm.tqdm(id_query_loader): | ||
with analog(log=["grad"], test=True) as al: | ||
id_test_input, id_test_target = id_test_input.to(DEVICE), id_test_target.to( | ||
DEVICE | ||
) | ||
model.zero_grad() | ||
id_test_out = model(id_test_input) | ||
id_test_loss = torch.nn.functional.cross_entropy( | ||
id_test_out, id_test_target, reduction="sum" | ||
) | ||
id_test_loss.backward() | ||
id_test_log = al.get_log() | ||
start = time.time() | ||
if_scores = analog.influence.compute_self_influence( | ||
id_test_log, damping=args.damping | ||
) | ||
id_self_influence_scores.append(if_scores.numpy().flatten()) | ||
|
||
# Save | ||
ood_sif_scores = np.concatenate(ood_self_influence_scores) | ||
id_sif_scores = np.concatenate(id_self_influence_scores) | ||
torch.save(ood_sif_scores, "ood_sif_scores.pt") | ||
torch.save(id_sif_scores, "id_sif_scores.pt") | ||
print( | ||
f"OOD self-influence scores: mean={ood_sif_scores.mean():.2f}, std={ood_sif_scores.std():.2f}" | ||
) | ||
print( | ||
f"ID self-influence scores: mean={id_sif_scores.mean():.2f}, std={id_sif_scores.std():.2f}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import os | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from examples.mnist_influence.utils import ( | ||
get_mnist_dataloader, | ||
get_fmnist_dataloader, | ||
) | ||
|
||
MNIST = "mnist" | ||
FMNIST = "fmnist" | ||
CIFAR10 = "cifar10" | ||
DATA_LIST = [MNIST, FMNIST, CIFAR10] | ||
|
||
|
||
def compute_input_stats(dataloader: torch.utils.data.DataLoader): | ||
""" | ||
Compute the mean and standard deviation of the input data. | ||
""" | ||
# Initialize sum and sum of squares tensors | ||
sum_tensor = torch.zeros_like(next(iter(dataloader))[0][0]) | ||
sum_sq_tensor = torch.zeros_like(sum_tensor) | ||
|
||
# Make sure the dimensions are (channel, height, width) | ||
assert len(sum_tensor.shape) == 3 | ||
|
||
num_images = 0 | ||
# Loop over the DataLoader | ||
for images, _ in dataloader: | ||
sum_tensor += images.sum(dim=0) # Summing over the batch dimension | ||
sum_sq_tensor += (images**2).sum(dim=0) # Sum of squares | ||
num_images += images.size(0) # Counting the number of images | ||
|
||
# Calculate the mean and standard deviation | ||
mean_per_pixel = sum_tensor / num_images | ||
std_per_pixel = torch.sqrt((sum_sq_tensor / num_images) - (mean_per_pixel**2)) | ||
|
||
# Make sure the dimensions are (channel, height, width) | ||
assert len(mean_per_pixel.shape) == 3 | ||
|
||
return mean_per_pixel.detach().cpu().numpy(), std_per_pixel.detach().cpu().numpy() | ||
|
||
|
||
def save_input_stats(dataloader: torch.utils.data.DataLoader, save_name: str): | ||
""" | ||
Compute and save the mean and standard deviation of the input data. | ||
""" | ||
mean_per_pixel, std_per_pixel = compute_input_stats(dataloader) | ||
np.save(f"{save_name}_mean.npy", mean_per_pixel) | ||
np.save(f"{save_name}_std.npy", std_per_pixel) | ||
|
||
|
||
def get_ood_input_processor(source_data: str, target_model: str): | ||
""" | ||
Process OOD input data for the target model. | ||
""" | ||
assert source_data in DATA_LIST | ||
assert target_model in DATA_LIST | ||
for d in [source_data, target_model]: | ||
if not os.path.exists(f"{d}_train_mean.npy") or not os.path.exists( | ||
f"{d}_train_std.npy" | ||
): | ||
if d == MNIST: | ||
dataloader = get_mnist_dataloader( | ||
batch_size=10000, split="train", shuffle=False | ||
) | ||
elif d == FMNIST: | ||
dataloader = get_fmnist_dataloader( | ||
batch_size=10000, split="train", shuffle=False | ||
) | ||
else: | ||
raise ValueError(f"Unsupported data: {d}") | ||
save_input_stats(dataloader, f"{d}_train") | ||
|
||
if source_data == "mnist": | ||
mnist_tr_mean, mnist_tr_std = np.load("mnist_train_mean.npy"), np.load( | ||
"mnist_train_std.npy" | ||
) | ||
source_tr_mean = torch.from_numpy(mnist_tr_mean) | ||
source_tr_std = torch.from_numpy(mnist_tr_std) | ||
if target_model == "fmnist": | ||
fmnist_tr_mean, fmnist_tr_std = np.load("fmnist_train_mean.npy"), np.load( | ||
"fmnist_train_std.npy" | ||
) | ||
target_tr_mean = torch.from_numpy(fmnist_tr_mean) | ||
target_tr_std = torch.from_numpy(fmnist_tr_std) | ||
else: | ||
raise ValueError(f"Unsupported target model: {target_model}") | ||
elif source_data == "fmnist": | ||
fmnist_tr_mean, fmnist_tr_std = np.load("fmnist_train_mean.npy"), np.load( | ||
"fmnist_train_std.npy" | ||
) | ||
source_tr_mean = torch.from_numpy(fmnist_tr_mean) | ||
source_tr_std = torch.from_numpy(fmnist_tr_std) | ||
if target_model == "mnist": | ||
mnist_tr_mean, mnist_tr_std = np.load("mnist_train_mean.npy"), np.load( | ||
"mnist_train_std.npy" | ||
) | ||
target_tr_mean = torch.from_numpy(mnist_tr_mean) | ||
target_tr_std = torch.from_numpy(mnist_tr_std) | ||
else: | ||
raise ValueError(f"Unsupported target model: {target_model}") | ||
else: | ||
raise ValueError(f"Unsupported source data: {source_data}") | ||
|
||
def ood_input_processor(x): | ||
device = x.device | ||
x_transform = (x - source_tr_mean.to(device)) / source_tr_std.to(device) | ||
x_transform = x_transform * target_tr_std.to(device) + target_tr_mean.to(device) | ||
return x_transform | ||
|
||
return ood_input_processor |