Skip to content

Commit

Permalink
model.py file
Browse files Browse the repository at this point in the history
  • Loading branch information
vadimkantorov committed Apr 14, 2017
1 parent e27aeeb commit 0fb089a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 50 deletions.
16 changes: 6 additions & 10 deletions cub2011.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@ class Cub2011(ImageFolder, CIFAR10):
tgz_md5 = '97eceeb196236b17998738112f37df78'

train_list = [
['180.Wilson_Warbler/Wilson_Warbler_0002_175571.jpg', '0c763ca2ad60ed3ae43e76a04df63983']
['001.Black_footed_Albatross/Black_Footed_Albatross_0001_796111.jpg', '4c84da568f89519f84640c54b7fba7c2'],
['002.Laysan_Albatross/Laysan_Albatross_0001_545.jpg', 'e7db63424d0e384dba02aacaf298cdc0'],
]
test_list = [
['198.Rock_Wren/Rock_Wren_0001_189289.jpg', '487d082f1fbd58faa7b08aa5ede3cc00'],
['200.Common_Yellowthroat/Common_Yellowthroat_0003_190521.jpg', '96fd60ce4b4805e64368efc32bf5c6fe']
]
test_list = []

def __init__(self, root, transform=None, target_transform=None, download=False, **kwargs):
self.root = root
Expand All @@ -25,11 +29,3 @@ def __init__(self, root, transform=None, target_transform=None, download=False,
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
ImageFolder.__init__(self, os.path.join(root, self.base_folder), transform = transform, target_transform = target_transform, **kwargs)

def recall(self, embeddings, labels, K = 1):
prod = torch.mm(embeddings, embeddings.t())
norm = prod.diag().unsqueeze(1).expand_as(prod)
D = norm + norm.t() - 2 * prod

knn_inds = D.topk(1 + K, dim = 1, largest = False)[1][:, 1:]
return torch.Tensor([labels[knn_inds[i]].eq(labels[i]).max() for i in range(len(embeddings))]).mean()
55 changes: 55 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import random
import torch
import torch.nn as nn
import torch.optim as optim

class EmbedderSimple(nn.Module):
def __init__(self, base_model, embedding_size = 128):
super(EmbedderSimple, self).__init__()
self.base_model = base_model
self.embedder = nn.Linear(base_model.output_size, embedding_size)

def forward(self, input):
return self.embedder(self.base_model(input).view(input.size(0), -1))

def sampler(self, batch_size, dataset, train_classes):
'''lazy sampling, not like in lifted_struct. they add to the pool all postiive combinations, then compute the average number of positive pairs per image, then sample for every image the same number of negative pairs'''
images_by_class = {class_label_ind_train : [example_idx for example_idx, (image_file_name, class_label_ind) in enumerate(dataset.imgs) if class_label_ind == class_label_ind_train] for class_label_ind_train in range(train_classes)}
sample_from_class = lambda class_label_ind: images_by_class[class_label_ind][random.randrange(len(images_by_class[class_label_ind]))]
while True:
example_indices = []
for i in range(0, batch_size, 2):
perm = random.sample(xrange(train_classes), 2)
example_indices += [sample_from_class(perm[0]), sample_from_class(perm[0 if i == 0 or random.random() > 0.5 else 1])]
yield example_indices

optim_algo = optim.SGD
optim_params = dict(lr = 1e-5, momentum = 0.9, weight_decay = 2e-4, dampening = 0.9)

def pdist_squared(A):
prod = torch.mm(A, A.t())
norm = prod.diag().unsqueeze(1).expand_as(prod)
return (norm + norm.t() - 2 * prod).clamp(min = 0)

class LiftedStruct(EmbedderSimple):
def criterion(self, input, labels, margin = 1.0, eps = 1e-4):
d = (pdist_squared(input) + eps).sqrt() + eps
pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(input)
neg_i = torch.mul((margin - d).exp(), 1 - pos).sum(1).expand_as(d)
return torch.sum(torch.mul(pos.triu(1), torch.log(neg_i + neg_i.t()) + d).clamp(min = 0).pow(2)) / (pos.sum() - len(d))

class Triplet(EmbedderSimple):
def criterion(self, input, labels, margin = 1.0):
d = pdist_squared(input)
pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(input)
T = d.unsqueeze(1).expand((len(d),) * 3) # [i][k][j]
M = pos.unsqueeze(1).expand_as(T) * (1 - pos.unsqueeze(2).expand_as(T))
return (M * torch.clamp(T + T.transpose(1, 2) - margin, min = 0)).sum() / M.sum() #[i][k][j] =

class TripletRatio(EmbedderSimple):
def criterion(self, input, labels, margin = 0.1, eps = 1e-4):
d = (pdist_squared(input) + eps).sqrt() + eps
pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(input)
T = d.unsqueeze(1).expand((len(d),) * 3) # [i][k][j]
M = pos.unsqueeze(1).expand_as(T) * (1 - pos.unsqueeze(2).expand_as(T))
return (M * T.div(T.transpose(1, 2) + margin)).sum() / M.sum() #[i][k][j] =
51 changes: 11 additions & 40 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
import itertools
import hickle
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
from torch.autograd import Variable

import googlenet
import cub2011
import model

assert os.environ.get('CUDA_VISIBLE_DEVICES')

Expand All @@ -26,55 +25,27 @@
BATCH_SIZE = 16
)

