Skip to content

Commit

Permalink
pull request (#7)
Browse files Browse the repository at this point in the history
* test

* pretrained=false

* main_supcon.py add wandb

* question loader

* utils

* fscore

* calculator.py

* dataset-statis.py

* mains

* gitignore

* script

* encoder_models/

* igmore

* Delete main_f1score.py

delete

* change

* Delete train_supcon_50p_group4.sh

train_supcon_50p_grroup4 delete

* gitignore

* calculator remove comment

* compare better way

* file_name not using

* f1_graph remove

* question loader try remove

* inference group4

* main_supcon

* reslove epsilon

* remove continuing try exception block

* remove print statement

* remove try exception for removing damage jpg,jpeg

* jpeg jpg include

* remove resnet from main_supcon

* spaces after each comma inference_g4.py

* delete unsed comment

* multiple line vioklation

* remove unused import

* remove unsed try except question_loader
  • Loading branch information
SlowMonk authored Nov 21, 2022
1 parent 402cb2d commit 898ab02
Show file tree
Hide file tree
Showing 17 changed files with 1,290 additions and 28 deletions.
16 changes: 15 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -130,4 +130,18 @@ dmypy.json

# Pyre type checker
.pyre/
save/
save/
wandb/
#encoder_models/
score_utils/

# files
./notebooks/
./sh_files/
./make_resize.py
notebooks/
sh_files/
utils/
main_f1score.py
train_supcon50_group4.sh
inference_group4.sh
118 changes: 118 additions & 0 deletions calculator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@

import typing
import sys

def increment(*args) -> typing.List[int]:
"""Increment all given numbers"""
return [arg + 1 for arg in args]


def print_results(TTP: int, TFP: int, TFN: int, result_dict: typing.Dict) -> typing.Tuple[float, float]:
"""Print results."""
fscores = []
for k, v in result_dict.items():
TP = v['TP']
FP = v['FP']
FN = v['FN']
precision = TP / (TP + FP)
recall = TP / (TP + FN)
try:
f = 2 * ((precision * recall) / (precision + recall))
except ZeroDivisionError:
f = 0
print(f'{k}: {f}')
fscores.append(f)
print(f"Macro f score: {sum(fscores) / len(fscores)}")

precision = TTP / (TTP + TFP)
recall = TTP / (TTP + TFN)

f = 2 * ((precision * recall) / (precision + recall))
print(f'Micro f score {f}')
return sum(fscores) / len(fscores), f


def calculate1(answer_file: str) -> typing.Tuple[float, float]:
"""Calculate Fscore for problem1"""
TTP, TFP, TFN = 0, 0, 0 # Total true positive, true negatives, false positives
result_dict = {}
header_skipped = False
with open(answer_file) as fh:
for line in fh:
if not header_skipped:
header_skipped = True
continue
_, cat, gt, ans = line.replace('\n', "").split(',')
_, cat, gt, ans = _, cat, int(gt), int(bool(ans))
print(f'gt->{gt}, ans->{ans}', gt == ans )
if cat not in result_dict:
result_dict[cat] = dict(TP=0, FN=0, FP=0)
if ans == gt:
result_dict[cat]['TP'], TTP = increment(result_dict[cat]['TP'], TTP)
else:
result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN = \
increment(result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN)
return print_results(TTP, TFP, TFN, result_dict)


def calculate3(answer_file: str) -> typing.Tuple[float, float]:
"""Calculate Fscore for problem3"""
TTP, TFP, TFN = 0, 0, 0 # Total true positive, true negatives, false positives
result_dict = {}
header_skipped = False
with open(answer_file) as fh:
for line in fh:
if not header_skipped:
header_skipped = True
continue
_, cat, qgt, agt, qans, ans = line.replace('\n', "").split(',')
if cat not in result_dict:
result_dict[cat] = dict(TP=0, FN=0, FP=0)
if ans == agt:
result_dict[cat]['TP'], TTP = increment(result_dict[cat]['TP'], TTP)
else:
result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN = \
increment(result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN)
return print_results(TTP, TFP, TFN, result_dict)


def calculate4(answer_file: str) -> typing.Tuple[float, float]:
"""Calculate Fscore for problem4"""
TTP, TFP, TFN = 0, 0, 0 # Total true positive, true negatives, false positives
result_dict = {}
header_skipped = False
with open(answer_file) as fh:
for line in fh:
if not header_skipped:
header_skipped = True
continue
_, cat, gt1, gt2, ans1, ans2 = line.replace('\n', "").split(',')
if cat not in result_dict:
result_dict[cat] = dict(TP=0, FN=0, FP=0)
for ans in ans1, ans2:
if ans in [gt1, gt2]:
result_dict[cat]['TP'], TTP = increment(result_dict[cat]['TP'], TTP)
else:
result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN = \
increment(result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN)
return print_results(TTP, TFP, TFN, result_dict)

def calculate4_both(answer_file: str) -> typing.Tuple[float, float]:
"""Calculate Fscore for problem4"""
TTP, TFP, TFN = 0, 0, 0 # Total true positive, true negatives, false positives
result_dict = {}
header_skipped = False
with open(answer_file) as fh:
for line in fh:
if not header_skipped:
header_skipped = True
continue
_, cat, gt1, gt2, ans1, ans2 = line.replace('\n', "").split(',')
if cat not in result_dict:
result_dict[cat] = dict(TP=0, FN=0, FP=0)
if sorted([ans1,ans2]) == sorted([gt1,gt2]):
result_dict[cat]['TP'], TTP = increment(result_dict[cat]['TP'], TTP)
else:
result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN = \
increment(result_dict[cat]['FP'], TFP, result_dict[cat]['FN'], TFN)
return print_results(TTP, TFP, TFN, result_dict)
46 changes: 46 additions & 0 deletions dataset_statistics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import argparse
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from tqdm import tqdm
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

def main(image_dir: str):

"""Get mean and std."""
dataloader = torch.utils.data.DataLoader(
datasets.ImageFolder(root=image_dir, transform=transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])), batch_size=4
)
channels_sum, channels_squared_sum, num_batches = 0, 0, 0
count = 0
for data, _ in tqdm(dataloader):

