-
Notifications
You must be signed in to change notification settings - Fork 0
/
show_classifier_graphs.py
87 lines (67 loc) · 2.77 KB
/
show_classifier_graphs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# -*- coding: utf-8 -*-
"""
Created on Wed Nov 9 17:03:26 2022
@author: Antonio Vispi
"""
import argparse
import os
import json
import matplotlib.pyplot as plt
import numpy as np
def load_json_arr(json_path):
lines = []
with open(json_path, 'r') as f:
for line in f:
lines.append(json.loads(line))
return lines
def training_graphs(path_in_json,path_out):
os.makedirs(path_out, exist_ok = True)
experiment_metrics = load_json_arr(path_in_json)
Val_Losses=[]
Train_Losses=[]
for sample in experiment_metrics[1:len(experiment_metrics)]:
if sample['mode'] == 'val':
if 'loss' in sample:
Val_Losses.append(sample['loss'])
if sample['mode'] == 'train':
if 'loss' in sample:
Train_Losses.append(sample['loss'])
Val_Accuracy=[]
Train_Accuracy=[]
for sample in experiment_metrics[1:len(experiment_metrics)]:
if sample['mode'] == 'val':
if 'accuracy_top-1' in sample:
Val_Accuracy.append(sample['accuracy_top-1'])
if sample['mode'] == 'train':
if 'top-1' in sample:
Train_Accuracy.append(sample['top-1'])
#Definition of the abscissa axes.
Epoch_Train = np.linspace(1, len(Val_Losses), num=len(Train_Losses))
Epoch_Val = np.linspace(1, len(Val_Losses), num=len(Val_Losses))
fig, axs = plt.subplots(1,2, figsize=(20, 15))
plt.rcParams['font.size'] = 18
plt.sca(axs[0])
plt.yticks(np.arange(0, max(max(Val_Losses),max(Train_Losses)), round(max(max(Val_Losses)/50,max(Train_Losses)/50),2)))
axs[0].plot(Epoch_Val,Val_Losses,Epoch_Train,Train_Losses)
axs[0].set_title('Model Losses')
axs[0].set(xlabel='Epoch')
axs[0].set(ylabel='Loss')
axs[0].legend(('Validation', 'Training'), loc='upper right', shadow=True)
plt.sca(axs[1])
plt.yticks(np.arange(0, max(max(Val_Accuracy),max(Train_Accuracy)), 5))
axs[1].plot(Epoch_Val,Val_Accuracy,Epoch_Train,Train_Accuracy)
axs[1].set_title('Model Accuracy')
axs[1].set(xlabel='Epoch')
axs[1].set(ylabel='Accuracy(%)')
axs[1].legend(('Validation', 'Training'), loc='lower right', shadow=True)
plt.savefig(path_out+'/Graphs.png')
print('\n')
print('The graph has been saved in the destination path.')
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--path_in_json', help='Enter the path of the .json file containing all the annotations of the model training data')
parser.add_argument('--path_out', help='The path in which the graph summarizing the training data will be saved.')
args = parser.parse_args()
training_graphs(args.path_in_json,args.path_out)
if __name__ == '__main__':
main()