From 135bdf0d9b19b520ef566e17b96d30b1cdac06e3 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Tue, 16 Feb 2021 23:44:58 +0100 Subject: [PATCH 01/25] feat: modify simclr code to include birds --- configs/pretext/simclr_birds.yml | 57 ++++++++++++++++++++++++ utils/common_config.py | 75 ++++++++++++++++++++------------ 2 files changed, 104 insertions(+), 28 deletions(-) create mode 100644 configs/pretext/simclr_birds.yml diff --git a/configs/pretext/simclr_birds.yml b/configs/pretext/simclr_birds.yml new file mode 100644 index 00000000..966b367f --- /dev/null +++ b/configs/pretext/simclr_birds.yml @@ -0,0 +1,57 @@ +# Setup +setup: simclr + +# Model +backbone: resnet50 +model_kwargs: + head: mlp + features_dim: 128 + +# Dataset +train_db_name: birds +val_db_name: birds +num_classes: 200 + +# Loss +criterion: simclr +criterion_kwargs: + temperature: 0.1 + +# Hyperparameters +epochs: 10 +optimizer: sgd +optimizer_kwargs: + nesterov: False + weight_decay: 0.0001 + momentum: 0.9 + lr: 0.4 +scheduler: cosine +scheduler_kwargs: + lr_decay_rate: 0.1 +batch_size: 32 +num_workers: 8 + +# Transformations +augmentation_strategy: simclr +augmentation_kwargs: + random_resized_crop: + size: 224 + scale: [0.2, 1.0] + color_jitter_random_apply: + p: 0.8 + color_jitter: + brightness: 0.4 + contrast: 0.4 + saturation: 0.4 + hue: 0.1 + random_grayscale: + p: 0.2 + normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +transformation_kwargs: + crop_size: 224 + normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] diff --git a/utils/common_config.py b/utils/common_config.py index 60d5bdf6..207b4913 100755 --- a/utils/common_config.py +++ b/utils/common_config.py @@ -10,7 +10,7 @@ from data.augment import Augment, Cutout from utils.collate import collate_custom - + def get_criterion(p): if p['criterion'] == 'simclr': from losses.losses import SimCLRLoss @@ -51,17 +51,22 @@ def get_model(p, pretrain_path=None): elif p['train_db_name'] == 'stl-10': from models.resnet_stl import resnet18 backbone = resnet18() - + else: raise NotImplementedError elif p['backbone'] == 'resnet50': if 'imagenet' in p['train_db_name']: from models.resnet import resnet50 - backbone = resnet50() + backbone = resnet50() + + # added birds with resnet50 + elif p['train_db_name'] == 'birds': + from models.resnet import resnet50 + backbone = resnet50() else: - raise NotImplementedError + raise NotImplementedError else: raise ValueError('Invalid backbone {}'.format(p['backbone'])) @@ -83,16 +88,16 @@ def get_model(p, pretrain_path=None): # Load pretrained weights if pretrain_path is not None and os.path.exists(pretrain_path): state = torch.load(pretrain_path, map_location='cpu') - + if p['setup'] == 'scan': # Weights are supposed to be transfered from contrastive training missing = model.load_state_dict(state, strict=False) assert(set(missing[1]) == { - 'contrastive_head.0.weight', 'contrastive_head.0.bias', + 'contrastive_head.0.weight', 'contrastive_head.0.bias', 'contrastive_head.2.weight', 'contrastive_head.2.bias'} or set(missing[1]) == { 'contrastive_head.weight', 'contrastive_head.bias'}) - elif p['setup'] == 'selflabel': # Weights are supposed to be transfered from scan + elif p['setup'] == 'selflabel': # Weights are supposed to be transfered from scan # We only continue with the best head (pop all heads first, then copy back the best head) model_state = state['model'] all_heads = [k for k in model_state.keys() if 'cluster_head' in k] @@ -141,9 +146,16 @@ def get_train_dataset(p, transform, to_augmented_dataset=False, subset_file = './data/imagenet_subsets/%s.txt' %(p['train_db_name']) dataset = ImageNetSubset(subset_file=subset_file, split='train', transform=transform) + # added birds train dataset + elif p['train_db_name'] == 'birds': + from torchvision.datasets import ImageFolder + birds_data_dir = '/content/data/CUB_200_2011/images/' # Colab + birds_train_dir = os.path.join(birds_data_dir, 'train') + dataset = ImageFolder(birds_train_dir, transform=transform) + else: raise ValueError('Invalid train dataset {}'.format(p['train_db_name'])) - + # Wrap into other dataset (__getitem__ changes) if to_augmented_dataset: # Dataset returns an image and an augmentation of that image. from data.custom_dataset import AugmentedDataset @@ -153,7 +165,7 @@ def get_train_dataset(p, transform, to_augmented_dataset=False, from data.custom_dataset import NeighborsDataset indices = np.load(p['topk_neighbors_train_path']) dataset = NeighborsDataset(dataset, indices, p['num_neighbors']) - + return dataset @@ -162,7 +174,7 @@ def get_val_dataset(p, transform=None, to_neighbors_dataset=False): if p['val_db_name'] == 'cifar-10': from data.cifar import CIFAR10 dataset = CIFAR10(train=False, transform=transform, download=True) - + elif p['val_db_name'] == 'cifar-20': from data.cifar import CIFAR20 dataset = CIFAR20(train=False, transform=transform, download=True) @@ -170,20 +182,27 @@ def get_val_dataset(p, transform=None, to_neighbors_dataset=False): elif p['val_db_name'] == 'stl-10': from data.stl import STL10 dataset = STL10(split='test', transform=transform, download=True) - + elif p['val_db_name'] == 'imagenet': from data.imagenet import ImageNet dataset = ImageNet(split='val', transform=transform) - + elif p['val_db_name'] in ['imagenet_50', 'imagenet_100', 'imagenet_200']: from data.imagenet import ImageNetSubset subset_file = './data/imagenet_subsets/%s.txt' %(p['val_db_name']) dataset = ImageNetSubset(subset_file=subset_file, split='val', transform=transform) - + + # added birds test dataset + elif p['val_db_name'] == 'birds': + from torchvision.datasets import ImageFolder + birds_data_dir = '/content/data/CUB_200_2011/images/' # Colab + birds_test_dir = os.path.join(birds_data_dir, 'test') + dataset = ImageFolder(birds_test_dir, transform=transform) + else: raise ValueError('Invalid validation dataset {}'.format(p['val_db_name'])) - - # Wrap into other dataset (__getitem__ changes) + + # Wrap into other dataset (__getitem__ changes) if to_neighbors_dataset: # Dataset returns an image and one of its nearest neighbors. from data.custom_dataset import NeighborsDataset indices = np.load(p['topk_neighbors_val_path']) @@ -193,7 +212,7 @@ def get_val_dataset(p, transform=None, to_neighbors_dataset=False): def get_train_dataloader(p, dataset): - return torch.utils.data.DataLoader(dataset, num_workers=p['num_workers'], + return torch.utils.data.DataLoader(dataset, num_workers=p['num_workers'], batch_size=p['batch_size'], pin_memory=True, collate_fn=collate_custom, drop_last=True, shuffle=True) @@ -213,7 +232,7 @@ def get_train_transformations(p): transforms.ToTensor(), transforms.Normalize(**p['augmentation_kwargs']['normalize']) ]) - + elif p['augmentation_strategy'] == 'simclr': # Augmentation strategy from the SimCLR paper return transforms.Compose([ @@ -226,9 +245,9 @@ def get_train_transformations(p): transforms.ToTensor(), transforms.Normalize(**p['augmentation_kwargs']['normalize']) ]) - + elif p['augmentation_strategy'] == 'ours': - # Augmentation strategy from our paper + # Augmentation strategy from our paper return transforms.Compose([ transforms.RandomHorizontalFlip(), transforms.RandomCrop(p['augmentation_kwargs']['crop_size']), @@ -239,7 +258,7 @@ def get_train_transformations(p): n_holes = p['augmentation_kwargs']['cutout_kwargs']['n_holes'], length = p['augmentation_kwargs']['cutout_kwargs']['length'], random = p['augmentation_kwargs']['cutout_kwargs']['random'])]) - + else: raise ValueError('Invalid augmentation strategy {}'.format(p['augmentation_strategy'])) @@ -247,30 +266,30 @@ def get_train_transformations(p): def get_val_transformations(p): return transforms.Compose([ transforms.CenterCrop(p['transformation_kwargs']['crop_size']), - transforms.ToTensor(), + transforms.ToTensor(), transforms.Normalize(**p['transformation_kwargs']['normalize'])]) def get_optimizer(p, model, cluster_head_only=False): - if cluster_head_only: # Only weights in the cluster head will be updated + if cluster_head_only: # Only weights in the cluster head will be updated for name, param in model.named_parameters(): if 'cluster_head' in name: - param.requires_grad = True + param.requires_grad = True else: - param.requires_grad = False + param.requires_grad = False params = list(filter(lambda p: p.requires_grad, model.parameters())) assert(len(params) == 2 * p['num_heads']) else: params = model.parameters() - + if p['optimizer'] == 'sgd': optimizer = torch.optim.SGD(params, **p['optimizer_kwargs']) elif p['optimizer'] == 'adam': optimizer = torch.optim.Adam(params, **p['optimizer_kwargs']) - + else: raise ValueError('Invalid optimizer {}'.format(p['optimizer'])) @@ -279,11 +298,11 @@ def get_optimizer(p, model, cluster_head_only=False): def adjust_learning_rate(p, optimizer, epoch): lr = p['optimizer_kwargs']['lr'] - + if p['scheduler'] == 'cosine': eta_min = lr * (p['scheduler_kwargs']['lr_decay_rate'] ** 3) lr = eta_min + (lr - eta_min) * (1 + math.cos(math.pi * epoch / p['epochs'])) / 2 - + elif p['scheduler'] == 'step': steps = np.sum(epoch > np.array(p['scheduler_kwargs']['lr_decay_epochs'])) if steps > 0: From a1dbb32cc194ecc4b9d3e26033b6913911aa409b Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Tue, 16 Feb 2021 23:56:01 +0100 Subject: [PATCH 02/25] chore: add root_dir for storing results --- configs/env.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/env.yml b/configs/env.yml index 8e1d688e..b025c3a9 100644 --- a/configs/env.yml +++ b/configs/env.yml @@ -1 +1 @@ -root_dir: /path/where/to/store/results/ +root_dir: /content/data/ From 7d9c5b7d045a998f02fe089b553c626e789a5422 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Wed, 17 Feb 2021 01:27:23 +0100 Subject: [PATCH 03/25] chore: change directory to see if config works --- utils/common_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/common_config.py b/utils/common_config.py index 207b4913..ff22628d 100755 --- a/utils/common_config.py +++ b/utils/common_config.py @@ -149,7 +149,7 @@ def get_train_dataset(p, transform, to_augmented_dataset=False, # added birds train dataset elif p['train_db_name'] == 'birds': from torchvision.datasets import ImageFolder - birds_data_dir = '/content/data/CUB_200_2011/images/' # Colab + birds_data_dir = './data/CUB_200_2011/images/' # Colab birds_train_dir = os.path.join(birds_data_dir, 'train') dataset = ImageFolder(birds_train_dir, transform=transform) @@ -195,7 +195,7 @@ def get_val_dataset(p, transform=None, to_neighbors_dataset=False): # added birds test dataset elif p['val_db_name'] == 'birds': from torchvision.datasets import ImageFolder - birds_data_dir = '/content/data/CUB_200_2011/images/' # Colab + birds_data_dir = './data/CUB_200_2011/images/' # Colab birds_test_dir = os.path.join(birds_data_dir, 'test') dataset = ImageFolder(birds_test_dir, transform=transform) From 569731204c60f48924555a67092a499f0a262c9e Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Wed, 17 Feb 2021 01:58:03 +0100 Subject: [PATCH 04/25] fix: add dataset class for birds --- utils/mypath.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/utils/mypath.py b/utils/mypath.py index 22b86161..692e96cb 100644 --- a/utils/mypath.py +++ b/utils/mypath.py @@ -13,15 +13,18 @@ def db_root_dir(database=''): if database == 'cifar-10': return '/path/to/cifar-10/' - + elif database == 'cifar-20': return '/path/to/cifar-20/' elif database == 'stl-10': return '/path/to/stl-10/' - + elif database in ['imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200']: return '/path/to/imagenet/' + + elif database == 'birds': + return '/content/Unsupervised-Classification/data/CUB_200_2011/images/' else: raise NotImplementedError From 2c43e5fa8350038c431a365198832c89b5f2fc64 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Wed, 17 Feb 2021 02:01:54 +0100 Subject: [PATCH 05/25] fix: add birds class (really) --- data/imagenet.py | 41 ++++++++++++++++++++++++++++++++--------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/data/imagenet.py b/data/imagenet.py index ec2a0285..e9365bc1 100644 --- a/data/imagenet.py +++ b/data/imagenet.py @@ -16,10 +16,10 @@ class ImageNet(datasets.ImageFolder): def __init__(self, root=MyPath.db_root_dir('imagenet'), split='train', transform=None): super(ImageNet, self).__init__(root=os.path.join(root, 'ILSVRC2012_img_%s' %(split)), transform=None) - self.transform = transform + self.transform = transform self.split = split self.resize = tf.Resize(256) - + def __len__(self): return len(self.imgs) @@ -41,12 +41,35 @@ def get_image(self, index): path, target = self.imgs[index] with open(path, 'rb') as f: img = Image.open(f).convert('RGB') - img = self.resize(img) + img = self.resize(img) return img +class Birds(datasets.ImageFolder): + def __init__(self, root=MyPath.db_root_dir('birds'), split='train', transform=None): + super(Birds, self).__init__(root=os.path.join(root, split), transform=None) + self.transform = transform + self.split = split + self.resize = tf.Resize(256) + + def __len__(self): + return len(self.imgs) + + def __getitem__(self, index): + path, target = self.imgs[index] + with open(path, 'rb') as f: + img = Image.open(f).convert('RGB') + im_size = img.size + img = self.resize(img) + + if self.transform is not None: + img = self.transform(img) + + out = {'image': img, 'target': target, 'meta': {'im_size': im_size, 'index': index}} + + return out class ImageNetSubset(data.Dataset): - def __init__(self, subset_file, root=MyPath.db_root_dir('imagenet'), split='train', + def __init__(self, subset_file, root=MyPath.db_root_dir('imagenet'), split='train', transform=None): super(ImageNetSubset, self).__init__() @@ -69,10 +92,10 @@ def __init__(self, subset_file, root=MyPath.db_root_dir('imagenet'), split='trai subdir_path = os.path.join(self.root, subdir) files = sorted(glob(os.path.join(self.root, subdir, '*.JPEG'))) for f in files: - imgs.append((f, i)) - self.imgs = imgs + imgs.append((f, i)) + self.imgs = imgs self.classes = class_names - + # Resize self.resize = tf.Resize(256) @@ -80,7 +103,7 @@ def get_image(self, index): path, target = self.imgs[index] with open(path, 'rb') as f: img = Image.open(f).convert('RGB') - img = self.resize(img) + img = self.resize(img) return img def __len__(self): @@ -91,7 +114,7 @@ def __getitem__(self, index): with open(path, 'rb') as f: img = Image.open(f).convert('RGB') im_size = img.size - img = self.resize(img) + img = self.resize(img) class_name = self.classes[target] if self.transform is not None: From 03b11029b4ba71b572357a20b95c64f3ff445e97 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Wed, 17 Feb 2021 02:05:05 +0100 Subject: [PATCH 06/25] fix: add `get_image` method to Birds --- data/imagenet.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/data/imagenet.py b/data/imagenet.py index e9365bc1..4c3db2aa 100644 --- a/data/imagenet.py +++ b/data/imagenet.py @@ -68,6 +68,13 @@ def __getitem__(self, index): return out + def get_image(self, index): + path, target = self.imgs[index] + with open(path, 'rb') as f: + img = Image.open(f).convert('RGB') + img = self.resize(img) + return img + class ImageNetSubset(data.Dataset): def __init__(self, subset_file, root=MyPath.db_root_dir('imagenet'), split='train', transform=None): From b30b52405c5644397066f7c8e2b18bbaba5a6875 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Wed, 17 Feb 2021 02:13:46 +0100 Subject: [PATCH 07/25] fix: create right path to birds dataset --- utils/common_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/common_config.py b/utils/common_config.py index ff22628d..5178007f 100755 --- a/utils/common_config.py +++ b/utils/common_config.py @@ -149,7 +149,7 @@ def get_train_dataset(p, transform, to_augmented_dataset=False, # added birds train dataset elif p['train_db_name'] == 'birds': from torchvision.datasets import ImageFolder - birds_data_dir = './data/CUB_200_2011/images/' # Colab + birds_data_dir = '/content/Unsupervised-Classification/data/CUB_200_2011/images/' # Colab birds_train_dir = os.path.join(birds_data_dir, 'train') dataset = ImageFolder(birds_train_dir, transform=transform) @@ -195,7 +195,7 @@ def get_val_dataset(p, transform=None, to_neighbors_dataset=False): # added birds test dataset elif p['val_db_name'] == 'birds': from torchvision.datasets import ImageFolder - birds_data_dir = './data/CUB_200_2011/images/' # Colab + birds_data_dir = '/content/Unsupervised-Classification/data/CUB_200_2011/images/' # Colab birds_test_dir = os.path.join(birds_data_dir, 'test') dataset = ImageFolder(birds_test_dir, transform=transform) From 5cea2f98e504073fc3052ba899136c0391efc6ca Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Wed, 17 Feb 2021 02:20:21 +0100 Subject: [PATCH 08/25] fix: get birds dataset from Birds class --- utils/common_config.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/utils/common_config.py b/utils/common_config.py index 5178007f..dbf76bf0 100755 --- a/utils/common_config.py +++ b/utils/common_config.py @@ -148,10 +148,8 @@ def get_train_dataset(p, transform, to_augmented_dataset=False, # added birds train dataset elif p['train_db_name'] == 'birds': - from torchvision.datasets import ImageFolder - birds_data_dir = '/content/Unsupervised-Classification/data/CUB_200_2011/images/' # Colab - birds_train_dir = os.path.join(birds_data_dir, 'train') - dataset = ImageFolder(birds_train_dir, transform=transform) + from data.imagenet import Birds + dataset = Birds(split='train', transform=transform) else: raise ValueError('Invalid train dataset {}'.format(p['train_db_name'])) @@ -194,10 +192,8 @@ def get_val_dataset(p, transform=None, to_neighbors_dataset=False): # added birds test dataset elif p['val_db_name'] == 'birds': - from torchvision.datasets import ImageFolder - birds_data_dir = '/content/Unsupervised-Classification/data/CUB_200_2011/images/' # Colab - birds_test_dir = os.path.join(birds_data_dir, 'test') - dataset = ImageFolder(birds_test_dir, transform=transform) + from data.imagenet import Birds + dataset = Birds(split='test', transform=transform) else: raise ValueError('Invalid validation dataset {}'.format(p['val_db_name'])) From 107938a7fe13a588ff7ffb99a07013cd39c4bee0 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Wed, 17 Feb 2021 02:29:23 +0100 Subject: [PATCH 09/25] fix: add 'birds' to `db_names` --- utils/mypath.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils/mypath.py b/utils/mypath.py index 692e96cb..c5c6645c 100644 --- a/utils/mypath.py +++ b/utils/mypath.py @@ -8,7 +8,7 @@ class MyPath(object): @staticmethod def db_root_dir(database=''): - db_names = {'cifar-10', 'stl-10', 'cifar-20', 'imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200'} + db_names = {'cifar-10', 'stl-10', 'cifar-20', 'imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200', 'birds'} assert(database in db_names) if database == 'cifar-10': @@ -25,6 +25,6 @@ def db_root_dir(database=''): elif database == 'birds': return '/content/Unsupervised-Classification/data/CUB_200_2011/images/' - + else: raise NotImplementedError From 9cd47119fc3346e6470f44b7c43652a2c7d9dab3 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Wed, 17 Feb 2021 02:39:56 +0100 Subject: [PATCH 10/25] feat: add train/test split script to repo --- organize_train_test.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 organize_train_test.py diff --git a/organize_train_test.py b/organize_train_test.py new file mode 100644 index 00000000..529ed35e --- /dev/null +++ b/organize_train_test.py @@ -0,0 +1,37 @@ +# https://github.com/ecm200/caltech_birds/blob/master/scripts/organise_train_test.py + +import os +import pandas as pd +import shutil + +# Script runtime options +root_dir = '/content/Unsupervised-Classification/data/CUB_200_2011' # change as needed +data_dir = os.path.join(root_dir,'images') + +image_fnames = pd.read_csv(filepath_or_buffer=os.path.join(root_dir,'images.txt'), + header=None, + delimiter=' ', + names=['Img ID', 'file path']) + +image_fnames['is training image?'] = pd.read_csv(filepath_or_buffer=os.path.join(root_dir,'train_test_split.txt'), + header=None, delimiter=' ', + names=['Img ID','is training image?'])['is training image?'] + +os.makedirs(os.path.join(data_dir,'train'), exist_ok=True) +os.makedirs(os.path.join(data_dir,'test'), exist_ok=True) + +for i_image, image_fname in enumerate(image_fnames['file path']): + if image_fnames['is training image?'].iloc[i_image]: + new_dir = os.path.join(data_dir,'train',image_fname.split('/')[0]) + os.makedirs(new_dir, exist_ok=True) + shutil.copy(src=os.path.join(data_dir,image_fname), dst=os.path.join(new_dir, image_fname.split('/')[1])) + print(i_image, ':: Image is in training set. [', bool(image_fnames['is training image?'].iloc[i_image]),']') + print('Image:: ', image_fname) + print('Destination:: ', new_dir) + else: + new_dir = os.path.join(data_dir,'test',image_fname.split('/')[0]) + os.makedirs(new_dir, exist_ok=True) + shutil.copy(src=os.path.join(data_dir,image_fname), dst=os.path.join(new_dir, image_fname.split('/')[1])) + print(i_image, ':: Image is in testing set. [', bool(image_fnames['is training image?'].iloc[i_image]),']') + print('Source Image:: ', image_fname) + print('Destination:: ', new_dir) From ee030224f82486d0316f026f2b1cf6dae2b999b2 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Wed, 17 Feb 2021 03:07:30 +0100 Subject: [PATCH 11/25] feat: create config file for SCAN --- configs/scan/scan_birds.yml | 59 +++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 configs/scan/scan_birds.yml diff --git a/configs/scan/scan_birds.yml b/configs/scan/scan_birds.yml new file mode 100644 index 00000000..030e793d --- /dev/null +++ b/configs/scan/scan_birds.yml @@ -0,0 +1,59 @@ +# setup +setup: scan + +# Loss +criterion: scan +criterion_kwargs: + entropy_weight: 5.0 + +# Model +backbone: resnet50 + +# Weight update +update_cluster_head_only: True # Train only linear layer during SCAN +num_heads: 10 # Use multiple heads + +# Dataset +train_db_name: birds +val_db_name: birds +num_classes: 200 +num_neighbors: 50 + +# Transformations +augmentation_strategy: simclr +augmentation_kwargs: + random_resized_crop: + size: 224 + scale: [0.2, 1.0] + color_jitter_random_apply: + p: 0.8 + color_jitter: + brightness: 0.4 + contrast: 0.4 + saturation: 0.4 + hue: 0.1 + random_grayscale: + p: 0.2 + normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +transformation_kwargs: + crop_size: 224 + normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +# Hyperparameters +optimizer: sgd +optimizer_kwargs: + lr: 5.0 + weight_decay: 0.0000 + nesterov: False + momentum: 0.9 +epochs: 10 +batch_size: 32 +num_workers: 8 + +# Scheduler +scheduler: constant From c6a77cb9cb60fa0652f6c258df6da250998ca52b Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Wed, 17 Feb 2021 03:10:49 +0100 Subject: [PATCH 12/25] feat: add config file for self-labeling --- configs/selflabel/selflabel_birds.yml | 56 +++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) create mode 100644 configs/selflabel/selflabel_birds.yml diff --git a/configs/selflabel/selflabel_birds.yml b/configs/selflabel/selflabel_birds.yml new file mode 100644 index 00000000..54a6caa2 --- /dev/null +++ b/configs/selflabel/selflabel_birds.yml @@ -0,0 +1,56 @@ +# setup +setup: selflabel + +# Threshold +confidence_threshold: 0.99 + +# EMA +use_ema: True +ema_alpha: 0.999 + +# Loss +criterion: confidence-cross-entropy +criterion_kwargs: + apply_class_balancing: False + +# Model +backbone: resnet50 +num_heads: 1 + +# Dataset +train_db_name: birds +val_db_name: birds +num_classes: 200 + +# Transformations +augmentation_strategy: ours +augmentation_kwargs: + crop_size: 224 + normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + num_strong_augs: 4 + cutout_kwargs: + n_holes: 1 + length: 75 + random: True + +transformation_kwargs: + crop_size: 224 + normalize: + mean: [0.485, 0.456, 0.406] + std: [0.229, 0.224, 0.225] + +# Hyperparameters +optimizer: sgd +optimizer_kwargs: + lr: 0.03 + weight_decay: 0.0 + nesterov: False + momentum: 0.9 +epochs: 10 +batch_size: 32 +num_workers: 8 + +# Scheduler +scheduler: constant From 3d48c88ffd236145a0ef3e4a59ed77baa423ef70 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Thu, 18 Feb 2021 02:30:34 +0100 Subject: [PATCH 13/25] chore: choose hyperparameters --- configs/pretext/simclr_birds.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/configs/pretext/simclr_birds.yml b/configs/pretext/simclr_birds.yml index 966b367f..f1a34bdc 100644 --- a/configs/pretext/simclr_birds.yml +++ b/configs/pretext/simclr_birds.yml @@ -5,7 +5,7 @@ setup: simclr backbone: resnet50 model_kwargs: head: mlp - features_dim: 128 + features_dim: 2048 # Dataset train_db_name: birds @@ -15,20 +15,20 @@ num_classes: 200 # Loss criterion: simclr criterion_kwargs: - temperature: 0.1 + temperature: 0.07 # Hyperparameters -epochs: 10 +epochs: 100 optimizer: sgd optimizer_kwargs: nesterov: False weight_decay: 0.0001 momentum: 0.9 - lr: 0.4 + lr: 0.005625 scheduler: cosine scheduler_kwargs: lr_decay_rate: 0.1 -batch_size: 32 +batch_size: 48 num_workers: 8 # Transformations From 8c2f0e522fb9bd93887049218cf271ce4e014e03 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Tue, 2 Mar 2021 21:51:31 +0100 Subject: [PATCH 14/25] feat: add MNIST dataset --- configs/pretext/simclr_mnist.yml | 57 +++++++++++ data/mnist.py | 160 +++++++++++++++++++++++++++++++ models/resnet_mnist.py | 125 ++++++++++++++++++++++++ utils/common_config.py | 67 +++++++++---- utils/mypath.py | 5 +- 5 files changed, 397 insertions(+), 17 deletions(-) create mode 100644 configs/pretext/simclr_mnist.yml create mode 100644 data/mnist.py create mode 100644 models/resnet_mnist.py diff --git a/configs/pretext/simclr_mnist.yml b/configs/pretext/simclr_mnist.yml new file mode 100644 index 00000000..c33588eb --- /dev/null +++ b/configs/pretext/simclr_mnist.yml @@ -0,0 +1,57 @@ +# Setup +setup: simclr + +# Model +backbone: resnet18 +model_kwargs: + head: mlp + features_dim: 128 + +# Dataset +train_db_name: mnist +val_db_name: mnist +num_classes: 10 + +# Loss +criterion: simclr +criterion_kwargs: + temperature: 0.1 + +# Hyperparameters +epochs: 5 +optimizer: sgd +optimizer_kwargs: + nesterov: False + weight_decay: 0.0001 + momentum: 0.9 + lr: 0.4 +scheduler: cosine +scheduler_kwargs: + lr_decay_rate: 0.1 +batch_size: 48 +num_workers: 8 + +# Transformations +augmentation_strategy: simclr +augmentation_kwargs: + random_resized_crop: + size: 32 + scale: [0.2, 1.0] + color_jitter_random_apply: + p: 0.8 + color_jitter: + brightness: 0.4 + contrast: 0.4 + saturation: 0.4 + hue: 0.1 + random_grayscale: + p: 0.2 + normalize: + mean: [0.5] + std: [0.5] + +transformation_kwargs: + crop_size: 32 + normalize: + mean: [0.5] + std: [0.5] diff --git a/data/mnist.py b/data/mnist.py new file mode 100644 index 00000000..8796a6d0 --- /dev/null +++ b/data/mnist.py @@ -0,0 +1,160 @@ +from torch.utils.data import Dataset +import warnings +from PIL import Image +import os +import os.path +import numpy as np +import torch +import codecs +import string +import gzip +import lzma +from torchvision.datasets.utils import download_url, download_and_extract_archive, extract_archive, \ + verify_str_arg + +class MNIST(Dataset): + """`MNIST `_ Dataset. + + Args: + root (string): Root directory of dataset where ``MNIST/processed/training.pt`` + and ``MNIST/processed/test.pt`` exist. + train (bool, optional): If True, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + download (bool, optional): If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + """ + + resources = [ + ("http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"), + ("http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"), + ("http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"), + ("http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c") + ] + + training_file = 'training.pt' + test_file = 'test.pt' + classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', + '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine'] + + @property + def train_labels(self): + warnings.warn("train_labels has been renamed targets") + return self.targets + + @property + def test_labels(self): + warnings.warn("test_labels has been renamed targets") + return self.targets + + @property + def train_data(self): + warnings.warn("train_data has been renamed data") + return self.data + + @property + def test_data(self): + warnings.warn("test_data has been renamed data") + return self.data + + def __init__(self, root=MyPath.db_root_dir('mnist'), train=True, + transform=None, download=False): + super(MNIST, self).__init__() + self.root = root + self.transform = transform + self.train = train # training set or test set + + if download: + self.download() + + if not self._check_exists(): + raise RuntimeError('Dataset not found.' + + ' You can use download=True to download it') + + if self.train: + data_file = self.training_file + else: + data_file = self.test_file + self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) + + def __getitem__(self, index): + """ + Args: + index (int): Index + + Returns: + tuple: (image, target) where target is index of the target class. + """ + img, target = self.data[index], int(self.targets[index]) + + # doing this so that it is consistent with all other datasets + # to return a PIL Image + img = Image.fromarray(img.numpy(), mode='L') + img_size = img.size + + if self.transform is not None: + img = self.transform(img) + + out = {'image': img, 'target': target, 'meta': {'im_size': img_size, 'index': index}} + + return img, target + + def __len__(self): + return len(self.data) + + @property + def raw_folder(self): + return os.path.join(self.root, self.__class__.__name__, 'raw') + + @property + def processed_folder(self): + return os.path.join(self.root, self.__class__.__name__, 'processed') + + @property + def class_to_idx(self): + return {_class: i for i, _class in enumerate(self.classes)} + + def _check_exists(self): + return (os.path.exists(os.path.join(self.processed_folder, + self.training_file)) and + os.path.exists(os.path.join(self.processed_folder, + self.test_file))) + + def download(self): + """Download the MNIST data if it doesn't exist in processed_folder already.""" + + if self._check_exists(): + return + + os.makedirs(self.raw_folder, exist_ok=True) + os.makedirs(self.processed_folder, exist_ok=True) + + # download files + for url, md5 in self.resources: + filename = url.rpartition('/')[2] + download_and_extract_archive(url, download_root=self.raw_folder, filename=filename, md5=md5) + + # process and save as torch files + print('Processing...') + + training_set = ( + read_image_file(os.path.join(self.raw_folder, 'train-images-idx3-ubyte')), + read_label_file(os.path.join(self.raw_folder, 'train-labels-idx1-ubyte')) + ) + test_set = ( + read_image_file(os.path.join(self.raw_folder, 't10k-images-idx3-ubyte')), + read_label_file(os.path.join(self.raw_folder, 't10k-labels-idx1-ubyte')) + ) + with open(os.path.join(self.processed_folder, self.training_file), 'wb') as f: + torch.save(training_set, f) + with open(os.path.join(self.processed_folder, self.test_file), 'wb') as f: + torch.save(test_set, f) + + print('Done!') + + def extra_repr(self): + return "Split: {}".format("Train" if self.train is True else "Test") diff --git a/models/resnet_mnist.py b/models/resnet_mnist.py new file mode 100644 index 00000000..c215b847 --- /dev/null +++ b/models/resnet_mnist.py @@ -0,0 +1,125 @@ +""" +This code is based on the Torchvision repository, which was licensed under the BSD 3-Clause. +""" +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, in_planes, planes, stride=1, is_last=False): + super(BasicBlock, self).__init__() + self.is_last = is_last + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.bn2(self.conv2(out)) + out += self.shortcut(x) + preact = out + out = F.relu(out) + if self.is_last: + return out, preact + else: + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, in_planes, planes, stride=1, is_last=False): + super(Bottleneck, self).__init__() + self.is_last = is_last + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.conv3 = nn.Conv2d(planes, self.expansion * planes, kernel_size=1, bias=False) + self.bn3 = nn.BatchNorm2d(self.expansion * planes) + + self.shortcut = nn.Sequential() + if stride != 1 or in_planes != self.expansion * planes: + self.shortcut = nn.Sequential( + nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(self.expansion * planes) + ) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = F.relu(self.bn2(self.conv2(out))) + out = self.bn3(self.conv3(out)) + out += self.shortcut(x) + preact = out + out = F.relu(out) + if self.is_last: + return out, preact + else: + return out + + +class ResNet(nn.Module): + def __init__(self, block, num_blocks, in_channel=1, zero_init_residual=False): + super(ResNet, self).__init__() + self.in_planes = 64 + + self.conv1 = nn.Conv2d(in_channel, 64, kernel_size=3, stride=1, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(64) + self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) + self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2) + self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) + self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves + # like an identity. This improves the model by 0.2~0.3% according to: + # https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, num_blocks, stride): + strides = [stride] + [1] * (num_blocks - 1) + layers = [] + for i in range(num_blocks): + stride = strides[i] + layers.append(block(self.in_planes, planes, stride)) + self.in_planes = planes * block.expansion + return nn.Sequential(*layers) + + def forward(self, x): + out = F.relu(self.bn1(self.conv1(x))) + out = self.layer1(out) + out = self.layer2(out) + out = self.layer3(out) + out = self.layer4(out) + out = self.avgpool(out) + out = torch.flatten(out, 1) + return out + + +def resnet18(**kwargs): + return {'backbone': ResNet(BasicBlock, [2, 2, 2, 2], **kwargs), 'dim': 512} diff --git a/utils/common_config.py b/utils/common_config.py index dbf76bf0..8e5cd5d6 100755 --- a/utils/common_config.py +++ b/utils/common_config.py @@ -52,6 +52,10 @@ def get_model(p, pretrain_path=None): from models.resnet_stl import resnet18 backbone = resnet18() + elif p['train_db_name'] == 'mnist': + from models.resnet_stl import resnet18 + backbone = resnet18() + else: raise NotImplementedError @@ -61,7 +65,7 @@ def get_model(p, pretrain_path=None): backbone = resnet50() # added birds with resnet50 - elif p['train_db_name'] == 'birds': + elif p['train_db_name'] in ['birds', 'mnist']: from models.resnet import resnet50 backbone = resnet50() @@ -137,6 +141,10 @@ def get_train_dataset(p, transform, to_augmented_dataset=False, from data.stl import STL10 dataset = STL10(split=split, transform=transform, download=True) + elif p['train_db_name'] == 'mnist': + from data.mnist import MNIST + dataset = MNIST(train=True, transform=transform, download=True) + elif p['train_db_name'] == 'imagenet': from data.imagenet import ImageNet dataset = ImageNet(split='train', transform=transform) @@ -177,6 +185,10 @@ def get_val_dataset(p, transform=None, to_neighbors_dataset=False): from data.cifar import CIFAR20 dataset = CIFAR20(train=False, transform=transform, download=True) + elif p['val_db_name'] == 'mnist': + from data.mnist import MNIST + dataset = MNIST(train=False, transform=transform, download=True) + elif p['val_db_name'] == 'stl-10': from data.stl import STL10 dataset = STL10(split='test', transform=transform, download=True) @@ -231,22 +243,37 @@ def get_train_transformations(p): elif p['augmentation_strategy'] == 'simclr': # Augmentation strategy from the SimCLR paper - return transforms.Compose([ - transforms.RandomResizedCrop(**p['augmentation_kwargs']['random_resized_crop']), - transforms.RandomHorizontalFlip(), - transforms.RandomApply([ - transforms.ColorJitter(**p['augmentation_kwargs']['color_jitter']) - ], p=p['augmentation_kwargs']['color_jitter_random_apply']['p']), - transforms.RandomGrayscale(**p['augmentation_kwargs']['random_grayscale']), - transforms.ToTensor(), - transforms.Normalize(**p['augmentation_kwargs']['normalize']) - ]) + + if p['train_db_name'] == 'mnist': + return transforms.Compose([ + transforms.Resize(48), + transforms.RandomResizedCrop(**p['augmentation_kwargs']['random_resized_crop']), + transforms.RandomHorizontalFlip(), + transforms.RandomApply([ + transforms.ColorJitter(**p['augmentation_kwargs']['color_jitter']) + ], p=p['augmentation_kwargs']['color_jitter_random_apply']['p']), + transforms.RandomGrayscale(**p['augmentation_kwargs']['random_grayscale']), + transforms.ToTensor(), + transforms.Normalize(**p['augmentation_kwargs']['normalize']) + ]) + + else: + return transforms.Compose([ + transforms.RandomResizedCrop(**p['augmentation_kwargs']['random_resized_crop']), + transforms.RandomHorizontalFlip(), + transforms.RandomApply([ + transforms.ColorJitter(**p['augmentation_kwargs']['color_jitter']) + ], p=p['augmentation_kwargs']['color_jitter_random_apply']['p']), + transforms.RandomGrayscale(**p['augmentation_kwargs']['random_grayscale']), + transforms.ToTensor(), + transforms.Normalize(**p['augmentation_kwargs']['normalize']) + ]) elif p['augmentation_strategy'] == 'ours': # Augmentation strategy from our paper return transforms.Compose([ transforms.RandomHorizontalFlip(), - transforms.RandomCrop(p['augmentation_kwargs']['crop_size']), + transforms.RandomCrop(p['augmentation_kwargs']['crop_size'], pad_if_needed=True), Augment(p['augmentation_kwargs']['num_strong_augs']), transforms.ToTensor(), transforms.Normalize(**p['augmentation_kwargs']['normalize']), @@ -260,10 +287,18 @@ def get_train_transformations(p): def get_val_transformations(p): - return transforms.Compose([ - transforms.CenterCrop(p['transformation_kwargs']['crop_size']), - transforms.ToTensor(), - transforms.Normalize(**p['transformation_kwargs']['normalize'])]) + if p['train_db_name'] == 'mnist': + return transforms.Compose([ + transforms.Resize(48), + transforms.CenterCrop(p['transformation_kwargs']['crop_size']), + transforms.ToTensor(), + transforms.Normalize(**p['transformation_kwargs']['normalize'])]) + + else: + return transforms.Compose([ + transforms.CenterCrop(p['transformation_kwargs']['crop_size']), + transforms.ToTensor(), + transforms.Normalize(**p['transformation_kwargs']['normalize'])]) def get_optimizer(p, model, cluster_head_only=False): diff --git a/utils/mypath.py b/utils/mypath.py index c5c6645c..2f45493f 100644 --- a/utils/mypath.py +++ b/utils/mypath.py @@ -8,7 +8,7 @@ class MyPath(object): @staticmethod def db_root_dir(database=''): - db_names = {'cifar-10', 'stl-10', 'cifar-20', 'imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200', 'birds'} + db_names = {'cifar-10', 'stl-10', 'cifar-20', 'imagenet', 'imagenet_50', 'imagenet_100', 'imagenet_200', 'birds', 'mnist'} assert(database in db_names) if database == 'cifar-10': @@ -26,5 +26,8 @@ def db_root_dir(database=''): elif database == 'birds': return '/content/Unsupervised-Classification/data/CUB_200_2011/images/' + elif database == 'mnist': + return '/content/mnist/' + else: raise NotImplementedError From e6fecc925570a11abb45515533f11083f6ef4d0d Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Tue, 2 Mar 2021 21:59:59 +0100 Subject: [PATCH 15/25] fix: address indent error --- utils/common_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/common_config.py b/utils/common_config.py index 8e5cd5d6..0da56364 100755 --- a/utils/common_config.py +++ b/utils/common_config.py @@ -65,7 +65,7 @@ def get_model(p, pretrain_path=None): backbone = resnet50() # added birds with resnet50 - elif p['train_db_name'] in ['birds', 'mnist']: + elif p['train_db_name'] in ['birds', 'mnist']: from models.resnet import resnet50 backbone = resnet50() From f5ca5943eed361c41d44ee12864ff6acd21607a8 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Tue, 2 Mar 2021 22:04:53 +0100 Subject: [PATCH 16/25] fix: add MyPath to import --- data/mnist.py | 1 + 1 file changed, 1 insertion(+) diff --git a/data/mnist.py b/data/mnist.py index 8796a6d0..2643f9b4 100644 --- a/data/mnist.py +++ b/data/mnist.py @@ -6,6 +6,7 @@ import numpy as np import torch import codecs +from utils.mypath import MyPath import string import gzip import lzma From 31fc871f1d4e8366225813195cd55053cf242cb6 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Tue, 2 Mar 2021 22:12:20 +0100 Subject: [PATCH 17/25] fix: add missing code to mnist data class --- data/mnist.py | 61 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/data/mnist.py b/data/mnist.py index 2643f9b4..d1981ec9 100644 --- a/data/mnist.py +++ b/data/mnist.py @@ -159,3 +159,64 @@ def download(self): def extra_repr(self): return "Split: {}".format("Train" if self.train is True else "Test") + +def get_int(b): + return int(codecs.encode(b, 'hex'), 16) + +def open_maybe_compressed_file(path): + """Return a file object that possibly decompresses 'path' on the fly. + Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'. + """ + if not isinstance(path, torch._six.string_classes): + return path + if path.endswith('.gz'): + return gzip.open(path, 'rb') + if path.endswith('.xz'): + return lzma.open(path, 'rb') + return open(path, 'rb') + + +SN3_PASCALVINCENT_TYPEMAP = { + 8: (torch.uint8, np.uint8, np.uint8), + 9: (torch.int8, np.int8, np.int8), + 11: (torch.int16, np.dtype('>i2'), 'i2'), + 12: (torch.int32, np.dtype('>i4'), 'i4'), + 13: (torch.float32, np.dtype('>f4'), 'f4'), + 14: (torch.float64, np.dtype('>f8'), 'f8') +} + + +def read_sn3_pascalvincent_tensor(path, strict=True): + """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). + Argument may be a filename, compressed filename, or file object. + """ + # read + with open_maybe_compressed_file(path) as f: + data = f.read() + # parse + magic = get_int(data[0:4]) + nd = magic % 256 + ty = magic // 256 + assert nd >= 1 and nd <= 3 + assert ty >= 8 and ty <= 14 + m = SN3_PASCALVINCENT_TYPEMAP[ty] + s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)] + parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) + assert parsed.shape[0] == np.prod(s) or not strict + return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) + + +def read_label_file(path): + with open(path, 'rb') as f: + x = read_sn3_pascalvincent_tensor(f, strict=False) + assert(x.dtype == torch.uint8) + assert(x.ndimension() == 1) + return x.long() + + +def read_image_file(path): + with open(path, 'rb') as f: + x = read_sn3_pascalvincent_tensor(f, strict=False) + assert(x.dtype == torch.uint8) + assert(x.ndimension() == 3) + return x From 79325446d429a46128b66bc003c0268452d81379 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Tue, 2 Mar 2021 22:20:18 +0100 Subject: [PATCH 18/25] fix: fix __getitem__ in mnist data class --- data/mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/mnist.py b/data/mnist.py index d1981ec9..096820be 100644 --- a/data/mnist.py +++ b/data/mnist.py @@ -102,7 +102,7 @@ def __getitem__(self, index): out = {'image': img, 'target': target, 'meta': {'im_size': img_size, 'index': index}} - return img, target + return out def __len__(self): return len(self.data) From bc4dc61706fcd3e69a7e1eaa263b84a38030b89f Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Tue, 2 Mar 2021 22:27:15 +0100 Subject: [PATCH 19/25] fix: selected correct resnet backbone --- utils/common_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/common_config.py b/utils/common_config.py index 0da56364..3958eb78 100755 --- a/utils/common_config.py +++ b/utils/common_config.py @@ -53,7 +53,7 @@ def get_model(p, pretrain_path=None): backbone = resnet18() elif p['train_db_name'] == 'mnist': - from models.resnet_stl import resnet18 + from models.resnet_mnist import resnet18 backbone = resnet18() else: From eff5f727b816f7ad77a8a3cc7a5e9cfa915b7115 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Tue, 2 Mar 2021 23:58:54 +0100 Subject: [PATCH 20/25] feat: add scan, selflabel configs for mnist --- configs/scan/scan_mnist.yml | 51 ++++++++++++++++++++++++++ configs/selflabel/selflabel_mnist.yml | 53 +++++++++++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 configs/scan/scan_mnist.yml create mode 100644 configs/selflabel/selflabel_mnist.yml diff --git a/configs/scan/scan_mnist.yml b/configs/scan/scan_mnist.yml new file mode 100644 index 00000000..dc6bae22 --- /dev/null +++ b/configs/scan/scan_mnist.yml @@ -0,0 +1,51 @@ +# setup +setup: scan + +# Loss +criterion: scan +criterion_kwargs: + entropy_weight: 5.0 + +# Weight update +update_cluster_head_only: False # Update full network in SCAN +num_heads: 1 # Only use one head + +# Model +backbone: resnet18 + +# Dataset +train_db_name: mnist +val_db_name: mnist +num_classes: 10 +num_neighbors: 20 + +# Transformations +augmentation_strategy: ours +augmentation_kwargs: + crop_size: 32 + normalize: + mean: [0.5] + std: [0.5] + num_strong_augs: 4 + cutout_kwargs: + n_holes: 1 + length: 16 + random: True + +transformation_kwargs: + crop_size: 32 + normalize: + mean: [0.5] + std: [0.5] + +# Hyperparameters +optimizer: adam +optimizer_kwargs: + lr: 0.0001 + weight_decay: 0.0001 +epochs: 50 +batch_size: 48 +num_workers: 8 + +# Scheduler +scheduler: constant diff --git a/configs/selflabel/selflabel_mnist.yml b/configs/selflabel/selflabel_mnist.yml new file mode 100644 index 00000000..c5bf4e23 --- /dev/null +++ b/configs/selflabel/selflabel_mnist.yml @@ -0,0 +1,53 @@ +# setup +setup: selflabel + +# ema +use_ema: False + +# Threshold +confidence_threshold: 0.99 + +# Criterion +criterion: confidence-cross-entropy +criterion_kwargs: + apply_class_balancing: True + +# Model +backbone: resnet18 +num_heads: 1 + +# Dataset +train_db_name: mnist +val_db_name: mnist +num_classes: 10 + +# Transformations +augmentation_strategy: ours +augmentation_kwargs: + crop_size: 32 + normalize: + mean: [0.5] + std: [0.5] + num_strong_augs: 4 + cutout_kwargs: + n_holes: 1 + length: 16 + random: True + +transformation_kwargs: + crop_size: 32 + normalize: + mean: [0.5] + std: [0.5] + +# Hyperparameters +epochs: 200 +optimizer: adam +optimizer_kwargs: + lr: 0.0001 + weight_decay: 0.0001 +batch_size: 48 +num_workers: 8 + +# Scheduler +scheduler: constant From 3ac48e30a54c425aca300c5006ba77fb8f07ea47 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Wed, 3 Mar 2021 18:01:07 +0100 Subject: [PATCH 21/25] fix: add `get_image` and reduce # epochs --- configs/scan/scan_mnist.yml | 2 +- configs/selflabel/selflabel_mnist.yml | 2 +- data/mnist.py | 5 +++++ 3 files changed, 7 insertions(+), 2 deletions(-) diff --git a/configs/scan/scan_mnist.yml b/configs/scan/scan_mnist.yml index dc6bae22..78bde7fc 100644 --- a/configs/scan/scan_mnist.yml +++ b/configs/scan/scan_mnist.yml @@ -43,7 +43,7 @@ optimizer: adam optimizer_kwargs: lr: 0.0001 weight_decay: 0.0001 -epochs: 50 +epochs: 5 batch_size: 48 num_workers: 8 diff --git a/configs/selflabel/selflabel_mnist.yml b/configs/selflabel/selflabel_mnist.yml index c5bf4e23..7b3c7967 100644 --- a/configs/selflabel/selflabel_mnist.yml +++ b/configs/selflabel/selflabel_mnist.yml @@ -41,7 +41,7 @@ transformation_kwargs: std: [0.5] # Hyperparameters -epochs: 200 +epochs: 5 optimizer: adam optimizer_kwargs: lr: 0.0001 diff --git a/data/mnist.py b/data/mnist.py index 096820be..71c0de2a 100644 --- a/data/mnist.py +++ b/data/mnist.py @@ -104,6 +104,11 @@ def __getitem__(self, index): return out + # Note: For eval.py. From /data/cifar.py + def get_image(self, index): + img = self.data[index] + return img + def __len__(self): return len(self.data) From cf1c5e9daff2d8bc47b0a486399e098a27c2867f Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Thu, 4 Mar 2021 19:55:38 +0100 Subject: [PATCH 22/25] feat: use VisionDataset for MNIST class --- data/mnist.py | 101 +++++++++++++------------------------------------- 1 file changed, 25 insertions(+), 76 deletions(-) diff --git a/data/mnist.py b/data/mnist.py index 71c0de2a..b99debec 100644 --- a/data/mnist.py +++ b/data/mnist.py @@ -1,4 +1,4 @@ -from torch.utils.data import Dataset +from .vision import VisionDataset import warnings from PIL import Image import os @@ -6,14 +6,15 @@ import numpy as np import torch import codecs -from utils.mypath import MyPath import string import gzip import lzma -from torchvision.datasets.utils import download_url, download_and_extract_archive, extract_archive, \ +from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union +from .utils import download_url, download_and_extract_archive, extract_archive, \ verify_str_arg +from utils.mypath import MyPath -class MNIST(Dataset): +[docs]class MNIST(VisionDataset): """`MNIST `_ Dataset. Args: @@ -62,13 +63,20 @@ def test_data(self): warnings.warn("test_data has been renamed data") return self.data - def __init__(self, root=MyPath.db_root_dir('mnist'), train=True, - transform=None, download=False): + def __init__( + self, + root: str = MyPath.db_root_dir('mnist'), + train: bool = True, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + download: bool = False, + ) -> None: super(MNIST, self).__init__() self.root = root self.transform = transform self.train = train # training set or test set + if download: self.download() @@ -82,7 +90,7 @@ def __init__(self, root=MyPath.db_root_dir('mnist'), train=True, data_file = self.test_file self.data, self.targets = torch.load(os.path.join(self.processed_folder, data_file)) - def __getitem__(self, index): + def __getitem__(self, index: int) -> Tuple[Any, Any]: """ Args: index (int): Index @@ -95,11 +103,13 @@ def __getitem__(self, index): # doing this so that it is consistent with all other datasets # to return a PIL Image img = Image.fromarray(img.numpy(), mode='L') - img_size = img.size if self.transform is not None: img = self.transform(img) + if self.target_transform is not None: + target = self.target_transform(target) + out = {'image': img, 'target': target, 'meta': {'im_size': img_size, 'index': index}} return out @@ -109,28 +119,28 @@ def get_image(self, index): img = self.data[index] return img - def __len__(self): + def __len__(self) -> int: return len(self.data) @property - def raw_folder(self): + def raw_folder(self) -> str: return os.path.join(self.root, self.__class__.__name__, 'raw') @property - def processed_folder(self): + def processed_folder(self) -> str: return os.path.join(self.root, self.__class__.__name__, 'processed') @property - def class_to_idx(self): + def class_to_idx(self) -> Dict[str, int]: return {_class: i for i, _class in enumerate(self.classes)} - def _check_exists(self): + def _check_exists(self) -> bool: return (os.path.exists(os.path.join(self.processed_folder, self.training_file)) and os.path.exists(os.path.join(self.processed_folder, self.test_file))) - def download(self): + def download(self) -> None: """Download the MNIST data if it doesn't exist in processed_folder already.""" if self._check_exists(): @@ -162,66 +172,5 @@ def download(self): print('Done!') - def extra_repr(self): + def extra_repr(self) -> str: return "Split: {}".format("Train" if self.train is True else "Test") - -def get_int(b): - return int(codecs.encode(b, 'hex'), 16) - -def open_maybe_compressed_file(path): - """Return a file object that possibly decompresses 'path' on the fly. - Decompression occurs when argument `path` is a string and ends with '.gz' or '.xz'. - """ - if not isinstance(path, torch._six.string_classes): - return path - if path.endswith('.gz'): - return gzip.open(path, 'rb') - if path.endswith('.xz'): - return lzma.open(path, 'rb') - return open(path, 'rb') - - -SN3_PASCALVINCENT_TYPEMAP = { - 8: (torch.uint8, np.uint8, np.uint8), - 9: (torch.int8, np.int8, np.int8), - 11: (torch.int16, np.dtype('>i2'), 'i2'), - 12: (torch.int32, np.dtype('>i4'), 'i4'), - 13: (torch.float32, np.dtype('>f4'), 'f4'), - 14: (torch.float64, np.dtype('>f8'), 'f8') -} - - -def read_sn3_pascalvincent_tensor(path, strict=True): - """Read a SN3 file in "Pascal Vincent" format (Lush file 'libidx/idx-io.lsh'). - Argument may be a filename, compressed filename, or file object. - """ - # read - with open_maybe_compressed_file(path) as f: - data = f.read() - # parse - magic = get_int(data[0:4]) - nd = magic % 256 - ty = magic // 256 - assert nd >= 1 and nd <= 3 - assert ty >= 8 and ty <= 14 - m = SN3_PASCALVINCENT_TYPEMAP[ty] - s = [get_int(data[4 * (i + 1): 4 * (i + 2)]) for i in range(nd)] - parsed = np.frombuffer(data, dtype=m[1], offset=(4 * (nd + 1))) - assert parsed.shape[0] == np.prod(s) or not strict - return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s) - - -def read_label_file(path): - with open(path, 'rb') as f: - x = read_sn3_pascalvincent_tensor(f, strict=False) - assert(x.dtype == torch.uint8) - assert(x.ndimension() == 1) - return x.long() - - -def read_image_file(path): - with open(path, 'rb') as f: - x = read_sn3_pascalvincent_tensor(f, strict=False) - assert(x.dtype == torch.uint8) - assert(x.ndimension() == 3) - return x From d1681e12e2549ca55251ff0cbe156cd3012228ea Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Thu, 4 Mar 2021 20:18:47 +0100 Subject: [PATCH 23/25] fix: remove [docs] from mnist dataset class --- data/mnist.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/data/mnist.py b/data/mnist.py index b99debec..8d1d79db 100644 --- a/data/mnist.py +++ b/data/mnist.py @@ -14,7 +14,7 @@ verify_str_arg from utils.mypath import MyPath -[docs]class MNIST(VisionDataset): +class MNIST(VisionDataset): """`MNIST `_ Dataset. Args: From 5f27c82ab552425dcd45236bb08904bfc87a8363 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Thu, 4 Mar 2021 20:30:58 +0100 Subject: [PATCH 24/25] fix: torchvision.datasets.vision --- data/mnist.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/data/mnist.py b/data/mnist.py index 8d1d79db..7eddfcb4 100644 --- a/data/mnist.py +++ b/data/mnist.py @@ -1,4 +1,4 @@ -from .vision import VisionDataset +from torchvision.datasets.vision import VisionDataset import warnings from PIL import Image import os @@ -10,7 +10,7 @@ import gzip import lzma from typing import Any, Callable, Dict, IO, List, Optional, Tuple, Union -from .utils import download_url, download_and_extract_archive, extract_archive, \ +from torchvision.datasets.utils import download_url, download_and_extract_archive, extract_archive, \ verify_str_arg from utils.mypath import MyPath From b70e3a916a3a55e97c64bbc09243e4164c629191 Mon Sep 17 00:00:00 2001 From: Bob Bell Date: Thu, 4 Mar 2021 20:37:03 +0100 Subject: [PATCH 25/25] fix: __init__ params --- data/mnist.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/data/mnist.py b/data/mnist.py index 7eddfcb4..fd2acabd 100644 --- a/data/mnist.py +++ b/data/mnist.py @@ -71,12 +71,10 @@ def __init__( target_transform: Optional[Callable] = None, download: bool = False, ) -> None: - super(MNIST, self).__init__() - self.root = root - self.transform = transform + super(MNIST, self).__init__(root, transform=transform, + target_transform=target_transform) self.train = train # training set or test set - if download: self.download()