diff --git a/cub2011.py b/cub2011.py index d3a38c3..0f9ace2 100644 --- a/cub2011.py +++ b/cub2011.py @@ -12,9 +12,13 @@ class Cub2011(ImageFolder, CIFAR10): tgz_md5 = '97eceeb196236b17998738112f37df78' train_list = [ - ['180.Wilson_Warbler/Wilson_Warbler_0002_175571.jpg', '0c763ca2ad60ed3ae43e76a04df63983'] + ['001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg', '4c84da568f89519f84640c54b7fba7c2'], + ['002.Laysan_Albatross/Laysan_Albatross_0001_545.jpg', 'e7db63424d0e384dba02aacaf298cdc0'], + ] + test_list = [ + ['198.Rock_Wren/Rock_Wren_0001_189289.jpg', '487d082f1fbd58faa7b08aa5ede3cc00'], + ['200.Common_Yellowthroat/Common_Yellowthroat_0003_190521.jpg', '96fd60ce4b4805e64368efc32bf5c6fe'] ] - test_list = [] def __init__(self, root, transform=None, target_transform=None, download=False, **kwargs): self.root = root @@ -25,11 +29,3 @@ def __init__(self, root, transform=None, target_transform=None, download=False, raise RuntimeError('Dataset not found or corrupted.' + ' You can use download=True to download it') ImageFolder.__init__(self, os.path.join(root, self.base_folder), transform = transform, target_transform = target_transform, **kwargs) - - def recall(self, embeddings, labels, K = 1): - prod = torch.mm(embeddings, embeddings.t()) - norm = prod.diag().unsqueeze(1).expand_as(prod) - D = norm + norm.t() - 2 * prod - - knn_inds = D.topk(1 + K, dim = 1, largest = False)[1][:, 1:] - return torch.Tensor([labels[knn_inds[i]].eq(labels[i]).max() for i in range(len(embeddings))]).mean() diff --git a/model.py b/model.py new file mode 100644 index 0000000..ab125df --- /dev/null +++ b/model.py @@ -0,0 +1,55 @@ +import random +import torch +import torch.nn as nn +import torch.optim as optim + +class EmbedderSimple(nn.Module): + def __init__(self, base_model, embedding_size = 128): + super(EmbedderSimple, self).__init__() + self.base_model = base_model + self.embedder = nn.Linear(base_model.output_size, embedding_size) + + def forward(self, input): + return self.embedder(self.base_model(input).view(input.size(0), -1)) + + def sampler(self, batch_size, dataset, train_classes): + '''lazy sampling, not like in lifted_struct. they add to the pool all postiive combinations, then compute the average number of positive pairs per image, then sample for every image the same number of negative pairs''' + images_by_class = {class_label_ind_train : [example_idx for example_idx, (image_file_name, class_label_ind) in enumerate(dataset.imgs) if class_label_ind == class_label_ind_train] for class_label_ind_train in range(train_classes)} + sample_from_class = lambda class_label_ind: images_by_class[class_label_ind][random.randrange(len(images_by_class[class_label_ind]))] + while True: + example_indices = [] + for i in range(0, batch_size, 2): + perm = random.sample(xrange(train_classes), 2) + example_indices += [sample_from_class(perm[0]), sample_from_class(perm[0 if i == 0 or random.random() > 0.5 else 1])] + yield example_indices + + optim_algo = optim.SGD + optim_params = dict(lr = 1e-5, momentum = 0.9, weight_decay = 2e-4, dampening = 0.9) + +def pdist_squared(A): + prod = torch.mm(A, A.t()) + norm = prod.diag().unsqueeze(1).expand_as(prod) + return (norm + norm.t() - 2 * prod).clamp(min = 0) + +class LiftedStruct(EmbedderSimple): + def criterion(self, input, labels, margin = 1.0, eps = 1e-4): + d = (pdist_squared(input) + eps).sqrt() + eps + pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(input) + neg_i = torch.mul((margin - d).exp(), 1 - pos).sum(1).expand_as(d) + return torch.sum(torch.mul(pos.triu(1), torch.log(neg_i + neg_i.t()) + d).clamp(min = 0).pow(2)) / (pos.sum() - len(d)) + +class Triplet(EmbedderSimple): + def criterion(self, input, labels, margin = 1.0): + d = pdist_squared(input) + pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(input) + T = d.unsqueeze(1).expand((len(d),) * 3) # [i][k][j] + M = pos.unsqueeze(1).expand_as(T) * (1 - pos.unsqueeze(2).expand_as(T)) + return (M * torch.clamp(T + T.transpose(1, 2) - margin, min = 0)).sum() / M.sum() #[i][k][j] = + +class TripletRatio(EmbedderSimple): + def criterion(self, input, labels, margin = 0.1, eps = 1e-4): + d = (pdist_squared(input) + eps).sqrt() + eps + pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(input) + T = d.unsqueeze(1).expand((len(d),) * 3) # [i][k][j] + M = pos.unsqueeze(1).expand_as(T) * (1 - pos.unsqueeze(2).expand_as(T)) + return (M * T.div(T.transpose(1, 2) + margin)).sum() / M.sum() #[i][k][j] = diff --git a/train.py b/train.py index 9274d82..086919d 100644 --- a/train.py +++ b/train.py @@ -4,14 +4,13 @@ import itertools import hickle import torch -import torch.nn as nn -import torch.optim as optim import torch.utils.data import torchvision.transforms as transforms from torch.autograd import Variable import googlenet import cub2011 +import model assert os.environ.get('CUDA_VISIBLE_DEVICES') @@ -26,55 +25,27 @@ BATCH_SIZE = 16 ) -def pairwise_euclidean_distance(A, eps = 1e-4): - prod = torch.mm(A, A.t()) - norm = prod.diag().unsqueeze(1).expand_as(prod) - return torch.sqrt((norm + norm.t() - 2 * prod).clamp(min = 0) + eps) + eps - -class LiftedStruct(nn.Module): - def __init__(self, base_model, embedding_size = 128): - super(LiftedStruct, self).__init__() - self.base_model = base_model - self.embedder = nn.Linear(base_model.output_size, embedding_size) - - def forward(self, input): - return self.embedder(self.base_model(input).view(input.size(0), -1)) - - def criterion(self, input, labels, margin = 1.0, eps = 1e-4): - d = pairwise_euclidean_distance(input, eps = eps) - pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(input) - m_d = margin - d - max_elem = m_d.max().unsqueeze(1).expand_as(m_d) - neg_i = torch.mul((m_d - max_elem).exp(), 1 - pos).sum(1).expand_as(d) - return torch.sum(torch.mul(pos.triu(1), torch.log(neg_i + neg_i.t()) + d + max_elem).clamp(min = 0).pow(2)) / (pos.sum() - len(d)) - - def sampler(self, batch_size, dataset, train_classes): - '''lazy sampling, not like in lifted_struct. they add to the pool all postiive combinations, then compute the average number of positive pairs per image, then sample for every image the same number of negative pairs''' - images_by_class = {class_label_ind_train : [example_idx for example_idx, (image_file_name, class_label_ind) in enumerate(dataset.imgs) if class_label_ind == class_label_ind_train] for class_label_ind_train in range(train_classes)} - sample_from_class = lambda class_label_ind: images_by_class[class_label_ind][random.randrange(len(images_by_class[class_label_ind]))] - while True: - example_indices = [] - for i in range(0, batch_size, 2): - perm = random.sample(xrange(train_classes), 2) - example_indices += [sample_from_class(perm[0]), sample_from_class(perm[0 if i == 0 or random.random() > 0.5 else 1])] - yield example_indices - - optim_algo = optim.SGD - optim_params = dict(lr = 1e-5, momentum = 0.9, weight_decay = 2e-4, dampening = 0.9) - def adapt_sampler(batch_size, dataset, sampler, **kwargs): return type('', (), dict( __len__ = dataset.__len__, __iter__ = lambda _: itertools.chain.from_iterable(sampler(batch_size, dataset, **kwargs)) ))() +def recall(self, embeddings, labels, K = 1): + prod = torch.mm(embeddings, embeddings.t()) + norm = prod.diag().unsqueeze(1).expand_as(prod) + D = norm + norm.t() - 2 * prod + + knn_inds = D.topk(1 + K, dim = 1, largest = False)[1][:, 1:] + return torch.Tensor([labels[knn_inds[i]].eq(labels[i]).max() for i in range(len(embeddings))]).mean() + for set_random_seed in [random.seed, torch.manual_seed, torch.cuda.manual_seed_all]: set_random_seed(opts['SEED']) base_model = googlenet.GoogLeNet() base_model_weights = hickle.load(opts['BASE_MODEL_WEIGHTS']) base_model.load_state_dict({k : torch.from_numpy(v) for k, v in base_model_weights.items()}) -model = LiftedStruct(base_model) +model = model.LiftedStruct(base_model) normalize = transforms.Compose([ transforms.ToTensor(), @@ -123,4 +94,4 @@ def adapt_sampler(batch_size, dataset, sampler, **kwargs): embeddings_all.append(output.data.cpu()) labels_all.append(labels.data.cpu()) print('eval {:>3}.{:05}'.format(epoch, batch_idx)) - log.write('recall@1 epoch {}: {}\n'.format(epoch, dataset_eval.recall(torch.cat(embeddings_all, 0), torch.cat(labels_all, 0)))) + log.write('recall@1 epoch {}: {}\n'.format(epoch, recall(torch.cat(embeddings_all, 0), torch.cat(labels_all, 0))))