From 898ab026765d07e4cb9f4e5d6db6a7fafc715af6 Mon Sep 17 00:00:00 2001 From: SlowMonk <33189954+SlowMonk@users.noreply.github.com> Date: Mon, 21 Nov 2022 11:30:03 +0900 Subject: [PATCH] pull request (#7) * test * pretrained=false * main_supcon.py add wandb * question loader * utils * fscore * calculator.py * dataset-statis.py * mains * gitignore * script * encoder_models/ * igmore * Delete main_f1score.py delete * change * Delete train_supcon_50p_group4.sh train_supcon_50p_grroup4 delete * gitignore * calculator remove comment * compare better way * file_name not using * f1_graph remove * question loader try remove * inference group4 * main_supcon * reslove epsilon * remove continuing try exception block * remove print statement * remove try exception for removing damage jpg,jpeg * jpeg jpg include * remove resnet from main_supcon * spaces after each comma inference_g4.py * delete unsed comment * multiple line vioklation * remove unused import * remove unsed try except question_loader --- .gitignore | 16 +- calculator.py | 118 +++++++++++ dataset_statistics.py | 46 ++++ encoder_models/__init__.py | 0 encoder_models/base_model.py | 122 +++++++++++ encoder_models/cluter_base_model.py | 114 ++++++++++ encoder_models/rank_base_model.py | 167 +++++++++++++++ encoder_models/resnet.py | 30 +++ encoder_models/vit.py | 48 +++++ fscore.py | 34 +++ inference_g4.py | 91 ++++++++ main.py | 118 +++++++++++ main_supcon.py | 83 ++++++-- networks/vit.py | 1 + question_loader.py | 17 +- quiz_master.py | 311 ++++++++++++++++++++++++++++ util.py | 2 +- 17 files changed, 1290 insertions(+), 28 deletions(-) create mode 100644 calculator.py create mode 100644 dataset_statistics.py create mode 100644 encoder_models/__init__.py create mode 100644 encoder_models/base_model.py create mode 100644 encoder_models/cluter_base_model.py create mode 100644 encoder_models/rank_base_model.py create mode 100644 encoder_models/resnet.py create mode 100644 encoder_models/vit.py create mode 100644 fscore.py create mode 100644 inference_g4.py create mode 100644 main.py create mode 100644 quiz_master.py diff --git a/.gitignore b/.gitignore index 6305b46c..4b501f62 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,18 @@ dmypy.json # Pyre type checker .pyre/ -save/ \ No newline at end of file +save/ +wandb/ +#encoder_models/ +score_utils/ + +# files +./notebooks/ +./sh_files/ +./make_resize.py +notebooks/ +sh_files/ +utils/ +main_f1score.py +train_supcon50_group4.sh +inference_group4.sh \ No newline at end of file diff --git a/calculator.py b/calculator.py new file mode 100644 index 00000000..ce6e337f --- /dev/null +++ b/calculator.py @@ -0,0 +1,118 @@ + +import typing +import sys + +def increment(*args) -> typing.List[int]: + """Increment all given numbers""" + return [arg + 1 for arg in args] + + +def print_results(TTP: int, TFP: int, TFN: int, result_dict: typing.Dict) -> typing.Tuple[float, float]: + """Print results.""" + fscores = [] + for k, v in result_dict.items(): + TP = v['TP'] + FP = v['FP'] + FN = v['FN'] + precision = TP / (TP + FP) + recall = TP / (TP + FN) + try: + f = 2 * ((precision * recall) / (precision + recall)) + except ZeroDivisionError: + f = 0 + print(f'{k}: {f}') + fscores.append(f) + print(f"Macro f score: {sum(fscores) / len(fscores)}") + + precision = TTP / (TTP + TFP) + recall = TTP / (TTP + TFN) + + f = 2 * ((precision * recall) / (precision + recall)) + print(f'Micro f score {f}') + return sum(fscores) / len(fscores), f + + +def calculate1(answer_file: str) -> typing.Tuple[float, float]: + """Calculate Fscore for problem1""" + TTP, TFP, TFN = 0, 0, 0 # Total true positive, true negatives, false positives + result_dict = {} + header_skipped = False + with open(answer_file) as fh: + for line in fh: + if not header_skipped: + header_skipped = True + continue + _, cat, gt, ans = line.replace('\n', "").split(',') + _, cat, gt, ans = _, cat, int(gt), int(bool(ans)) + print(f'gt->{gt}, ans->{ans}', gt == ans ) + if cat not in result_dict: + result_dict[cat] = dict(TP=0, FN=0, FP=0) + if ans == gt: + result_dict[cat]['TP'], TTP = increment(result_dict[cat]['TP'], TTP) + else: + result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN = \ + increment(result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN) + return print_results(TTP, TFP, TFN, result_dict) + + +def calculate3(answer_file: str) -> typing.Tuple[float, float]: + """Calculate Fscore for problem3""" + TTP, TFP, TFN = 0, 0, 0 # Total true positive, true negatives, false positives + result_dict = {} + header_skipped = False + with open(answer_file) as fh: + for line in fh: + if not header_skipped: + header_skipped = True + continue + _, cat, qgt, agt, qans, ans = line.replace('\n', "").split(',') + if cat not in result_dict: + result_dict[cat] = dict(TP=0, FN=0, FP=0) + if ans == agt: + result_dict[cat]['TP'], TTP = increment(result_dict[cat]['TP'], TTP) + else: + result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN = \ + increment(result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN) + return print_results(TTP, TFP, TFN, result_dict) + + +def calculate4(answer_file: str) -> typing.Tuple[float, float]: + """Calculate Fscore for problem4""" + TTP, TFP, TFN = 0, 0, 0 # Total true positive, true negatives, false positives + result_dict = {} + header_skipped = False + with open(answer_file) as fh: + for line in fh: + if not header_skipped: + header_skipped = True + continue + _, cat, gt1, gt2, ans1, ans2 = line.replace('\n', "").split(',') + if cat not in result_dict: + result_dict[cat] = dict(TP=0, FN=0, FP=0) + for ans in ans1, ans2: + if ans in [gt1, gt2]: + result_dict[cat]['TP'], TTP = increment(result_dict[cat]['TP'], TTP) + else: + result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN = \ + increment(result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN) + return print_results(TTP, TFP, TFN, result_dict) + +def calculate4_both(answer_file: str) -> typing.Tuple[float, float]: + """Calculate Fscore for problem4""" + TTP, TFP, TFN = 0, 0, 0 # Total true positive, true negatives, false positives + result_dict = {} + header_skipped = False + with open(answer_file) as fh: + for line in fh: + if not header_skipped: + header_skipped = True + continue + _, cat, gt1, gt2, ans1, ans2 = line.replace('\n', "").split(',') + if cat not in result_dict: + result_dict[cat] = dict(TP=0, FN=0, FP=0) + if sorted([ans1,ans2]) == sorted([gt1,gt2]): + result_dict[cat]['TP'], TTP = increment(result_dict[cat]['TP'], TTP) + else: + result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN = \ + increment(result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN) + return print_results(TTP, TFP, TFN, result_dict) \ No newline at end of file diff --git a/dataset_statistics.py b/dataset_statistics.py new file mode 100644 index 00000000..e3a20174 --- /dev/null +++ b/dataset_statistics.py @@ -0,0 +1,46 @@ +import argparse +import torch +import torchvision.transforms as transforms +import torchvision.datasets as datasets +from tqdm import tqdm +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True + +def main(image_dir: str): + + """Get mean and std.""" + dataloader = torch.utils.data.DataLoader( + datasets.ImageFolder(root=image_dir, transform=transforms.Compose([ + transforms.Resize((224, 224)), + transforms.ToTensor(), + ])), batch_size=4 + ) + channels_sum, channels_squared_sum, num_batches = 0, 0, 0 + count = 0 + for data, _ in tqdm(dataloader): + + # Mean over batch, height and width, but not over the channels + channels_sum += torch.mean(data, dim=[0, 2, 3]) + channels_squared_sum += torch.mean(data ** 2, dim=[0, 2, 3]) + num_batches += 1 + count +=1 + if count%100==0: + mean = channels_sum / num_batches + std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5 + print(f'count->{count}, mean->{mean}, std->{std}') + mean = channels_sum / num_batches + + # std = sqrt(E[X^2] - (E[X])^2) + std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5 + print(f"Mean: {mean}") + print(f"Std: {std}") + return mean, std + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--image_dir', type=str) + args = parser.parse_args() + main( + image_dir=args.image_dir, + ) diff --git a/encoder_models/__init__.py b/encoder_models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/encoder_models/base_model.py b/encoder_models/base_model.py new file mode 100644 index 00000000..9d44b514 --- /dev/null +++ b/encoder_models/base_model.py @@ -0,0 +1,122 @@ +import os.path +import shutil +import typing +import numpy as np +import faiss +import torch +import glob +from cached_property import cached_property +from PIL import Image +import time + + +class BaseModel: + + NAME = "Please override in subclass" + CACHE_DIR = 'cache' + + @cached_property + def transform(self): + raise NotImplementedError("Data preprocessing must be implemented in the subclass.") + + def encode(self, preprocessed_image) -> typing.List[float]: + raise NotImplementedError("Encoding must be implemented in the subclass") + + @property + def cache_dir(self): + return os.path.join(self.question_dir, self.CACHE_DIR) + + def __init__(self, weights_path: str = ""): + self.weights_path = weights_path + self.vector_cache = set() + + def _preprocess_image(self, image_path: str): + """Applies transform to image.""" + image_tensor = Image.open(image_path).convert('RGB') + return self.transform(image_tensor) + + def encode_images(self, batch_size: int = 128): + """Applies transform to image.""" + counter = 0 + if os.path.exists(self.cache_dir): + shutil.rmtree(self.cache_dir) + os.mkdir(self.cache_dir) + image_vectors = [] + image_names = [] + type = f'{self.question_dir}/*/*.jp*' + num_images = len(glob.glob(type)) + print(f"{self.question_dir}/*/*.jpg", num_images) + + for image_path in glob.glob(type): + file_name = os.path.basename(image_path) + counter += 1 + if file_name in self.vector_cache: + continue + image_vectors.append(self._preprocess_image(image_path)) + image_names.append(file_name) + self.vector_cache.add(file_name) + if len(image_vectors) == batch_size: + preprocessed_images = self.encode(torch.stack(image_vectors)) + for i in range(batch_size): + np.save(os.path.join(self.cache_dir, f"{image_names[i]}.npy"), preprocessed_images[i,:]) + image_vectors, image_names = [], [] + print(f"Preprocessed {counter}/{num_images} images.") + for i, image_name in enumerate(image_names): + np.save(os.path.join(self.cache_dir, f"{image_name}.npy"), preprocessed_images[i,:]) + + def _preprocess_image_list(self, + image_path_list: typing.List[str], + expected_length: int) -> typing.List[np.array]: + """Feature extraction for image list.""" + image_vectors = [] + for query_image in image_path_list: + file_name = os.path.basename(query_image) + if file_name in self.vector_cache: + query_vector = np.load(os.path.join(self.cache_dir, f"{file_name}.npy")) + else: + preprocessed_image = self._preprocess_image(query_image) + query_vector = np.squeeze(self.encode(preprocessed_image)) + image_vectors.append(query_vector) + assert len(image_vectors) == expected_length + return image_vectors + + @classmethod + def get_k_nearest_neighbors(cls, query_vectors, answer_vectors, k) -> typing.Tuple[np.array, np.array]: + """Gets numpy arrays representing distance and number of k nearest answer vectors to query vectors""" + start = time.time() + index = faiss.IndexFlatL2(query_vectors.shape[-1]) # build the index + index.add(np.stack(answer_vectors)) + distance, indices = index.search(query_vectors, k) + print(f"Neighbour search took {time.time() - start}") + return distance, indices + + def question1(self, + query_image_paths: typing.List[str], + answer_image_paths: typing.List[str], + ): + raise NotImplementedError("Implement in subclass.") + + def question2(self, + query_image_paths: typing.List[str], + answer_image_paths: typing.List[str], + ): + raise NotImplementedError("Implement in subclass.") + + def question3(self, + query_image_paths1: typing.List[str], + query_image_paths2: typing.List[str], + answer_image_paths: typing.List[str], + ): + raise NotImplementedError("Implement in subclass.") + + def question4(self, + query_image_paths: typing.List[str], + answer_image_paths: typing.List[str], + ): + raise NotImplementedError("Implement in subclass.") + + def group2(self, + query_image_paths: typing.List[str], + answer_image_paths: typing.List[str], + ): + raise NotImplementedError("Implement in subclass.") diff --git a/encoder_models/cluter_base_model.py b/encoder_models/cluter_base_model.py new file mode 100644 index 00000000..2fe45280 --- /dev/null +++ b/encoder_models/cluter_base_model.py @@ -0,0 +1,114 @@ +import typing +import numpy as np +from .base_model import BaseModel + + +class ClusterBaseModel(BaseModel): + + @classmethod + def get_centroid(cls, encoded_vector): + """Get centroid.""" + return np.expand_dims(np.sum(encoded_vector, axis=0) / encoded_vector.shape[0], axis=0) + + def question1(self, + query_image_paths: typing.List[str], + answer_image_paths: typing.List[str], + ) -> int: + """Get response for question 1.""" + query_vectors = np.stack(self._preprocess_image_list(query_image_paths, 1)) + answer_vectors1 = np.stack(self._preprocess_image_list(answer_image_paths[:3], 3)) + answer_vectors2 = np.stack(self._preprocess_image_list(answer_image_paths[3:], 3)) + centroid1 = self.get_centroid(answer_vectors1) + centroid2 = self.get_centroid(answer_vectors2) + # Find which centroid the query image is closest to. + D, I = self.get_k_nearest_neighbors(query_vectors, np.concatenate((centroid1, centroid2)), 2) + print(f"Answer centroid distance for group {I[0][0] + 1} - {D[0][0]} " + f"is closer to query than answer centroid distance to group {I[0][1] + 1} - {D[0][1]}") + return I[0][0] + 1 + + def question2(self, + query_image_paths: typing.List[str], + answer_image_paths: typing.List[str], + ) -> bool: + """Get response for question 2.""" + query_vectors = np.stack(self._preprocess_image_list(query_image_paths, 3)) + answer_vectors = np.stack(self._preprocess_image_list(answer_image_paths, 1)) + total_vectors = np.concatenate((query_vectors, answer_vectors), axis=0) + centroid = self.get_centroid(total_vectors) + D, I = self.get_k_nearest_neighbors(centroid, total_vectors, 4) + for i, d in zip(I[0], D[0]): + if i != 3: + print(f"Distance to centroid for query image {i + 1} is {d}") + else: + print(f"Distance to centroid for answer image is {d}") + # Return true if the answer image is not the furthest image away from the centroid. + is_similar = I[0][-1] != 3 + print( + "Answer is similar!" if is_similar else "Answer is different" + ) + return is_similar + + def question3(self, + query_image_paths1: typing.List[str], + query_image_paths2: typing.List[str], + answer_image_paths: typing.List[str] + ) -> typing.Tuple[int, int]: + """Get response for question 3.""" + query1_vectors = np.stack(self._preprocess_image_list(query_image_paths1, 3)) + query2_vectors = np.stack(self._preprocess_image_list(query_image_paths2, 3)) + + centroid1 = self.get_centroid(query1_vectors) + centroid2 = self.get_centroid(query2_vectors) + + D1, _ = self.get_k_nearest_neighbors(centroid1, query1_vectors, 3) + D2, _ = self.get_k_nearest_neighbors(centroid2, query2_vectors, 3) + # Choose which group of images in closer in euclidian space as the positive image group. + print(f"Total euclidian distance from group1 query images to group 1 centroid is {sum(D1[0])}") + print(f"Total euclidian distance from group2 query images to group 2 centroid is {sum(D2[0])}") + centroid, query_answer = ((centroid1, 0) if sum(D1[0]) < sum(D2[0]) else (centroid2, 1)) + print(f"Group {query_answer + 1} is more similar.") + + answer_vectors1 = np.stack(self._preprocess_image_list(answer_image_paths[:3], 3)) + answer_vectors2 = np.stack(self._preprocess_image_list(answer_image_paths[3:6], 3)) + answer_vectors3 = np.stack(self._preprocess_image_list(answer_image_paths[6:], 3)) + # Calculate the distance of all answer image groups to the chosen query centroid. + D1, _ = self.get_k_nearest_neighbors(centroid, answer_vectors1, 3) + D2, _ = self.get_k_nearest_neighbors(centroid, answer_vectors2, 3) + D3, _ = self.get_k_nearest_neighbors(centroid, answer_vectors3, 3) + distance_to_centroid = [sum(D1[0]), sum(D2[0]), sum(D3[0])] + # Choose the smallest distance to the chosen centroid + print(f"Total euclidian distance from group1 answer images to centroid is {sum(D1[0])}") + print(f"Total euclidian distance from group2 answer images to centroid is {sum(D2[0])}") + print(f"Total euclidian distance from group3 answer images to centroid is {sum(D3[0])}") + chosen_group = (distance_to_centroid.index(min(distance_to_centroid)) + 1) + print(f"Closest group to centroid is {chosen_group}") + return query_answer, chosen_group + + def question4(self, + query_image_paths: typing.List[str], + answer_image_paths: typing.List[str], + ) -> typing.Tuple[int, int]: + """Get response for question 4.""" + query_vectors = np.stack(self._preprocess_image_list(query_image_paths, 3)) + answer_vectors = np.stack(self._preprocess_image_list(answer_image_paths, 5)) + # Get the two closest answer images to each query centroid. + centroid = self.get_centroid(query_vectors) + D, I = self.get_k_nearest_neighbors(centroid, answer_vectors, 5) + for i, d in zip(I[0], D[0]): + print(f"Euclidean distance to query centroid for answer image {i + 1} is {d}") + print(f"Two closest images are {I[0][0] + 1, I[0][1] + 1}") + return I[0][0] + 1, I[0][1] + 1 + + def group2(self, + query_image_paths: typing.List[str], + answer_image_paths: typing.List[str], + ) -> int: + """Get response for group 2 questions.""" + query_vectors = np.stack(self._preprocess_image_list(query_image_paths, 1)) + answer_vectors = np.stack(self._preprocess_image_list(answer_image_paths, 3)) + # Find which answer image the query is closest 2. + D, I = self.get_k_nearest_neighbors(query_vectors, answer_vectors, 3) + for i, d in zip(I[0], D[0]): + print(f"Euclidean distance to query for answer image {i + 1} is {d}") + print(f"Closest image is {I[0][0] + 1}") + return I[0][0] + 1 diff --git a/encoder_models/rank_base_model.py b/encoder_models/rank_base_model.py new file mode 100644 index 00000000..13a31bbb --- /dev/null +++ b/encoder_models/rank_base_model.py @@ -0,0 +1,167 @@ +import typing +import numpy as np +import faiss +import sys + +from .base_model import BaseModel + + +class RankBaseModel(BaseModel): + + def _get_distance_between_vector_group(self, query_vectors) -> float: + """Get the distance between a group of 3 query vectors""" + + assert len(query_vectors) == 3 + index = faiss.IndexFlatL2(query_vectors[0].shape[-1]) + index.add(np.stack(query_vectors[1:])) + # Get distance from 1 vector to 2 other. + D1, I = index.search(np.expand_dims(query_vectors[0], axis=0), 2) + + # Get distance between remaining 2 vectors. + index = faiss.IndexFlatL2(query_vectors[0].shape[-1]) + index.add(np.expand_dims(query_vectors[1], axis=0)) + D2, I = index.search(np.expand_dims(query_vectors[2], axis=0), 1) + return sum(D1[0]) + sum(D2[0]) + + def question1(self, + query_image_paths: typing.List[str], + answer_image_paths: typing.List[str], + k: int = 3) -> int: + """Get response for question 1.""" + query_vectors = np.stack(self._preprocess_image_list(query_image_paths, 1)) + answer_vectors = np.stack(self._preprocess_image_list(answer_image_paths, 6)) + _, I = self.get_k_nearest_neighbors(query_vectors, answer_vectors, k) + return self._process_question_1_response(I) + + def _process_question_1_response(self, indices): + """Process response for question 1.""" + SET = {0, 1, 2} + indices, = indices + # indices contains the 3 nearest neighbours to the query image. + # If 2 or more of those images are in group 1, choose group 1 + # If 2 or more of those images are in group 2, choose group 2 + return 2 if len(list(set(indices) - SET)) > 1 else 1 + + def question2(self, + query_image_paths: typing.List[str], + answer_image_paths: typing.List[str], + k: int = 3) -> bool: + """Get response for question 2.""" + query_vectors = np.stack(self._preprocess_image_list(query_image_paths, 3)) + answer_vectors = np.stack(self._preprocess_image_list(answer_image_paths, 1)) + D, _ = self.get_k_nearest_neighbors(answer_vectors, query_vectors, k) + return self._process_question_2_response(D) + + def _process_question_2_response(self, distances) -> bool: + """Process response for question 2.""" + if (distances[0] < 11000).sum(): + ans = True + elif (distances[0] > 14000).sum(): + ans = False + else: + ans = (distances[0] < 12500).sum() == 3 + return ans + + def question3(self, + query_image_paths1: typing.List[str], + query_image_paths2: typing.List[str], + answer_image_paths: typing.List[str] + ) -> typing.Tuple[int, int]: + """Get response for question 3.""" + query1_vectors = self._preprocess_image_list(query_image_paths1, 3) + query2_vectors = self._preprocess_image_list(query_image_paths2, 3) + + distance_between_query_vectors1 = self._get_distance_between_vector_group(query1_vectors) + distance_between_query_vectors2 = self._get_distance_between_vector_group(query2_vectors) + # Choose which group of images in closer in euclidian space as the positive image group. + query_vectors, query_answer = ((query1_vectors, 0) + if distance_between_query_vectors1 < distance_between_query_vectors2 + else (query2_vectors, 1)) + query_vectors = np.stack(query_vectors) + + answer_vectors = np.stack(self._preprocess_image_list(answer_image_paths, 9)) + # For each of the 3 query images, get k=4 nearest neighbours from the 9 answer images + distance, indices = self.get_k_nearest_neighbors(query_vectors, answer_vectors, 4) + + return query_answer, (self._process_question_3_response(distance, indices) + 1) + + def _process_question_3_response(self, distances, indices): + """Process response for question 3.""" + count = [] + distance_list = [sys.maxsize, sys.maxsize, sys.maxsize] + # TODO: This code could probably be cleaned up! + # Iterate over nearest neighbours for each query image + for im_indices, im_dist in zip(indices, distances): + # Get number of answer images from group 1 closest to a query image + group1 = [list(im_indices).index(i) for i in im_indices if i in [0, 1, 2]] + # Get number of answer images from group 2 closest to a query image + group2 = [list(im_indices).index(i) for i in im_indices if i in [3, 4, 5]] + # Get number of answer images from group 3 closest to a query image + group3 = [list(im_indices).index(i) for i in im_indices if i in [6, 7, 8]] + + groups = [group1, group2, group3] + group_numbers = [len(group1), len(group2), len(group3)] + # Select the group with the most answer images closest to the query image + selected_group = group_numbers.index(max(group_numbers)) + im_dist = list(im_dist) + distance = sum(sorted([im_dist[i] for i in groups[selected_group]])[:2]) + if selected_group in count: + # If 2 query images select the same answer group, return that answer group + return selected_group + else: + count.append(selected_group) + distance_list[selected_group] = distance + # If each query image is closes to a different group, select the answer group closest to the query image + # in euclidian space. + answer = distance_list.index(min(distance_list)) + return answer + + def question4(self, + query_image_paths: typing.List[str], + answer_image_paths: typing.List[str], + k: int = 2) -> typing.Tuple[int, int]: + """Get response for question 4.""" + query_vectors = np.stack(self._preprocess_image_list(query_image_paths, 3)) + answer_vectors = np.stack(self._preprocess_image_list(answer_image_paths, 5)) + # Get the two closest answer images to each query image. + distance, indices = self.get_k_nearest_neighbors(query_vectors, answer_vectors, k) + ans1, ans2 = self._process_question_4_response(distance, indices) + ans1 += 1 + ans2 += 1 + return ans1, ans2 + + def _process_question_4_response(self, distances, indices) -> typing.List[int]: + """Get response for question 4.""" + count_dict = {} + distance_dict = {} + # TODO: This code can probably be cleaned up. + for im_indices, im_dist in zip(indices, distances): + # Iterate over each query image and mark which answer image was chosen as closest. + for i, d in zip(im_indices, im_dist): + if i in count_dict: + count_dict[i] += 1 + distance_dict[i] += d + else: + count_dict[i] = 1 + distance_dict[i] = d + # Sort the answer images by the amount of times they were chosen k=2 neighbour of a query image. + sorted_indices = sorted(list(count_dict.keys()), key=count_dict.get, reverse=True) + answers = [] + while len(answers) < 2: + if len(sorted_indices) == 1: + # If there is only 1 image left in the sorted indices list then it must be chosen, + answers.append(sorted_indices.pop()) + continue + if count_dict[sorted_indices[0]] > count_dict[sorted_indices[1]]: + # If there is one answer image selected as a k=2 nearest image more than the other images. This image + # Should be chosen + answers.append(sorted_indices[0]) + sorted_indices.pop(0) + continue + # If multiple answer images are selected as a k=2 nearest neighbour with the same frequency, + # choose which ever answer image is closest to the query image. + top_indices = [i for i in sorted_indices if count_dict[i] == count_dict[sorted_indices[0]]] + top_distance = min(top_indices, key=distance_dict.get) + sorted_indices.remove(top_distance) + answers.append(top_distance) + return answers diff --git a/encoder_models/resnet.py b/encoder_models/resnet.py new file mode 100644 index 00000000..74a9348b --- /dev/null +++ b/encoder_models/resnet.py @@ -0,0 +1,30 @@ + +import torch +import timm + +from cached_property import cached_property +from timm.data import resolve_data_config +from timm.data.transforms_factory import create_transform +from .cluter_base_model import ClusterBaseModel as BaseModel + + +class Resnet(BaseModel): + + NAME = 'resnet' + + @cached_property + def transform(self): + config = resolve_data_config({}, model=self.model) + transform = create_transform(**config) + return transform + + def __init__(self, weights_path=""): + """Initialize""" + super().__init__(weights_path) + model = timm.create_model('resnet50', pretrained=True, num_classes=0) + model.eval() + self.model = model + + def encode(self, preprocessed_image): + r = self.model(torch.unsqueeze(preprocessed_image, 0)) + return r.detach().numpy() diff --git a/encoder_models/vit.py b/encoder_models/vit.py new file mode 100644 index 00000000..7dbad670 --- /dev/null +++ b/encoder_models/vit.py @@ -0,0 +1,48 @@ +import typing + +import torch +import timm + +from cached_property import cached_property +from timm.data import resolve_data_config +from timm.data.transforms_factory import create_transform +from .cluter_base_model import ClusterBaseModel +from .rank_base_model import RankBaseModel + + +class Vit: + + @cached_property + def transform(self): + config = resolve_data_config({}, model=self.model) + config["mean"] = self.mean + config["std"] = self.std + transform = create_transform(**config) + return transform + + def __init__(self, weights_path, mean: typing.List[float], std: typing.List[float], question_dir: str): + """Initialize""" + super().__init__(weights_path) + self.std = std + self.mean = mean + model = timm.create_model('vit_base_patch16_224') + weights = torch.load(self.weights_path, map_location=torch.device('cuda')) + model.load_state_dict({k.replace("module.", ""): v for k, v in weights["model"].items()}, strict=False) + model.to("cuda") + model.eval() + self.model = model + self.question_dir = question_dir + + def encode(self, preprocessed_images): + r = self.model.forward_features(preprocessed_images.to("cuda")) + return r.cpu().detach().numpy() + + +class ClusterVit(Vit, ClusterBaseModel): + + NAME = 'cluster_vit' + + +class RankVit(Vit, RankBaseModel): + + NAME = 'rank_vit' diff --git a/fscore.py b/fscore.py new file mode 100644 index 00000000..46080b4f --- /dev/null +++ b/fscore.py @@ -0,0 +1,34 @@ +import argparse + +from calculator import calculate1, calculate3, calculate4 + + +CALCULATOR = { + 1: calculate1, + 2: calculate1, + 3: calculate3, + 4: calculate4, + 5: calculate1, + 7: calculate1, + 8: calculate4, +} + + +def main( + answer_file: str, + question_number: int, +): + CALCULATOR[question_number]( + answer_file=answer_file, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--answer_file', type=str) + parser.add_argument('--question_number', type=int) + args = parser.parse_args() + main( + answer_file=args.answer_file, + question_number=args.question_number, + ) diff --git a/inference_g4.py b/inference_g4.py new file mode 100644 index 00000000..c05537dc --- /dev/null +++ b/inference_g4.py @@ -0,0 +1,91 @@ +from __future__ import print_function + +import argparse +from question_loader import Question1Dataset, Question2Dataset, Question3Dataset, Question4Dataset, Group2Dataset +import wandb +from calculator import calculate1, calculate3, calculate4, calculate4_both +from quiz_master import quiz3, quiz1, quiz2, quiz4, group2, group4_question1, group4_question2 +from encoder_models.vit import RankVit, ClusterVit + +try: + import apex + from apex import amp, optimizers +except ImportError: + pass + +MODELS = { + 'vit': ClusterVit, + 'rankvit': RankVit +} + +CALCULATOR = { + 1: calculate1, + 2: calculate1, + 3: calculate3, + 4: calculate4, + 5: calculate1, + 7: calculate1, + 8: calculate4_both, +} + +QUIZ_OPTIONS = { + 1: quiz1, + 2: quiz2, + 3: quiz3, + 4: quiz4, + 5: group2, + 7: group4_question1, + 8: group4_question2, +} + +def f1_graph(question_number,model_name,weights_path,opt,question_dir,file_name,title): + score_model = MODELS[model_name]( + weights_path=weights_path, + mean=eval(opt.mean), + std=eval(opt.std), + question_dir=question_dir, + ) + + score_model.encode_images() + + QUIZ_OPTIONS[question_number]( + score_model, + question_dir, + file_name, + ) + var=CALCULATOR[question_number]( + answer_file = file_name, + ) + var = list(var) + print({'micro_f1'+title: var[0]}) + print({'macro_f1'+title: var[1]}) + +def main(): + + parser = argparse.ArgumentParser('argument for training') + parser.add_argument('--mean', type=str,default="(0.6958, 0.6816, 0.6524)") + parser.add_argument('--std', type=str,default= "(0.3159, 0.3100, 0.3385)") + parser.add_argument('--weights_path', type=str,default= "") + parser.add_argument('--root_dir', type=str,default= "/media0/chris/group4_resize_v2") + parser.add_argument('--model_name', type=str,default= "vit") + parser.add_argument('--valid_path', type=str,default= "valid") + parser.add_argument('--test_path', type=str,default= "test") + + opt = parser.parse_args() + + file_name = opt.weights_path.replace('pth','csv') + + question1_val_dir = f"{opt.root_dir}/{opt.valid_path}/question1" + f1_graph(7, opt.model_name, opt.weights_path, opt, question1_val_dir, file_name, "question1_valid") + + question1_tst_dir = f"{opt.root_dir}/{opt.test_path}/question1" + f1_graph(7, opt.model_name, opt.weights_path, opt,question1_tst_dir, file_name, "question1_test") + + question2_val_dir = f"{opt.root_dir}/{opt.valid_path}/question2" + f1_graph(8, opt.model_name, opt.weights_path, opt,question2_val_dir, file_name, "question2_valid") + + question2_tst_dir = f"{opt.root_dir}/{opt.test_path}/question2" + f1_graph(8, opt.model_name, opt.weights_path, opt,question2_tst_dir, file_name, "question2_test") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 00000000..f0beb723 --- /dev/null +++ b/main.py @@ -0,0 +1,118 @@ + +import argparse +import os.path +import typing +from datetime import datetime +import glob + +from encoder_models.vit import RankVit, ClusterVit +from encoder_models.resnet import Resnet +from quiz_master import quiz3, quiz1, quiz2, quiz4, group2, group4_question1, group4_question2 +from calculator import calculate1, calculate3, calculate4 + +MODELS = { + 'vit': ClusterVit, + 'resnet': Resnet, + 'rankvit': RankVit +} + +CALCULATOR_OPTIONS = { + 1: calculate1, + 2: calculate1, + 3: calculate3, + 4: calculate4, + 5: calculate1, + 7: calculate1, + 8: calculate4, +} + +QUIZ_OPTIONS = { + 1: quiz1, + 2: quiz2, + 3: quiz3, + 4: quiz4, + 5: group2, + 7: group4_question1, + 8: group4_question2, +} + + +def get_sorted_weights(weights_dir: str) -> typing.List[typing.Tuple[str, int]]: + """Get weights files.""" + weights_files = list(glob.glob(f"{weights_dir}/*.pth")) + weights_files = [(weights, int(weights.replace('.pth', "").split("_")[-1])) for weights in weights_files] + return sorted(weights_files, + key=lambda x: x[-1]) + + +def main( + model_name: str, + weights_path: str, + question_number: int, + question_dir: str, + mean: typing.List[float], + std: typing.List[float], + **kwargs +): + if not os.path.isdir(weights_path): + model = MODELS[model_name]( + weights_path=weights_path, + mean=mean, + std=std, + question_dir=question_dir, + ) + model.encode_images() + QUIZ_OPTIONS[question_number]( + model, + question_dir + ) + else: + with open(f"question_number{question_number}_{str(datetime.now())[:-7].replace(' ', '')}.csv", 'w+') as fh: + fh.write('epoch,fmacro,fmicro\n') + fh.flush() + for weight, epoch in get_sorted_weights(weights_path): + result_file = weight.replace('pth', 'csv') + model = MODELS[model_name]( + weights_path=weight, + mean=mean, + std=std, + question_dir=question_dir, + ) + model.encode_images() + QUIZ_OPTIONS[question_number]( + model, + question_dir, + file_name=result_file + ) + fmacro, fmicro = CALCULATOR_OPTIONS[question_number](result_file) + fh.write(f'{epoch},{fmacro},{fmicro}\n') + fh.flush() + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--model', type=str, default='vit', help='Choose vit, simclr or resnet') + parser.add_argument('--weights_path', type=str, default="") + parser.add_argument('--question_number', type=int) + parser.add_argument('--question_dir', type=str) + parser.add_argument('--mean', type=str) + parser.add_argument('--std', type=str) + parser.add_argument('--additional_args', type=str, default="", + help="If your model requires additional args") + args = parser.parse_args() + additional_args = {} + if args.additional_args and "=" in args.additional_args: + for arg_group in args.additional_args.split(','): + k, v = args.additional_args.split('=') + additional_args[k] = v + print('args.mean->', args.mean) + main( + model_name=args.model, + weights_path=args.weights_path, + question_number=args.question_number, + question_dir=args.question_dir, + mean=eval(args.mean), + std=eval(args.std), + **additional_args + ) diff --git a/main_supcon.py b/main_supcon.py index 937cf626..d57e2a5d 100644 --- a/main_supcon.py +++ b/main_supcon.py @@ -21,6 +21,12 @@ from timm.data.transforms_factory import create_transform from networks.vit import SupConVit from losses import SupConLoss +import wandb +from calculator import calculate1, calculate3, calculate4 +from quiz_master import quiz3, quiz1, quiz2, quiz4, group2, group4_question1, group4_question2 +from encoder_models.vit import RankVit, ClusterVit +from encoder_models.resnet import Resnet +from tqdm import tqdm try: import apex @@ -28,13 +34,37 @@ except ImportError: pass +MODELS = { + 'vit': ClusterVit, + 'rankvit': RankVit +} + +CALCULATOR = { + 1: calculate1, + 2: calculate1, + 3: calculate3, + 4: calculate4, + 5: calculate1, + 7: calculate1, + 8: calculate4, +} + +QUIZ_OPTIONS = { + 1: quiz1, + 2: quiz2, + 3: quiz3, + 4: quiz4, + 5: group2, + 7: group4_question1, + 8: group4_question2, +} def parse_option(): parser = argparse.ArgumentParser('argument for training') parser.add_argument('--print_freq', type=int, default=10, help='print frequency') - parser.add_argument('--save_freq', type=int, default=50, + parser.add_argument('--save_freq', type=int, default=1, help='save frequency') parser.add_argument('--batch_size', type=int, default=256, help='batch_size') @@ -42,7 +72,9 @@ def parse_option(): help='num of workers to use') parser.add_argument('--epochs', type=int, default=1000, help='number of training epochs') - parser.add_argument('--use_parallel', type=bool, default=False, + parser.add_argument('--use_parallel', type=bool, default=True, + help='Use parallel trainer') + parser.add_argument('--gpus', type=str, default="2,3", help='Use parallel trainer') # optimization @@ -67,6 +99,9 @@ def parse_option(): parser.add_argument('--data_folder', type=str, default=None, help='path to custom dataset') parser.add_argument('--size', type=int, default=32, help='parameter for RandomResizedCrop') parser.add_argument('--pretrained', type=bool, default=False) + parser.add_argument('--wandb_id', type=str, default="") + parser.add_argument('--wandb', type=bool, default=False) + parser.add_argument('--wandb_pn', type=str, default="") # method parser.add_argument('--method', type=str, default='SupCon', @@ -291,7 +326,7 @@ def train(train_loader, model, criterion, optimizer, epoch, opt): labels = torch.tensor(labels, dtype=int) # Reshape the images from 5D to 4D tensors. images = [images[0].reshape([images[0].shape[0] * images[0].shape[1], *images[0].shape[2:]]), - images[1].reshape([images[0].shape[0] * images[0].shape[1], *images[0].shape[2:]])] + images[1].reshape([images[0].shape[0] * images[0].shape[1], *images[0].shape[2:]])] if labels.shape[0] != images[0].shape[0]: print(f'Skipping question {labels.shape[0]} != {images[0].shape[0]}') @@ -319,7 +354,7 @@ def train(train_loader, model, criterion, optimizer, epoch, opt): loss = criterion(features) else: raise ValueError('contrastive method not supported: {}'. - format(opt.method)) + format(opt.method)) # update metric losses.update(loss.item(), bsz) @@ -335,15 +370,15 @@ def train(train_loader, model, criterion, optimizer, epoch, opt): # print info if (idx + 1) % opt.print_freq == 0: - print(f'{datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")}\t' + - 'Train: [{0}][{1}/{2}]\t' - 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' - 'loss {loss.val:.3f} ({loss.avg:.3f})'.format( - epoch, idx + 1, len(train_loader), batch_time=batch_time, - data_time=data_time, loss=losses)) - sys.stdout.flush() + print('Train: [{0}][{1}/{2}]\t' + 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'loss {loss.val:.3f} ({loss.avg:.3f})'.format( + epoch, idx + 1, len(train_loader), batch_time=batch_time, + data_time=data_time, loss=losses)) + sys.stdout.flush() + return losses.avg @@ -402,10 +437,10 @@ def train_category(train_loader, model, criterion, optimizer, epoch, opt): return losses.avg - def main(): opt = parse_option() - + if opt.wandb: + wandb.init(project=opt.wandb_pn, entity=opt.wandb_id) # build data loader if opt.method == 'SimCLR': train_loader = set_loader_category(opt) @@ -423,6 +458,7 @@ def main(): # training routine for epoch in range(1, opt.epochs + 1): + adjust_learning_rate(opt, optimizer, epoch) # train for one epoch @@ -431,18 +467,25 @@ def main(): loss = train_category(train_loader, model, criterion, optimizer, epoch, opt) else: loss = train(train_loader, model, criterion, optimizer, epoch, opt) + if opt.wandb: + wandb.log({'loss':loss}, step=epoch) time2 = time.time() - print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) - # tensorboard logger logger.log_value('loss', loss, epoch) logger.log_value('learning_rate', optimizer.param_groups[0]['lr'], epoch) - print(f'loss: {loss} epoch {epoch}') + print(f'epoch {epoch}, loss: {loss} ') + weights_name = f'ckpt_{opt.method}_pretrained_{opt.pretrained}_{opt.group_num}_epoch_{epoch}.pth' if epoch % opt.save_freq == 0: - save_file = os.path.join( - opt.save_folder, f'ckpt_{opt.method}_pretrained_{opt.pretrained}_{opt.group_num}_epoch_{epoch}.pth') - save_model(model, optimizer, opt, epoch, save_file) + + weights_path = os.path.join( + opt.save_folder,weights_name) + save_model(model, optimizer, opt, epoch, weights_path) + model_name = 'vit' + valid_path = "valid" + test_path = "test" + + root_dir = "/".join(opt.data_folder.split("/")[:-1]) # save the last model save_file = os.path.join( diff --git a/networks/vit.py b/networks/vit.py index e37b40ae..61a08260 100644 --- a/networks/vit.py +++ b/networks/vit.py @@ -10,6 +10,7 @@ def __init__(self, name='', head='mlp', feat_dim=128, pretrained=False): dim_in = 768 print(f"Setting pretrained to {pretrained}") self.encoder = timm.create_model("vit_base_patch16_224", pretrained=pretrained, num_classes=0) + if head == 'linear': self.head = nn.Linear(dim_in, feat_dim) elif head == 'mlp': diff --git a/question_loader.py b/question_loader.py index 2d423a8b..c30afebc 100644 --- a/question_loader.py +++ b/question_loader.py @@ -6,6 +6,8 @@ from torch.utils.data import Dataset from PIL import Image import numpy as np +from PIL import ImageFile +ImageFile.LOAD_TRUNCATED_IMAGES = True class BaseQuestionLoader(Dataset): """Question loader base model.""" @@ -72,18 +74,19 @@ def __getitem__(self, idx): question_image = info['Questions'][0]['images'] samples1, samples2 = [], [] for im in question_image: + image = Image.open(os.path.join( self.root, dir_name, im['image_url'] )).convert('RGB') - # Augment the images twice. sample1, sample2 = self.transform(image) samples1.append(sample1) samples2.append(sample2) + # Augment the images twice. + samples = [torch.squeeze(torch.stack(samples1), dim=0), torch.squeeze(torch.stack(samples2), dim=0)] if samples[0].shape[0] != 3: - # Delete question if it is invalid - #shutil.rmtree(os.path.join(self.root, dir_name)) + shutil.rmtree(os.path.join(self.root, dir_name)) print(f"removed {os.path.join(self.root, dir_name)}") return samples @@ -117,7 +120,7 @@ def __getitem__(self, idx): torch.squeeze(torch.stack(samples2), dim=0)] if samples[0].shape[0] != 6: # # Delete question if it is invalid - # shutil.rmtree(os.path.join(self.root, dir_name)) + shutil.rmtree(os.path.join(self.root, dir_name)) print(f"removed {os.path.join(self.root, dir_name)}") return samples @@ -139,6 +142,7 @@ def __getitem__(self, idx): samples1, samples2 = [], [] for im in positive_samples: + image = Image.open(os.path.join( self.root, dir_name, im['image_url'] )).convert('RGB') @@ -146,11 +150,12 @@ def __getitem__(self, idx): sample1, sample2 = self.transform(image) samples1.append(sample1) samples2.append(sample2) + samples = [torch.squeeze(torch.stack(samples1), dim=0), torch.squeeze(torch.stack(samples2), dim=0)] if samples[0].shape[0] != 5: - # Delete question if it is invalid - #shutil.rmtree(os.path.join(self.root, dir_name)) + + shutil.rmtree(os.path.join(self.root, dir_name)) print(f"removed {os.path.join(self.root, dir_name)}") return samples diff --git a/quiz_master.py b/quiz_master.py new file mode 100644 index 00000000..c21fc438 --- /dev/null +++ b/quiz_master.py @@ -0,0 +1,311 @@ +import json +import os +import traceback +import typing +from datetime import datetime +from encoder_models.base_model import BaseModel + + +def quiz1(model: BaseModel, question_dir: str, file_name: typing.Optional[str] = None): + """ + Code for solving question 1. Given 1 query image and 2 answer groups each containing 3 images, choose which + image group is most similar to the query image. + + :param model: The model to solve the challenge. + :param question_dir: The path to the directory containing the questions. + """ + + question_dirs = os.listdir(question_dir) + file_name = file_name or f'{model.NAME}q1{str(datetime.now()).replace(" ", "").replace(":", "")}.csv' + question_dirs.remove(BaseModel.CACHE_DIR) + incorrect_questions = [] + total = 0 + correct = 0 + with open(file_name, 'w+') as fh: + fh.write('question,category,gt,ans\n') + fh.flush() + for i, question_number in enumerate(question_dirs): + print(f'Answering question {question_number}') + qdir_path = os.path.join(question_dir, question_number) + try: + with open(os.path.join(qdir_path, + f'{question_number}.json')) as q: + question_info = json.load(q) + category = question_info['category'] + assert len(question_info['correct_answer_group_ID']) == 1 + gt, = question_info['correct_answer_group_ID'] + query_images = [ + os.path.join(qdir_path, im_dict['image_url']) + for im_dict in question_info['Questions'][0]["images"] + ] + answer_images = [ + os.path.join(qdir_path, image['image_url']) + for im_dict in question_info['Answers'] + for image in im_dict['images'] + ] + + ans = model.question1(query_images, answer_images) + + except Exception: + traceback.print_exc() + print(f'Invalid question for {question_number}') + incorrect_questions.append(question_number) + else: + fh.write(f'{question_number},{category},{gt},{ans}\n') + fh.flush() + total += 1 + if ans == gt: + correct += 1 + print(f'Accuracy after {i} is {correct / total}') + print('These questions are incorrect') + print(incorrect_questions) + + +def quiz2(model: BaseModel, + question_dir: str, + quiz_name: typing.Optional[str] = 'q2', + file_name: typing.Optional[str] = None, + ): + """ + Code for solving question 2. Given 3 query images and 1 answer image, decide if the query images are similar to + the answer image. + + :param model: The model to solve the challenge. + :param question_dir: The path to the directory containing the questions. + """ + + question_dirs = os.listdir(question_dir) + file_name = file_name or f'{model.NAME}{quiz_name}{str(datetime.now()).replace(" ", "").replace(":", "")}.csv' + question_dirs.remove(BaseModel.CACHE_DIR) + incorrect_questions = [] + total = 0 + correct = 0 + with open(file_name, 'w+') as fh: + fh.write('question,category,gt,ans\n') + fh.flush() + for i, question_number in enumerate(question_dirs): + print(f'Answering question {question_number}') + qdir_path = os.path.join(question_dir, question_number) + try: + with open(os.path.join(qdir_path, + f'{question_number}.json')) as q: + question_info = json.load(q) + category = question_info['category'] + gt = question_info['is_correct'] + query_images = [ + os.path.join(qdir_path, im_dict['image_url']) + for im_dict in question_info['Questions'][0]["images"] + ] + answer_images = [ + os.path.join(qdir_path, image['image_url']) + for im_dict in question_info['Answers'] + for image in im_dict['images'] + ] + + ans = model.question2(query_images, answer_images) + except Exception: + traceback.print_exc() + print(f'Invalid question for {question_number}') + incorrect_questions.append(question_number) + else: + fh.write(f'{question_number},{category},{gt},{ans}\n') + fh.flush() + total += 1 + if ans == gt: + correct += 1 + print(f'Accuracy after {i} is {correct / total}') + print('These questions are incorrect') + print(incorrect_questions) + + +def quiz3(model: BaseModel, question_dir: str, file_name: typing.Optional[str] = None): + """ + Code for solving question 3. Given 2 groups of query image pick which group has similar characteristics. Next given + 3 groups of answer images, choose which group is most similar to the chosen query image group. + + + :param model: The model to solve the challenge. + :param question_dir: The path to the directory containing the questions. + """ + + question_dirs = os.listdir(question_dir) + file_name = file_name or f'{model.NAME}q3{str(datetime.now()).replace(" ", "").replace(":", "")}.csv' + question_dirs.remove(BaseModel.CACHE_DIR) + incorrect_questions = [] + total = 0 + correct = 0 + with open(file_name, 'w+') as fh: + fh.write('question,category,qgt,agt,qans,ans\n') + fh.flush() + for i, question_number in enumerate(question_dirs): + print(f'Answering question {question_number}') + qdir_path = os.path.join(question_dir, question_number) + try: + with open(os.path.join(qdir_path, + f'{question_number}.json')) as q: + question_info = json.load(q) + category = question_info['category'] + assert len(question_info['correct_question_group_ID']) == 1 + gt = question_info['correct_answer_group_ID'] + question_gt, = question_info['correct_question_group_ID'] + query1, query2 = [ + [os.path.join(qdir_path, im['image_url']) for im in im_dict["images"]] + for im_dict in question_info['Questions'] + ] + answer_images = [ + os.path.join(qdir_path, image['image_url']) + for im_dict in question_info['Answers'] + for image in im_dict['images'] + ] + + query_answer, ans = model.question3(query1, query2, answer_images) + group_id = question_info['Questions'][query_answer]["group_id"] + + except Exception: + traceback.print_exc() + print(f'Invalid question for {question_number}') + incorrect_questions.append(question_number) + else: + fh.write(f'{question_number},{category},{question_gt},{gt},{group_id},{ans}\n') + fh.flush() + total += 1 + if ans == gt: + correct += 1 + print(f'Accuracy after {i} is {correct / total}') + print('These questions are incorrect') + print(incorrect_questions) + + +def quiz4(model: BaseModel, + question_dir: str, + quiz_name: typing.Optional[str] = 'q4', + file_name: typing.Optional[str] = None): + """ + Code for solving question 4. Given 3 query images and 5 answer images, choose the 2 answer images most similar to + query images. + + :param model: The model to solve the challenge. + :param question_dir: The path to the directory containing the questions. + """ + + question_dirs = os.listdir(question_dir) + file_name = file_name or f'{model.NAME}{quiz_name}{str(datetime.now()).replace(" ", "").replace(":", "")}.csv' + question_dirs.remove(BaseModel.CACHE_DIR) + incorrect_questions = [] + total = 0 + correct = 0 + with open(file_name, 'w+') as fh: + fh.write('question,category,gt1,gt2,ans1,ans2\n') + fh.flush() + for i, question_number in enumerate(question_dirs): + print(f'Answering question {question_number}') + qdir_path = os.path.join(question_dir, question_number) + try: + with open(os.path.join(qdir_path, + f'{question_number}.json')) as q: + question_info = json.load(q) + category = question_info['category'] + assert len(question_info['correct_answer_group_ID']) == 2 + gt1, gt2 = question_info['correct_answer_group_ID'] + query_images = [ + os.path.join(qdir_path, im_dict['image_url']) + for im_dict in question_info['Questions'][0]["images"] + ] + answer_images = [ + os.path.join(qdir_path, im_dict['images'][0]['image_url']) + for im_dict in question_info['Answers'] + ] + + ans1, ans2 = model.question4(query_images, answer_images) + except Exception: + traceback.print_exc() + print(f'Invalid question for {question_number}') + incorrect_questions.append(question_number) + else: + fh.write(f'{question_number},{category},{gt1},{gt2},{ans1},{ans2}\n') + fh.flush() + total += 2 + if ans1 in [gt1, gt2]: + correct += 1 + if ans2 in [gt1, gt2]: + correct += 1 + print(f'Accuracy after {i} is {correct / total}') + print('These questions are incorrect') + print(incorrect_questions) + + +def group2(model: BaseModel, question_dir: str, file_name: typing.Optional[str] = None): + """ + Code for solving group 2 questions. Given 1 query image and 3 answer images, choose which + answer group is most similar to the query image. + + :param model: The model to solve the challenge. + :param question_dir: The path to the directory containing the questions. + """ + + question_dirs = os.listdir(question_dir) + file_name = file_name or f'{model.NAME}gruop2{str(datetime.now()).replace(" ", "").replace(":", "")}.csv' + question_dirs.remove(BaseModel.CACHE_DIR) + incorrect_questions = [] + total = 0 + correct = 0 + with open(file_name, 'w+') as fh: + fh.write('question,category,gt,ans\n') + fh.flush() + for i, question_number in enumerate(question_dirs): + print(f'Answering question {question_number}') + qdir_path = os.path.join(question_dir, question_number) + try: + with open(os.path.join(qdir_path, + f'{question_number}.json')) as q: + question_info = json.load(q) + category = question_info['category'] + assert len(question_info['correct_answer_group_ID']) == 1 + gt, = question_info['correct_answer_group_ID'] + query_images = [ + os.path.join(qdir_path, im_dict['image_url']) + for im_dict in question_info['Questions'][0]["images"] + ] + answer_images = [ + os.path.join(qdir_path, image['image_url']) + for im_dict in question_info['Answers'] + for image in im_dict['images'] + ] + + ans = model.group2(query_images, answer_images) + + except Exception: + traceback.print_exc() + print(f'Invalid question for {question_number}') + incorrect_questions.append(question_number) + else: + fh.write(f'{question_number},{category},{gt},{ans}\n') + fh.flush() + total += 1 + if ans == gt: + correct += 1 + print(f'Accuracy after {i} is {correct / total}') + print('These questions are incorrect') + print(incorrect_questions) + + +def group4_question1(model: BaseModel, question_dir: str, file_name: typing.Optional[str] = None): + """ + Code for solving quiz 4 question 1. Given 3 query images and 1 answer image, decide if the query images are similar to + the answer image. + + :param model: The model to solve the challenge. + :param question_dir: The path to the directory containing the questions. + """ + quiz2(model, question_dir, 'quiz4_q1', file_name=file_name) + + +def group4_question2(model: BaseModel, question_dir: str, file_name: typing.Optional[str] = None): + """ + Code for solving quiz 4 question 2. Given 3 query images and 5 answer images, choose the 2 answer images most similar to + query images. + + :param model: The model to solve the challenge. + :param question_dir: The path to the directory containing the questions. + """ + quiz4(model, question_dir, 'quiz4_q2', file_name=file_name) diff --git a/util.py b/util.py index b6323530..a761b61e 100644 --- a/util.py +++ b/util.py @@ -84,7 +84,7 @@ def set_optimizer(opt, model): def save_model(model, optimizer, opt, epoch, save_file): - print('==> Saving...') + state = { 'opt': opt, 'model': model.state_dict(),