forked from HobbitLong/SupContrast
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* test * pretrained=false * main_supcon.py add wandb * question loader * utils * fscore * calculator.py * dataset-statis.py * mains * gitignore * script * encoder_models/ * igmore * Delete main_f1score.py delete * change * Delete train_supcon_50p_group4.sh train_supcon_50p_grroup4 delete * gitignore * calculator remove comment * compare better way * file_name not using * f1_graph remove * question loader try remove * inference group4 * main_supcon * reslove epsilon * remove continuing try exception block * remove print statement * remove try exception for removing damage jpg,jpeg * jpeg jpg include * remove resnet from main_supcon * spaces after each comma inference_g4.py * delete unsed comment * multiple line vioklation * remove unused import * remove unsed try except question_loader
- Loading branch information
Showing
17 changed files
with
1,290 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
|
||
import typing | ||
import sys | ||
|
||
def increment(*args) -> typing.List[int]: | ||
"""Increment all given numbers""" | ||
return [arg + 1 for arg in args] | ||
|
||
|
||
def print_results(TTP: int, TFP: int, TFN: int, result_dict: typing.Dict) -> typing.Tuple[float, float]: | ||
"""Print results.""" | ||
fscores = [] | ||
for k, v in result_dict.items(): | ||
TP = v['TP'] | ||
FP = v['FP'] | ||
FN = v['FN'] | ||
precision = TP / (TP + FP) | ||
recall = TP / (TP + FN) | ||
try: | ||
f = 2 * ((precision * recall) / (precision + recall)) | ||
except ZeroDivisionError: | ||
f = 0 | ||
print(f'{k}: {f}') | ||
fscores.append(f) | ||
print(f"Macro f score: {sum(fscores) / len(fscores)}") | ||
|
||
precision = TTP / (TTP + TFP) | ||
recall = TTP / (TTP + TFN) | ||
|
||
f = 2 * ((precision * recall) / (precision + recall)) | ||
print(f'Micro f score {f}') | ||
return sum(fscores) / len(fscores), f | ||
|
||
|
||
def calculate1(answer_file: str) -> typing.Tuple[float, float]: | ||
"""Calculate Fscore for problem1""" | ||
TTP, TFP, TFN = 0, 0, 0 # Total true positive, true negatives, false positives | ||
result_dict = {} | ||
header_skipped = False | ||
with open(answer_file) as fh: | ||
for line in fh: | ||
if not header_skipped: | ||
header_skipped = True | ||
continue | ||
_, cat, gt, ans = line.replace('\n', "").split(',') | ||
_, cat, gt, ans = _, cat, int(gt), int(bool(ans)) | ||
print(f'gt->{gt}, ans->{ans}', gt == ans ) | ||
if cat not in result_dict: | ||
result_dict[cat] = dict(TP=0, FN=0, FP=0) | ||
if ans == gt: | ||
result_dict[cat]['TP'], TTP = increment(result_dict[cat]['TP'], TTP) | ||
else: | ||
result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN = \ | ||
increment(result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN) | ||
return print_results(TTP, TFP, TFN, result_dict) | ||
|
||
|
||
def calculate3(answer_file: str) -> typing.Tuple[float, float]: | ||
"""Calculate Fscore for problem3""" | ||
TTP, TFP, TFN = 0, 0, 0 # Total true positive, true negatives, false positives | ||
result_dict = {} | ||
header_skipped = False | ||
with open(answer_file) as fh: | ||
for line in fh: | ||
if not header_skipped: | ||
header_skipped = True | ||
continue | ||
_, cat, qgt, agt, qans, ans = line.replace('\n', "").split(',') | ||
if cat not in result_dict: | ||
result_dict[cat] = dict(TP=0, FN=0, FP=0) | ||
if ans == agt: | ||
result_dict[cat]['TP'], TTP = increment(result_dict[cat]['TP'], TTP) | ||
else: | ||
result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN = \ | ||
increment(result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN) | ||
return print_results(TTP, TFP, TFN, result_dict) | ||
|
||
|
||
def calculate4(answer_file: str) -> typing.Tuple[float, float]: | ||
"""Calculate Fscore for problem4""" | ||
TTP, TFP, TFN = 0, 0, 0 # Total true positive, true negatives, false positives | ||
result_dict = {} | ||
header_skipped = False | ||
with open(answer_file) as fh: | ||
for line in fh: | ||
if not header_skipped: | ||
header_skipped = True | ||
continue | ||
_, cat, gt1, gt2, ans1, ans2 = line.replace('\n', "").split(',') | ||
if cat not in result_dict: | ||
result_dict[cat] = dict(TP=0, FN=0, FP=0) | ||
for ans in ans1, ans2: | ||
if ans in [gt1, gt2]: | ||
result_dict[cat]['TP'], TTP = increment(result_dict[cat]['TP'], TTP) | ||
else: | ||
result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN = \ | ||
increment(result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN) | ||
return print_results(TTP, TFP, TFN, result_dict) | ||
|
||
def calculate4_both(answer_file: str) -> typing.Tuple[float, float]: | ||
"""Calculate Fscore for problem4""" | ||
TTP, TFP, TFN = 0, 0, 0 # Total true positive, true negatives, false positives | ||
result_dict = {} | ||
header_skipped = False | ||
with open(answer_file) as fh: | ||
for line in fh: | ||
if not header_skipped: | ||
header_skipped = True | ||
continue | ||
_, cat, gt1, gt2, ans1, ans2 = line.replace('\n', "").split(',') | ||
if cat not in result_dict: | ||
result_dict[cat] = dict(TP=0, FN=0, FP=0) | ||
if sorted([ans1,ans2]) == sorted([gt1,gt2]): | ||
result_dict[cat]['TP'], TTP = increment(result_dict[cat]['TP'], TTP) | ||
else: | ||
result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN = \ | ||
increment(result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN) | ||
return print_results(TTP, TFP, TFN, result_dict) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import argparse | ||
import torch | ||
import torchvision.transforms as transforms | ||
import torchvision.datasets as datasets | ||
from tqdm import tqdm | ||
from PIL import ImageFile | ||
ImageFile.LOAD_TRUNCATED_IMAGES = True | ||
|
||
def main(image_dir: str): | ||
|
||
"""Get mean and std.""" | ||
dataloader = torch.utils.data.DataLoader( | ||
datasets.ImageFolder(root=image_dir, transform=transforms.Compose([ | ||
transforms.Resize((224, 224)), | ||
transforms.ToTensor(), | ||
])), batch_size=4 | ||
) | ||
channels_sum, channels_squared_sum, num_batches = 0, 0, 0 | ||
count = 0 | ||
for data, _ in tqdm(dataloader): | ||
|
||
# Mean over batch, height and width, but not over the channels | ||
channels_sum += torch.mean(data, dim=[0, 2, 3]) | ||
channels_squared_sum += torch.mean(data ** 2, dim=[0, 2, 3]) | ||
num_batches += 1 | ||
count +=1 | ||
if count%100==0: | ||
mean = channels_sum / num_batches | ||
std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5 | ||
print(f'count->{count}, mean->{mean}, std->{std}') | ||
mean = channels_sum / num_batches | ||
|
||
# std = sqrt(E[X^2] - (E[X])^2) | ||
std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5 | ||
print(f"Mean: {mean}") | ||
print(f"Std: {std}") | ||
return mean, std | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument('--image_dir', type=str) | ||
args = parser.parse_args() | ||
main( | ||
image_dir=args.image_dir, | ||
) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
import os.path | ||
import shutil | ||
import typing | ||
import numpy as np | ||
import faiss | ||
import torch | ||
import glob | ||
from cached_property import cached_property | ||
from PIL import Image | ||
import time | ||
|
||
|
||
class BaseModel: | ||
|
||
NAME = "Please override in subclass" | ||
CACHE_DIR = 'cache' | ||
|
||
@cached_property | ||
def transform(self): | ||
raise NotImplementedError("Data preprocessing must be implemented in the subclass.") | ||
|
||
def encode(self, preprocessed_image) -> typing.List[float]: | ||
raise NotImplementedError("Encoding must be implemented in the subclass") | ||
|
||
@property | ||
def cache_dir(self): | ||
return os.path.join(self.question_dir, self.CACHE_DIR) | ||
|
||
def __init__(self, weights_path: str = ""): | ||
self.weights_path = weights_path | ||
self.vector_cache = set() | ||
|
||
def _preprocess_image(self, image_path: str): | ||
"""Applies transform to image.""" | ||
image_tensor = Image.open(image_path).convert('RGB') | ||
return self.transform(image_tensor) | ||
|
||
def encode_images(self, batch_size: int = 128): | ||
"""Applies transform to image.""" | ||
counter = 0 | ||
if os.path.exists(self.cache_dir): | ||
shutil.rmtree(self.cache_dir) | ||
os.mkdir(self.cache_dir) | ||
image_vectors = [] | ||
image_names = [] | ||
type = f'{self.question_dir}/*/*.jp*' | ||
num_images = len(glob.glob(type)) | ||
print(f"{self.question_dir}/*/*.jpg", num_images) | ||
|
||
for image_path in glob.glob(type): | ||
file_name = os.path.basename(image_path) | ||
counter += 1 | ||
if file_name in self.vector_cache: | ||
continue | ||
image_vectors.append(self._preprocess_image(image_path)) | ||
image_names.append(file_name) | ||
self.vector_cache.add(file_name) | ||
if len(image_vectors) == batch_size: | ||
preprocessed_images = self.encode(torch.stack(image_vectors)) | ||
for i in range(batch_size): | ||
np.save(os.path.join(self.cache_dir, f"{image_names[i]}.npy"), preprocessed_images[i,:]) | ||
image_vectors, image_names = [], [] | ||
print(f"Preprocessed {counter}/{num_images} images.") | ||
for i, image_name in enumerate(image_names): | ||
np.save(os.path.join(self.cache_dir, f"{image_name}.npy"), preprocessed_images[i,:]) | ||
|
||
def _preprocess_image_list(self, | ||
image_path_list: typing.List[str], | ||
expected_length: int) -> typing.List[np.array]: | ||
"""Feature extraction for image list.""" | ||
image_vectors = [] | ||
for query_image in image_path_list: | ||
file_name = os.path.basename(query_image) | ||
if file_name in self.vector_cache: | ||
query_vector = np.load(os.path.join(self.cache_dir, f"{file_name}.npy")) | ||
else: | ||
preprocessed_image = self._preprocess_image(query_image) | ||
query_vector = np.squeeze(self.encode(preprocessed_image)) | ||
image_vectors.append(query_vector) | ||
assert len(image_vectors) == expected_length | ||
return image_vectors | ||
|
||
@classmethod | ||
def get_k_nearest_neighbors(cls, query_vectors, answer_vectors, k) -> typing.Tuple[np.array, np.array]: | ||
"""Gets numpy arrays representing distance and number of k nearest answer vectors to query vectors""" | ||
start = time.time() | ||
index = faiss.IndexFlatL2(query_vectors.shape[-1]) # build the index | ||
index.add(np.stack(answer_vectors)) | ||
distance, indices = index.search(query_vectors, k) | ||
print(f"Neighbour search took {time.time() - start}") | ||
return distance, indices | ||
|
||
def question1(self, | ||
query_image_paths: typing.List[str], | ||
answer_image_paths: typing.List[str], | ||
): | ||
raise NotImplementedError("Implement in subclass.") | ||
|
||
def question2(self, | ||
query_image_paths: typing.List[str], | ||
answer_image_paths: typing.List[str], | ||
): | ||
raise NotImplementedError("Implement in subclass.") | ||
|
||
def question3(self, | ||
query_image_paths1: typing.List[str], | ||
query_image_paths2: typing.List[str], | ||
answer_image_paths: typing.List[str], | ||
): | ||
raise NotImplementedError("Implement in subclass.") | ||
|
||
def question4(self, | ||
query_image_paths: typing.List[str], | ||
answer_image_paths: typing.List[str], | ||
): | ||
raise NotImplementedError("Implement in subclass.") | ||
|
||
def group2(self, | ||
query_image_paths: typing.List[str], | ||
answer_image_paths: typing.List[str], | ||
): | ||
raise NotImplementedError("Implement in subclass.") |
Oops, something went wrong.