# Mean over batch, height and width, but not over the channels
channels_sum += torch.mean(data, dim=[0, 2, 3])
channels_squared_sum += torch.mean(data ** 2, dim=[0, 2, 3])
num_batches += 1
count +=1
if count%100==0:
mean = channels_sum / num_batches
std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5
print(f'count->{count}, mean->{mean}, std->{std}')
mean = channels_sum / num_batches

# std = sqrt(E[X^2] - (E[X])^2)
std = (channels_squared_sum / num_batches - mean ** 2) ** 0.5
print(f"Mean: {mean}")
print(f"Std: {std}")
return mean, std


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--image_dir', type=str)
args = parser.parse_args()
main(
image_dir=args.image_dir,
)
Empty file added encoder_models/__init__.py
Empty file.
122 changes: 122 additions & 0 deletions encoder_models/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os.path
import shutil
import typing
import numpy as np
import faiss
import torch
import glob
from cached_property import cached_property
from PIL import Image
import time


class BaseModel:

NAME = "Please override in subclass"
CACHE_DIR = 'cache'

@cached_property
def transform(self):
raise NotImplementedError("Data preprocessing must be implemented in the subclass.")

def encode(self, preprocessed_image) -> typing.List[float]:
raise NotImplementedError("Encoding must be implemented in the subclass")

@property
def cache_dir(self):
return os.path.join(self.question_dir, self.CACHE_DIR)

def __init__(self, weights_path: str = ""):
self.weights_path = weights_path
self.vector_cache = set()

def _preprocess_image(self, image_path: str):
"""Applies transform to image."""
image_tensor = Image.open(image_path).convert('RGB')
return self.transform(image_tensor)

def encode_images(self, batch_size: int = 128):
"""Applies transform to image."""
counter = 0
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir)
os.mkdir(self.cache_dir)
image_vectors = []
image_names = []
type = f'{self.question_dir}/*/*.jp*'
num_images = len(glob.glob(type))
print(f"{self.question_dir}/*/*.jpg", num_images)

for image_path in glob.glob(type):
file_name = os.path.basename(image_path)
counter += 1
if file_name in self.vector_cache:
continue
image_vectors.append(self._preprocess_image(image_path))
image_names.append(file_name)
self.vector_cache.add(file_name)
if len(image_vectors) == batch_size:
preprocessed_images = self.encode(torch.stack(image_vectors))
for i in range(batch_size):
np.save(os.path.join(self.cache_dir, f"{image_names[i]}.npy"), preprocessed_images[i,:])
image_vectors, image_names = [], []
print(f"Preprocessed {counter}/{num_images} images.")
for i, image_name in enumerate(image_names):
np.save(os.path.join(self.cache_dir, f"{image_name}.npy"), preprocessed_images[i,:])

def _preprocess_image_list(self,
image_path_list: typing.List[str],
expected_length: int) -> typing.List[np.array]:
"""Feature extraction for image list."""
image_vectors = []
for query_image in image_path_list:
file_name = os.path.basename(query_image)
if file_name in self.vector_cache:
query_vector = np.load(os.path.join(self.cache_dir, f"{file_name}.npy"))
else:
preprocessed_image = self._preprocess_image(query_image)
query_vector = np.squeeze(self.encode(preprocessed_image))
image_vectors.append(query_vector)
assert len(image_vectors) == expected_length
return image_vectors

@classmethod
def get_k_nearest_neighbors(cls, query_vectors, answer_vectors, k) -> typing.Tuple[np.array, np.array]:
"""Gets numpy arrays representing distance and number of k nearest answer vectors to query vectors"""
start = time.time()
index = faiss.IndexFlatL2(query_vectors.shape[-1]) # build the index
index.add(np.stack(answer_vectors))
distance, indices = index.search(query_vectors, k)
print(f"Neighbour search took {time.time() - start}")
return distance, indices

def question1(self,
query_image_paths: typing.List[str],
answer_image_paths: typing.List[str],
):
raise NotImplementedError("Implement in subclass.")

def question2(self,
query_image_paths: typing.List[str],
answer_image_paths: typing.List[str],
):
raise NotImplementedError("Implement in subclass.")

def question3(self,
query_image_paths1: typing.List[str],
query_image_paths2: typing.List[str],
answer_image_paths: typing.List[str],
):
raise NotImplementedError("Implement in subclass.")

def question4(self,
query_image_paths: typing.List[str],
answer_image_paths: typing.List[str],
):
raise NotImplementedError("Implement in subclass.")

def group2(self,
query_image_paths: typing.List[str],
answer_image_paths: typing.List[str],
):
raise NotImplementedError("Implement in subclass.")
Loading

0 comments on commit 898ab02

Please sign in to comment.