def pairwise_euclidean_distance(A, eps = 1e-4):
prod = torch.mm(A, A.t())
norm = prod.diag().unsqueeze(1).expand_as(prod)
return torch.sqrt((norm + norm.t() - 2 * prod).clamp(min = 0) + eps) + eps

class LiftedStruct(nn.Module):
def __init__(self, base_model, embedding_size = 128):
super(LiftedStruct, self).__init__()
self.base_model = base_model
self.embedder = nn.Linear(base_model.output_size, embedding_size)

def forward(self, input):
return self.embedder(self.base_model(input).view(input.size(0), -1))

def criterion(self, input, labels, margin = 1.0, eps = 1e-4):
d = pairwise_euclidean_distance(input, eps = eps)
pos = torch.eq(*[labels.unsqueeze(dim).expand_as(d) for dim in [0, 1]]).type_as(input)
m_d = margin - d
max_elem = m_d.max().unsqueeze(1).expand_as(m_d)
neg_i = torch.mul((m_d - max_elem).exp(), 1 - pos).sum(1).expand_as(d)
return torch.sum(torch.mul(pos.triu(1), torch.log(neg_i + neg_i.t()) + d + max_elem).clamp(min = 0).pow(2)) / (pos.sum() - len(d))

def sampler(self, batch_size, dataset, train_classes):
'''lazy sampling, not like in lifted_struct. they add to the pool all postiive combinations, then compute the average number of positive pairs per image, then sample for every image the same number of negative pairs'''
images_by_class = {class_label_ind_train : [example_idx for example_idx, (image_file_name, class_label_ind) in enumerate(dataset.imgs) if class_label_ind == class_label_ind_train] for class_label_ind_train in range(train_classes)}
sample_from_class = lambda class_label_ind: images_by_class[class_label_ind][random.randrange(len(images_by_class[class_label_ind]))]
while True:
example_indices = []
for i in range(0, batch_size, 2):
perm = random.sample(xrange(train_classes), 2)
example_indices += [sample_from_class(perm[0]), sample_from_class(perm[0 if i == 0 or random.random() > 0.5 else 1])]
yield example_indices

optim_algo = optim.SGD
optim_params = dict(lr = 1e-5, momentum = 0.9, weight_decay = 2e-4, dampening = 0.9)

def adapt_sampler(batch_size, dataset, sampler, **kwargs):
return type('', (), dict(
__len__ = dataset.__len__,
__iter__ = lambda _: itertools.chain.from_iterable(sampler(batch_size, dataset, **kwargs))
))()

def recall(self, embeddings, labels, K = 1):
prod = torch.mm(embeddings, embeddings.t())
norm = prod.diag().unsqueeze(1).expand_as(prod)
D = norm + norm.t() - 2 * prod

knn_inds = D.topk(1 + K, dim = 1, largest = False)[1][:, 1:]
return torch.Tensor([labels[knn_inds[i]].eq(labels[i]).max() for i in range(len(embeddings))]).mean()

for set_random_seed in [random.seed, torch.manual_seed, torch.cuda.manual_seed_all]:
set_random_seed(opts['SEED'])

base_model = googlenet.GoogLeNet()
base_model_weights = hickle.load(opts['BASE_MODEL_WEIGHTS'])
base_model.load_state_dict({k : torch.from_numpy(v) for k, v in base_model_weights.items()})
model = LiftedStruct(base_model)
model = model.LiftedStruct(base_model)

normalize = transforms.Compose([
transforms.ToTensor(),
Expand Down Expand Up @@ -123,4 +94,4 @@ def adapt_sampler(batch_size, dataset, sampler, **kwargs):
embeddings_all.append(output.data.cpu())
labels_all.append(labels.data.cpu())
print('eval {:>3}.{:05}'.format(epoch, batch_idx))
log.write('recall@1 epoch {}: {}\n'.format(epoch, dataset_eval.recall(torch.cat(embeddings_all, 0), torch.cat(labels_all, 0))))
log.write('recall@1 epoch {}: {}\n'.format(epoch, recall(torch.cat(embeddings_all, 0), torch.cat(labels_all, 0))))

0 comments on commit 0fb089a

Please sign in to comment.