-
Notifications
You must be signed in to change notification settings - Fork 2
/
eval.py
executable file
·228 lines (184 loc) · 7.71 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
from __future__ import print_function
import argparse
import os
from models.model_transmil import TransMIL
# internal imports
from utils.file_utils import save_pkl
from utils.utils import *
from wsi_datasets.classification import WSI_Classification_Dataset
from data.cls_default import label_dicts
# pytorch imports
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import pandas as pd
import numpy as np
import json
from models.model_clam import CLAM_SB
from utils.core_utils import validate
torch.multiprocessing.set_sharing_strategy('file_system')
def seed_torch(seed=7):
"""
Set deterministic seed
Args:
seed (int): The seed to be set
Returns:
None
"""
import random
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if device.type == 'cuda':
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def read_splits(split_dir, fold_idx = None):
"""
Read the splits dataframes
Args:
split_dir (str): Path to splits directory
fold_idx (int): Read splits for which fold. If None, then not added to name
Returns:
splits: dictionary with splits
"""
splits = {}
for split in ['train', 'val', 'test']:
if fold_idx is not None:
split_path = os.path.join(split_dir, split + f'_{fold_idx}.csv')
else:
split_path = os.path.join(split_dir, split + '.csv')
if os.path.isfile(split_path):
df = pd.read_csv(split_path)
splits[split] = df
return splits
def build_datasets(splits, **kwargs):
"""
Build dataloaders for each split
Args:
splits (dict): dictionary with splits dfs
Returns:
splits: dataloaders for splits
"""
for k, df in splits.items():
dataset = WSI_Classification_Dataset(df, **kwargs)
dataloader = DataLoader(dataset, batch_size=1, shuffle = (k == 'train'), num_workers=2)
splits[k] = dataloader
print(f'split: {k}, n: {len(dataloader)}')
return splits
def merge_dict(main_dict, new_dict):
"""
Merges two dictionaries
Args:
main_dict (dict): The dictionary to which values will be appended.
new_dict (dict): The dictionary from which values will be retrieved and added
to main_dict.
Returns:
dict: The merged dictionary with values appended or added as necessary.
"""
for k, v in new_dict.items():
if k not in main_dict:
main_dict[k] = []
main_dict[k].append(v)
return main_dict
def main(args):
"""
Main function
Args:
args (Namespace): Experiment arguments
Returns:
None
"""
dataset_kwargs = dict(data_source = args.data_source,
label_map = args.label_map,
target_col = args.target_col,
label_map_race = args.label_map_race,
)
splits = read_splits(args.split_dir, fold_idx=None)
print('successfully read splits for: ', list(splits.keys()))
# reading all splits, but only want test as this is eval
loader = build_datasets(splits, **dataset_kwargs)['test']
print('\nInit Model...', end=' ')
model_dict = {"in_dim": args.in_dim,
'n_classes': args.n_classes}
if args.model_size is not None and args.model_type != 'mil':
model_dict.update({"size_arg": args.model_size})
if args.model_type in ['clam_sb', 'clam_mb']:
if args.model_type =='clam_sb':
model = CLAM_SB(**model_dict)
else:
raise NotImplementedError
elif args.model_type == "transmil":
model_dict = {"n_classes": args.n_classes, "in_dim": args.in_dim}
model = TransMIL(**model_dict)
model.cuda()
print_network(model)
if os.path.isdir(args.ckpt_path):
print('looking for ckpts in a directory!')
ckpts = [file for file in os.listdir(args.ckpt_path) if os.path.splitext(file)[-1] == '.pth']
ckpts = [os.path.join(args.ckpt_path, file) for file in ckpts]
else:
ckpts = [args.ckpt_path]
tags = [os.path.splitext(os.path.basename(ckpt))[0] for ckpt in ckpts]
all_test_results = {}
for ckpt, tag in zip(ckpts, tags):
print("Going for checkpoint {}".format(tag))
state_dict = torch.load(ckpt, map_location='cpu')['model']
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
print('missing keys: ', missing_keys)
print('unexpected keys: ', unexpected_keys)
test_results, test_dump = validate(model, loader, loss_fn = nn.CrossEntropyLoss(),
print_every = args.print_every, dump_results = True, eval_mode=True)
print('finished eval ckpt: ', os.path.basename(ckpt))
print(test_results)
all_test_results = merge_dict(all_test_results, test_results)
filename = os.path.join(args.results_dir, f'{tag}_results.pkl')
save_pkl(filename, test_dump)
final_dict = {}
final_dict = {'tags': tags}
final_dict.update({k + '_test':v for k,v in all_test_results.items()})
with open(os.path.join(args.results_dir, 'summary.json'), 'w') as f:
f.write(json.dumps(final_dict, sort_keys=True, indent=4))
final_df = pd.DataFrame(final_dict)
save_name = 'summary.csv'
final_df.to_csv(os.path.join(args.results_dir, save_name))
# Generic eval settings
print("parsing arguments")
parser = argparse.ArgumentParser(description='Configurations for WSI Eval')
parser.add_argument('--results_dir', default='./eval_results', help='results directory (default: ./results)')
parser.add_argument('--split_dir', type=str, default=None,
help='manually specify the set of splits to use')
parser.add_argument('--data_source', type=str, default=None,
help='')
parser.add_argument('--ckpt_path', type=str, default=None,
help='manually specify the path to model checkpoint')
parser.add_argument('--target_col', type=str, default='label')
parser.add_argument('--model_type', type=str,
help='type of model (default: clam_sb, clam w/ single attention branch)')
parser.add_argument('--exp_code', type=str, help='experiment code for saving results')
parser.add_argument('--model_size', type=str, choices=['small', 'big'], default='small', help='size of model')
parser.add_argument('--task', type=str)
parser.add_argument('--in_dim', default=1024, type=int, help='dim of input features')
parser.add_argument('--print_every', default=100, type=int, help='how often to print')
if __name__ == "__main__":
args = parser.parse_args()
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
args.label_map = label_dicts[args.task]
args.label_map_race = label_dicts["race_map"]
args.n_classes_race = len(set(list(args.label_map_race.values())))
print('label map race: ', args.label_map_race)
print('task: ', args.task)
print('label map: ', args.label_map)
args.n_classes = len(set(list(args.label_map.values())))
print('split_dir: ', args.split_dir)
args.results_dir = os.path.join(args.results_dir, str(args.exp_code))
os.makedirs(args.results_dir, exist_ok=True)
print("\n################### Settings ###################")
for key, val in vars(args).items():
print("{}: {}".format(key, val))
with open(os.path.join(args.results_dir, 'config.json'), 'w') as f:
f.write(json.dumps(vars(args), sort_keys=True, indent=4))
results = main(args)
print("finished!")