diff --git a/configs/env.yml b/configs/env.yml index 8e1d688e..b025c3a9 100644 --- a/configs/env.yml +++ b/configs/env.yml @@ -1 +1 @@ -root_dir: /path/where/to/store/results/ +root_dir: /content/data/ diff --git a/configs/pretext/simclr_birds.yml b/configs/pretext/simclr_birds.yml new file mode 100644 index 00000000..f1a34bdc --- /dev/null +++ b/configs/pretext/simclr_birds.yml @@ -0,0 +1,57 @@ +# Setup +setup: simclr + +# Model +backbone: resnet50 +model_kwargs: + head: mlp + features_dim: 2048 + +# Dataset +train_db_name: birds +val_db_name: birds +num_classes: 200 + +# Loss +criterion: simclr +criterion_kwargs: + temperature: 0.07 + +# Hyperparameters +epochs: 100 +optimizer: sgd +optimizer_kwargs: + nesterov: False + weight_decay: 0.0001 + momentum: 0.9 + lr: 0.005625 +scheduler: cosine +scheduler_kwargs: + lr_decay_rate: 0.1 +batch_size: 48 +num_workers: 8 + +# Transformations +augmentation_strategy: simclr +augmentation_kwargs: + random_resized_crop: + size: 224 + scale: [0.2, 1.0] + color_jitter_random_apply: + p: 0.8 + color_jitter: + brightness: 0.4 + contrast: 0.4 + saturation: 0.4 + hue: 0.1 + random_grayscale: + p: 0.2 + normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +transformation_kwargs: + crop_size: 224 + normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] diff --git a/configs/pretext/simclr_mnist.yml b/configs/pretext/simclr_mnist.yml new file mode 100644 index 00000000..c33588eb --- /dev/null +++ b/configs/pretext/simclr_mnist.yml @@ -0,0 +1,57 @@ +# Setup +setup: simclr + +# Model +backbone: resnet18 +model_kwargs: + head: mlp + features_dim: 128 + +# Dataset +train_db_name: mnist +val_db_name: mnist +num_classes: 10 + +# Loss +criterion: simclr +criterion_kwargs: + temperature: 0.1 + +# Hyperparameters +epochs: 5 +optimizer: sgd +optimizer_kwargs: + nesterov: False + weight_decay: 0.0001 + momentum: 0.9 + lr: 0.4 +scheduler: cosine +scheduler_kwargs: + lr_decay_rate: 0.1 +batch_size: 48 +num_workers: 8 + +# Transformations +augmentation_strategy: simclr +augmentation_kwargs: + random_resized_crop: + size: 32 + scale: [0.2, 1.0] + color_jitter_random_apply: + p: 0.8 + color_jitter: + brightness: 0.4 + contrast: 0.4 + saturation: 0.4 + hue: 0.1 + random_grayscale: + p: 0.2 + normalize: + mean: [0.5] + std: [0.5] + +transformation_kwargs: + crop_size: 32 + normalize: + mean: [0.5] + std: [0.5] diff --git a/configs/scan/scan_birds.yml b/configs/scan/scan_birds.yml new file mode 100644 index 00000000..030e793d --- /dev/null +++ b/configs/scan/scan_birds.yml @@ -0,0 +1,59 @@ +# setup +setup: scan + +# Loss +criterion: scan +criterion_kwargs: + entropy_weight: 5.0 + +# Model +backbone: resnet50 + +# Weight update +update_cluster_head_only: True # Train only linear layer during SCAN +num_heads: 10 # Use multiple heads + +# Dataset +train_db_name: birds +val_db_name: birds +num_classes: 200 +num_neighbors: 50 + +# Transformations +augmentation_strategy: simclr +augmentation_kwargs: + random_resized_crop: + size: 224 + scale: [0.2, 1.0] + color_jitter_random_apply: + p: 0.8 + color_jitter: + brightness: 0.4 + contrast: 0.4 + saturation: 0.4 + hue: 0.1 + random_grayscale: + p: 0.2 + normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +transformation_kwargs: + crop_size: 224 + normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +# Hyperparameters +optimizer: sgd +optimizer_kwargs: + lr: 5.0 + weight_decay: 0.0000 + nesterov: False + momentum: 0.9 +epochs: 10 +batch_size: 32 +num_workers: 8 + +# Scheduler +scheduler: constant diff --git a/configs/scan/scan_mnist.yml b/configs/scan/scan_mnist.yml new file mode 100644 index 00000000..78bde7fc --- /dev/null +++ b/configs/scan/scan_mnist.yml @@ -0,0 +1,51 @@ +# setup +setup: scan + +# Loss +criterion: scan +criterion_kwargs: + entropy_weight: 5.0 + +# Weight update +update_cluster_head_only: False # Update full network in SCAN +num_heads: 1 # Only use one head + +# Model +backbone: resnet18 + +# Dataset +train_db_name: mnist +val_db_name: mnist +num_classes: 10 +num_neighbors: 20 + +# Transformations +augmentation_strategy: ours +augmentation_kwargs: + crop_size: 32 + normalize: + mean: [0.5] + std: [0.5] + num_strong_augs: 4 + cutout_kwargs: + n_holes: 1 + length: 16 + random: True + +transformation_kwargs: + crop_size: 32 + normalize: + mean: [0.5] + std: [0.5] + +# Hyperparameters +optimizer: adam +optimizer_kwargs: + lr: 0.0001 + weight_decay: 0.0001 +epochs: 5 +batch_size: 48 +num_workers: 8 + +# Scheduler +scheduler: constant diff --git a/configs/selflabel/selflabel_birds.yml b/configs/selflabel/selflabel_birds.yml new file mode 100644 index 00000000..54a6caa2 --- /dev/null +++ b/configs/selflabel/selflabel_birds.yml @@ -0,0 +1,56 @@ +# setup +setup: selflabel + +# Threshold +confidence_threshold: 0.99 + +# EMA +use_ema: True +ema_alpha: 0.999 + +# Loss +criterion: confidence-cross-entropy +criterion_kwargs: + apply_class_balancing: False + +# Model +backbone: resnet50 +num_heads: 1 + +# Dataset +train_db_name: birds +val_db_name: birds +num_classes: 200 + +# Transformations +augmentation_strategy: ours +augmentation_kwargs: + crop_size: 224 + normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + num_strong_augs: 4 + cutout_kwargs: + n_holes: 1 + length: 75 + random: True + +transformation_kwargs: + crop_size: 224 + normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +# Hyperparameters +optimizer: sgd +optimizer_kwargs: + lr: 0.03 + weight_decay: 0.0 + nesterov: False + momentum: 0.9 +epochs: 10 +batch_size: 32 +num_workers: 8 + +# Scheduler +scheduler: constant diff --git a/configs/selflabel/selflabel_mnist.yml b/configs/selflabel/selflabel_mnist.yml new file mode 100644 index 00000000..7b3c7967 --- /dev/null +++ b/configs/selflabel/selflabel_mnist.yml @@ -0,0 +1,53 @@ +# setup +setup: selflabel + +# ema +use_ema: False + +# Threshold +confidence_threshold: 0.99 + +# Criterion +criterion: confidence-cross-entropy +criterion_kwargs: + apply_class_balancing: True + +# Model +backbone: resnet18 +num_heads: 1 + +# Dataset +train_db_name: mnist +val_db_name: mnist +num_classes: 10 + +# Transformations +augmentation_strategy: ours +augmentation_kwargs: + crop_size: 32 + normalize: + mean: [0.5] + std: [0.5] + num_strong_augs: 4 + cutout_kwargs: + n_holes: 1 + length: 16 + random: True + +transformation_kwargs: + crop_size: 32 + normalize: + mean: [0.5] + std: [0.5] + +# Hyperparameters +epochs: 5 +optimizer: adam +optimizer_kwargs: + lr: 0.0001 + weight_decay: 0.0001 +batch_size: 48 +num_workers: 8 + +# Scheduler +scheduler: constant diff --git a/data/imagenet.py b/data/imagenet.py index ec2a0285..4c3db2aa 100644 --- a/data/imagenet.py +++ b/data/imagenet.py @@ -16,10 +16,10 @@ class ImageNet(datasets.ImageFolder): def __init__(self, root=MyPath.db_root_dir('imagenet'), split='train', transform=None): super(ImageNet, self).__init__(root=os.path.join(root, 'ILSVRC2012_img_%s' %(split)), transform=None) - self.transform = transform + self.transform = transform self.split = split self.resize = tf.Resize(256) - + def __len__(self): return len(self.imgs) @@ -41,12 +41,42 @@ def get_image(self, index): path, target = self.imgs[index] with open(path, 'rb') as f: img = Image.open(f).convert('RGB') - img = self.resize(img) + img = self.resize(img) return img +class Birds(datasets.ImageFolder): + def __init__(self, root=MyPath.db_root_dir('birds'), split='train', transform=None): + super(Birds, self).__init__(root=os.path.join(root, split), transform=None) + self.transform = transform + self.split = split + self.resize = tf.Resize(256) + + def __len__(self): + return len(self.imgs) + + def __getitem__(self, index): + path, target = self.imgs[index] + with open(path, 'rb') as f: + img = Image.open(f).convert('RGB') + im_size = img.size + img = self.resize(img) + + if self.transform is not None: + img = self.transform(img) + + out = {'image': img, 'target': target, 'meta': {'im_size': im_size, 'index': index}} + + return out + + def get_image(self, index): + path, target = self.imgs[index] + with open(path, 'rb') as f: + img = Image.open(f).convert('RGB') + img = self.resize(img) + return img class ImageNetSubset(data.Dataset): - def __init__(self, subset_file, root=MyPath.db_root_dir('imagenet'), split='train', + def __init__(self, subset_file, root=MyPath.db_root_dir('imagenet'), split='train', transform=None): super(ImageNetSubset, self).__init__() @@ -69,10 +99,10 @@ def __init__(self, subset_file, root=MyPath.db_root_dir('imagenet'), split='trai subdir_path = os.path.join(self.root, subdir) files = sorted(glob(os.path.join(self.root, subdir, '*.JPEG'))) for f in files: - imgs.append((f, i)) - self.imgs = imgs + imgs.append((f, i)) + self.imgs = imgs self.classes = class_names - + # Resize self.resize = tf.Resize(256) @@ -80,7 +110,7 @@ def get_image(self, index): path, target = self.imgs[index] with open(path, 'rb') as f: img = Image.open(f).convert('RGB') - img = self.resize(img) + img = self.resize(img) return img def __len__(self): @@ -91,7 +121,7 @@ def __getitem__(self, index): with open(path, 'rb') as f: img = Image.open(f).convert('RGB') im_size = img.size - img = self.resize(img) + img = self.resize(img) class_name = self.classes[target] if self.transform is not None: diff --git a/data/mnist.py b/data/mnist.py new file mode 100644 index 00000000..fd2acabd --- /dev/null +++ b/data/mnist.py @@ -0,0 +1,174 @@ +from torchvision.datasets.vision import VisionDataset +import warnings +from PIL import Image +import os +import os.path +import numpy as np +import torch +import codecs +import string +import gzip +import lzma +from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union +from torchvision.datasets.utils import download_url, download_and_extract_archive, extract_archive, \ + verify_str_arg +from utils.mypath import MyPath + +class MNIST(VisionDataset): + """`MNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``MNIST/processed/training.pt`` + and ``MNIST/processed/test.pt`` exist. + train (bool, optional): If True, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + resources = [ + ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), + ("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), + ("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), + ("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c") + ] + + training_file = 'training.pt' + test_file = 'test.pt' + classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', + '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] + + @property + def train_labels(self): + warnings.warn("train_labels has been renamed targets") + return self.targets + + @property + def test_labels(self): + warnings.warn("test_labels has been renamed targets") + return self.targets + + @property + def train_data(self): + warnings.warn("train_data has been renamed data") + return self.data + + @property + def test_data(self): + warnings.warn("test_data has been renamed data") + return self.data + + def __init__( + self, + root: str = MyPath.db_root_dir('mnist'), + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: + super(MNIST, self).__init__(root, transform=transform, + target_transform=target_transform) + self.train = train # training set or test set + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError('Dataset not found.' + + ' You can use download=True to download it') + + if self.train: + data_file = self.training_file + else: + data_file = self.test_file + self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], int(self.targets[index]) + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img.numpy(), mode='L') + + if self.transform is not None: + img = self.transform(img) + + if self.target_transform is not None: + target = self.target_transform(target) + + out = {'image': img, 'target': target, 'meta': {'im_size': img_size, 'index': index}} + + return out + + # Note: For eval.py. From /data/cifar.py + def get_image(self, index): + img = self.data[index] + return img + + def __len__(self) -> int: + return len(self.data) + + @property + def raw_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, 'raw') + + @property + def processed_folder(self) -> str: + return os.path.join(self.root, self.__class__.__name__, 'processed') + + @property + def class_to_idx(self) -> Dict[str, int]: + return {_class: i for i, _class in enumerate(self.classes)} + + def _check_exists(self) -> bool: + return (os.path.exists(os.path.join(self.processed_folder, + self.training_file)) and + os.path.exists(os.path.join(self.processed_folder, + self.test_file))) + + def download(self) -> None: + """Download the MNIST data if it doesn't exist in processed_folder already.""" + + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + os.makedirs(self.processed_folder, exist_ok=True) + + # download files + for url, md5 in self.resources: + filename = url.rpartition('/')[2] + download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) + + # process and save as torch files + print('Processing...') + + training_set = ( + read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')), + read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte')) + ) + test_set = ( + read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')), + read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte')) + ) + with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f: + torch.save(training_set, f) + with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f: + torch.save(test_set, f) + + print('Done!') + + def extra_repr(self) -> str: + return "Split: {}".format("Train" if self.train is True else "Test") diff --git a/models/resnet_mnist.py b/models/resnet_mnist.py new file mode 100644 index 00000000..c215b847 --- /dev/null +++ b/models/resnet_mnist.py @@ -0,0 +1,125 @@ +""" +This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1, is_last=False): + super(BasicBlock, self).__init__() + self.is_last = is_last + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + preact = out + out = F.relu(out) + if self.is_last: + return out, preact + else: + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, is_last=False): + super(Bottleneck, self).__init__() + self.is_last = is_last + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + preact = out + out = F.relu(out) + if self.is_last: + return out, preact + else: + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, in_channel=1, zero_init_residual=False): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves + # like an identity. This improves the model by 0.2~0.3% according to: + # https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for i in range(num_blocks): + stride = strides[i] + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = self.avgpool(out) + out = torch.flatten(out, 1) + return out + + +def resnet18(**kwargs): + return {'backbone': ResNet(BasicBlock, [2, 2, 2, 2], **kwargs), 'dim': 512} diff --git a/organize_train_test.py b/organize_train_test.py new file mode 100644 index 00000000..529ed35e --- /dev/null +++ b/organize_train_test.py @@ -0,0 +1,37 @@ +# https://github.com/ecm200/caltech_birds/blob/master/scripts/organise_train_test.py + +import os +import pandas as pd +import shutil + +# Script runtime options +root_dir = '/content/Unsupervised-Classification/data/CUB_200_2011' # change as needed +data_dir = os.path.join(root_dir,'images') + +image_fnames = pd.read_csv(filepath_or_buffer=os.path.join(root_dir,'images.txt'), + header=None, + delimiter=' ', + names=['Img ID', 'file path']) + +image_fnames['is training image?'] = pd.read_csv(filepath_or_buffer=os.path.join(root_dir,'train_test_split.txt'), + header=None, delimiter=' ', + names=['Img ID','is training image?'])['is training image?'] + +os.makedirs(os.path.join(data_dir,'train'), exist_ok=True) +os.makedirs(os.path.join(data_dir,'test'), exist_ok=True) + +for i_image, image_fname in enumerate(image_fnames['file path']): + if image_fnames['is training image?'].iloc[i_image]: + new_dir = os.path.join(data_dir,'train',image_fname.split('/')[0]) + os.makedirs(new_dir, exist_ok=True) + shutil.copy(src=os.path.join(data_dir,image_fname), dst=os.path.join(new_dir, image_fname.split('/')[1])) + print(i_image, ':: Image is in training set. [', bool(image_fnames['is training image?'].iloc[i_image]),']') + print('Image:: ', image_fname) + print('Destination:: ', new_dir) + else: + new_dir = os.path.join(data_dir,'test',image_fname.split('/')[0]) + os.makedirs(new_dir, exist_ok=True) + shutil.copy(src=os.path.join(data_dir,image_fname), dst=os.path.join(new_dir, image_fname.split('/')[1])) + print(i_image, ':: Image is in testing set. [', bool(image_fnames['is training image?'].iloc[i_image]),']') + print('Source Image:: ', image_fname) + print('Destination:: ', new_dir) diff --git a/utils/common_config.py b/utils/common_config.py index 60d5bdf6..3958eb78 100755 --- a/utils/common_config.py +++ b/utils/common_config.py @@ -10,7 +10,7 @@ from data.augment import Augment, Cutout from utils.collate import collate_custom - + def get_criterion(p): if p['criterion'] == 'simclr': from losses.losses import SimCLRLoss @@ -51,17 +51,26 @@ def get_model(p, pretrain_path=None): elif p['train_db_name'] == 'stl-10': from models.resnet_stl import resnet18 backbone = resnet18() - + + elif p['train_db_name'] == 'mnist': + from models.resnet_mnist import resnet18 + backbone = resnet18() + else: raise NotImplementedError elif p['backbone'] == 'resnet50': if 'imagenet' in p['train_db_name']: from models.resnet import resnet50 - backbone = resnet50() + backbone = resnet50() + + # added birds with resnet50 + elif p['train_db_name'] in ['birds', 'mnist']: + from models.resnet import resnet50 + backbone = resnet50() else: - raise NotImplementedError + raise NotImplementedError else: raise ValueError('Invalid backbone {}'.format(p['backbone'])) @@ -83,16 +92,16 @@ def get_model(p, pretrain_path=None): # Load pretrained weights if pretrain_path is not None and os.path.exists(pretrain_path): state = torch.load(pretrain_path, map_location='cpu') - + if p['setup'] == 'scan': # Weights are supposed to be transfered from contrastive training missing = model.load_state_dict(state, strict=False) assert(set(missing[1]) == { - 'contrastive_head.0.weight', 'contrastive_head.0.bias', + 'contrastive_head.0.weight', 'contrastive_head.0.bias', 'contrastive_head.2.weight', 'contrastive_head.2.bias'} or set(missing[1]) == { 'contrastive_head.weight', 'contrastive_head.bias'}) - elif p['setup'] == 'selflabel': # Weights are supposed to be transfered from scan + elif p['setup'] == 'selflabel': # Weights are supposed to be transfered from scan # We only continue with the best head (pop all heads first, then copy back the best head) model_state = state['model'] all_heads = [k for k in model_state.keys() if 'cluster_head' in k] @@ -132,6 +141,10 @@ def get_train_dataset(p, transform, to_augmented_dataset=False, from data.stl import STL10 dataset = STL10(split=split, transform=transform, download=True) + elif p['train_db_name'] == 'mnist': + from data.mnist import MNIST + dataset = MNIST(train=True, transform=transform, download=True) + elif p['train_db_name'] == 'imagenet': from data.imagenet import ImageNet dataset = ImageNet(split='train', transform=transform) @@ -141,9 +154,14 @@ def get_train_dataset(p, transform, to_augmented_dataset=False, subset_file = './data/imagenet_subsets/%s.txt' %(p['train_db_name']) dataset = ImageNetSubset(subset_file=subset_file, split='train', transform=transform) + # added birds train dataset + elif p['train_db_name'] == 'birds': + from data.imagenet import Birds + dataset = Birds(split='train', transform=transform) + else: raise ValueError('Invalid train dataset {}'.format(p['train_db_name'])) - + # Wrap into other dataset (__getitem__ changes) if to_augmented_dataset: # Dataset returns an image and an augmentation of that image. from data.custom_dataset import AugmentedDataset @@ -153,7 +171,7 @@ def get_train_dataset(p, transform, to_augmented_dataset=False, from data.custom_dataset import NeighborsDataset indices = np.load(p['topk_neighbors_train_path']) dataset = NeighborsDataset(dataset, indices, p['num_neighbors']) - + return dataset @@ -162,28 +180,37 @@ def get_val_dataset(p, transform=None, to_neighbors_dataset=False): if p['val_db_name'] == 'cifar-10': from data.cifar import CIFAR10 dataset = CIFAR10(train=False, transform=transform, download=True) - + elif p['val_db_name'] == 'cifar-20': from data.cifar import CIFAR20 dataset = CIFAR20(train=False, transform=transform, download=True) + elif p['val_db_name'] == 'mnist': + from data.mnist import MNIST + dataset = MNIST(train=False, transform=transform, download=True) + elif p['val_db_name'] == 'stl-10': from data.stl import STL10 dataset = STL10(split='test', transform=transform, download=True) - + elif p['val_db_name'] == 'imagenet': from data.imagenet import ImageNet dataset = ImageNet(split='val', transform=transform) - + elif p['val_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']: from data.imagenet import ImageNetSubset subset_file = './data/imagenet_subsets/%s.txt' %(p['val_db_name']) dataset = ImageNetSubset(subset_file=subset_file, split='val', transform=transform) - + + # added birds test dataset + elif p['val_db_name'] == 'birds': + from data.imagenet import Birds + dataset = Birds(split='test', transform=transform) + else: raise ValueError('Invalid validation dataset {}'.format(p['val_db_name'])) - - # Wrap into other dataset (__getitem__ changes) + + # Wrap into other dataset (__getitem__ changes) if to_neighbors_dataset: # Dataset returns an image and one of its nearest neighbors. from data.custom_dataset import NeighborsDataset indices = np.load(p['topk_neighbors_val_path']) @@ -193,7 +220,7 @@ def get_val_dataset(p, transform=None, to_neighbors_dataset=False): def get_train_dataloader(p, dataset): - return torch.utils.data.DataLoader(dataset, num_workers=p['num_workers'], + return torch.utils.data.DataLoader(dataset, num_workers=p['num_workers'], batch_size=p['batch_size'], pin_memory=True, collate_fn=collate_custom, drop_last=True, shuffle=True) @@ -213,25 +240,40 @@ def get_train_transformations(p): transforms.ToTensor(), transforms.Normalize(**p['augmentation_kwargs']['normalize']) ]) - + elif p['augmentation_strategy'] == 'simclr': # Augmentation strategy from the SimCLR paper - return transforms.Compose([ - transforms.RandomResizedCrop(**p['augmentation_kwargs']['random_resized_crop']), - transforms.RandomHorizontalFlip(), - transforms.RandomApply([ - transforms.ColorJitter(**p['augmentation_kwargs']['color_jitter']) - ], p=p['augmentation_kwargs']['color_jitter_random_apply']['p']), - transforms.RandomGrayscale(**p['augmentation_kwargs']['random_grayscale']), - transforms.ToTensor(), - transforms.Normalize(**p['augmentation_kwargs']['normalize']) - ]) - + + if p['train_db_name'] == 'mnist': + return transforms.Compose([ + transforms.Resize(48), + transforms.RandomResizedCrop(**p['augmentation_kwargs']['random_resized_crop']), + transforms.RandomHorizontalFlip(), + transforms.RandomApply([ + transforms.ColorJitter(**p['augmentation_kwargs']['color_jitter']) + ], p=p['augmentation_kwargs']['color_jitter_random_apply']['p']), + transforms.RandomGrayscale(**p['augmentation_kwargs']['random_grayscale']), + transforms.ToTensor(), + transforms.Normalize(**p['augmentation_kwargs']['normalize']) + ]) + + else: + return transforms.Compose([ + transforms.RandomResizedCrop(**p['augmentation_kwargs']['random_resized_crop']), + transforms.RandomHorizontalFlip(), + transforms.RandomApply([ + transforms.ColorJitter(**p['augmentation_kwargs']['color_jitter']) + ], p=p['augmentation_kwargs']['color_jitter_random_apply']['p']), + transforms.RandomGrayscale(**p['augmentation_kwargs']['random_grayscale']), + transforms.ToTensor(), + transforms.Normalize(**p['augmentation_kwargs']['normalize']) + ]) + elif p['augmentation_strategy'] == 'ours': - # Augmentation strategy from our paper + # Augmentation strategy from our paper return transforms.Compose([ transforms.RandomHorizontalFlip(), - transforms.RandomCrop(p['augmentation_kwargs']['crop_size']), + transforms.RandomCrop(p['augmentation_kwargs']['crop_size'], pad_if_needed=True), Augment(p['augmentation_kwargs']['num_strong_augs']), transforms.ToTensor(), transforms.Normalize(**p['augmentation_kwargs']['normalize']), @@ -239,38 +281,46 @@ def get_train_transformations(p): n_holes = p['augmentation_kwargs']['cutout_kwargs']['n_holes'], length = p['augmentation_kwargs']['cutout_kwargs']['length'], random = p['augmentation_kwargs']['cutout_kwargs']['random'])]) - + else: raise ValueError('Invalid augmentation strategy {}'.format(p['augmentation_strategy'])) def get_val_transformations(p): - return transforms.Compose([ - transforms.CenterCrop(p['transformation_kwargs']['crop_size']), - transforms.ToTensor(), - transforms.Normalize(**p['transformation_kwargs']['normalize'])]) + if p['train_db_name'] == 'mnist': + return transforms.Compose([ + transforms.Resize(48), + transforms.CenterCrop(p['transformation_kwargs']['crop_size']), + transforms.ToTensor(), + transforms.Normalize(**p['transformation_kwargs']['normalize'])]) + + else: + return transforms.Compose([ + transforms.CenterCrop(p['transformation_kwargs']['crop_size']), + transforms.ToTensor(), + transforms.Normalize(**p['transformation_kwargs']['normalize'])]) def get_optimizer(p, model, cluster_head_only=False): - if cluster_head_only: # Only weights in the cluster head will be updated + if cluster_head_only: # Only weights in the cluster head will be updated for name, param in model.named_parameters(): if 'cluster_head' in name: - param.requires_grad = True + param.requires_grad = True else: - param.requires_grad = False + param.requires_grad = False params = list(filter(lambda p: p.requires_grad, model.parameters())) assert(len(params) == 2 * p['num_heads']) else: params = model.parameters() - + if p['optimizer'] == 'sgd': optimizer = torch.optim.SGD(params, **p['optimizer_kwargs']) elif p['optimizer'] == 'adam': optimizer = torch.optim.Adam(params, **p['optimizer_kwargs']) - + else: raise ValueError('Invalid optimizer {}'.format(p['optimizer'])) @@ -279,11 +329,11 @@ def get_optimizer(p, model, cluster_head_only=False): def adjust_learning_rate(p, optimizer, epoch): lr = p['optimizer_kwargs']['lr'] - + if p['scheduler'] == 'cosine': eta_min = lr * (p['scheduler_kwargs']['lr_decay_rate'] ** 3) lr = eta_min + (lr - eta_min) * (1 + math.cos(math.pi * epoch / p['epochs'])) / 2 - + elif p['scheduler'] == 'step': steps = np.sum(epoch > np.array(p['scheduler_kwargs']['lr_decay_epochs'])) if steps > 0: diff --git a/utils/mypath.py b/utils/mypath.py index 22b86161..2f45493f 100644 --- a/utils/mypath.py +++ b/utils/mypath.py @@ -8,20 +8,26 @@ class MyPath(object): @staticmethod def db_root_dir(database=''): - db_names = {'cifar-10', 'stl-10', 'cifar-20', 'imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200'} + db_names = {'cifar-10', 'stl-10', 'cifar-20', 'imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200', 'birds', 'mnist'} assert(database in db_names) if database == 'cifar-10': return '/path/to/cifar-10/' - + elif database == 'cifar-20': return '/path/to/cifar-20/' elif database == 'stl-10': return '/path/to/stl-10/' - + elif database in ['imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200']: return '/path/to/imagenet/' - + + elif database == 'birds': + return '/content/Unsupervised-Classification/data/CUB_200_2011/images/' + + elif database == 'mnist': + return '/content/mnist/' + else: raise NotImplementedError