-
Notifications
You must be signed in to change notification settings - Fork 5
/
visuals.py
207 lines (158 loc) · 7.53 KB
/
visuals.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
###########################################
# Suppress matplotlib user warnings
# Necessary for newer version of matplotlib
import warnings
warnings.filterwarnings("ignore", category = UserWarning, module = "matplotlib")
###########################################
#
# Display inline matplotlib plots with IPython
from IPython import get_ipython
get_ipython().run_line_magic('matplotlib', 'inline')
###########################################
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os
import ast
def calculate_safety(data):
""" Calculates the safety rating of the smartcab during testing. """
good_ratio = data['good_actions'].sum() * 1.0 / \
(data['initial_deadline'] - data['final_deadline']).sum()
if good_ratio == 1: # Perfect driving
return ("A+", "green")
else: # Imperfect driving
if data['actions'].apply(lambda x: ast.literal_eval(x)[4]).sum() > 0: # Major accident
return ("F", "red")
elif data['actions'].apply(lambda x: ast.literal_eval(x)[3]).sum() > 0: # Minor accident
return ("D", "#EEC700")
elif data['actions'].apply(lambda x: ast.literal_eval(x)[2]).sum() > 0: # Major violation
return ("C", "#EEC700")
else: # Minor violation
minor = data['actions'].apply(lambda x: ast.literal_eval(x)[1]).sum()
if minor >= len(data)/2: # Minor violation in at least half of the trials
return ("B", "green")
else:
return ("A", "green")
def calculate_reliability(data):
""" Calculates the reliability rating of the smartcab during testing. """
success_ratio = data['success'].sum() * 1.0 / len(data)
if success_ratio == 1: # Always meets deadline
return ("A+", "green")
else:
if success_ratio >= 0.90:
return ("A", "green")
elif success_ratio >= 0.80:
return ("B", "green")
elif success_ratio >= 0.70:
return ("C", "#EEC700")
elif success_ratio >= 0.60:
return ("D", "#EEC700")
else:
return ("F", "red")
def plot_trials(csv):
""" Plots the data from logged metrics during a simulation."""
data = pd.read_csv(os.path.join("logs", csv))
if len(data) < 10:
print "Not enough data collected to create a visualization."
print "At least 20 trials are required."
return
# Create additional features
data['average_reward'] = (data['net_reward'] / (data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean()
data['reliability_rate'] = (data['success']*100).rolling(window=10, center=False).mean() # compute avg. net reward with window=10
data['good_actions'] = data['actions'].apply(lambda x: ast.literal_eval(x)[0])
data['good'] = (data['good_actions'] * 1.0 / \
(data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean()
data['minor'] = (data['actions'].apply(lambda x: ast.literal_eval(x)[1]) * 1.0 / \
(data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean()
data['major'] = (data['actions'].apply(lambda x: ast.literal_eval(x)[2]) * 1.0 / \
(data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean()
data['minor_acc'] = (data['actions'].apply(lambda x: ast.literal_eval(x)[3]) * 1.0 / \
(data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean()
data['major_acc'] = (data['actions'].apply(lambda x: ast.literal_eval(x)[4]) * 1.0 / \
(data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean()
data['epsilon'] = data['parameters'].apply(lambda x: ast.literal_eval(x)['e'])
data['alpha'] = data['parameters'].apply(lambda x: ast.literal_eval(x)['a'])
# Create training and testing subsets
training_data = data[data['testing'] == False]
testing_data = data[data['testing'] == True]
plt.figure(figsize=(12,8))
###############
### Average step reward plot
###############
ax = plt.subplot2grid((6,6), (0,3), colspan=3, rowspan=2)
ax.set_title("10-Trial Rolling Average Reward per Action")
ax.set_ylabel("Reward per Action")
ax.set_xlabel("Trial Number")
ax.set_xlim((10, len(training_data)))
# Create plot-specific data
step = training_data[['trial','average_reward']].dropna()
ax.axhline(xmin = 0, xmax = 1, y = 0, color = 'black', linestyle = 'dashed')
ax.plot(step['trial'], step['average_reward'])
###############
### Parameters Plot
###############
ax = plt.subplot2grid((6,6), (2,3), colspan=3, rowspan=2)
# Check whether the agent was expected to learn
if csv != 'sim_no-learning.csv':
ax.set_ylabel("Parameter Value")
ax.set_xlabel("Trial Number")
ax.set_xlim((1, len(training_data)))
ax.set_ylim((0, 1.05))
ax.plot(training_data['trial'], training_data['epsilon'], color='blue', label='Exploration factor')
ax.plot(training_data['trial'], training_data['alpha'], color='green', label='Learning factor')
ax.legend(bbox_to_anchor=(0.5,1.19), fancybox=True, ncol=2, loc='upper center', fontsize=10)
else:
ax.axis('off')
ax.text(0.52, 0.30, "Simulation completed\nwith learning disabled.", fontsize=24, ha='center', style='italic')
###############
### Bad Actions Plot
###############
actions = training_data[['trial','good', 'minor','major','minor_acc','major_acc']].dropna()
maximum = (1 - actions['good']).values.max()
ax = plt.subplot2grid((6,6), (0,0), colspan=3, rowspan=4)
ax.set_title("10-Trial Rolling Relative Frequency of Bad Actions")
ax.set_ylabel("Relative Frequency")
ax.set_xlabel("Trial Number")
ax.set_ylim((0, maximum + 0.01))
ax.set_xlim((10, len(training_data)))
ax.set_yticks(np.linspace(0, maximum+0.01, 10))
ax.plot(actions['trial'], (1 - actions['good']), color='black', label='Total Bad Actions', linestyle='dotted', linewidth=3)
ax.plot(actions['trial'], actions['minor'], color='orange', label='Minor Violation', linestyle='dashed')
ax.plot(actions['trial'], actions['major'], color='orange', label='Major Violation', linewidth=2)
ax.plot(actions['trial'], actions['minor_acc'], color='red', label='Minor Accident', linestyle='dashed')
ax.plot(actions['trial'], actions['major_acc'], color='red', label='Major Accident', linewidth=2)
ax.legend(loc='upper right', fancybox=True, fontsize=10)
###############
### Rolling Success-Rate plot
###############
ax = plt.subplot2grid((6,6), (4,0), colspan=4, rowspan=2)
ax.set_title("10-Trial Rolling Rate of Reliability")
ax.set_ylabel("Rate of Reliability")
ax.set_xlabel("Trial Number")
ax.set_xlim((10, len(training_data)))
ax.set_ylim((-5, 105))
ax.set_yticks(np.arange(0, 101, 20))
ax.set_yticklabels(['0%', '20%', '40%', '60%', '80%', '100%'])
# Create plot-specific data
trial = training_data.dropna()['trial']
rate = training_data.dropna()['reliability_rate']
# Rolling success rate
ax.plot(trial, rate, label="Reliability Rate", color='blue')
###############
### Test results
###############
ax = plt.subplot2grid((6,6), (4,4), colspan=2, rowspan=2)
ax.axis('off')
if len(testing_data) > 0:
safety_rating, safety_color = calculate_safety(testing_data)
reliability_rating, reliability_color = calculate_reliability(testing_data)
# Write success rate
ax.text(0.40, .9, "{} testing trials simulated.".format(len(testing_data)), fontsize=14, ha='center')
ax.text(0.40, 0.7, "Safety Rating:", fontsize=16, ha='center')
ax.text(0.40, 0.42, "{}".format(safety_rating), fontsize=40, ha='center', color=safety_color)
ax.text(0.40, 0.27, "Reliability Rating:", fontsize=16, ha='center')
ax.text(0.40, 0, "{}".format(reliability_rating), fontsize=40, ha='center', color=reliability_color)
else:
ax.text(0.36, 0.30, "Simulation completed\nwith testing disabled.", fontsize=20, ha='center', style='italic')
plt.tight_layout()
plt.show()