-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_mean_reward_my.py
34 lines (28 loc) · 1.16 KB
/
get_mean_reward_my.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator
import numpy as np
import os
def get_mean_last10(tfevent_file, teacher=True):
ea = event_accumulator.EventAccumulator(tfevent_file)
ea.Reload()
if teacher:
test_rew = ea.scalars.Items('test/rew')
else:
test_rew = ea.scalars.Items('test_student/rew')
test_rew_value = [i.value for i in test_rew]
return np.mean(test_rew_value[-10:]), max(test_rew_value)
dir_path = '/home/zhoufu/drl_iot/RL_IoT_distillation/advance/log'
logs = {}
for path, dir_list, file_list in os.walk(dir_path):
for file_name in file_list:
file_path = os.path.join(path, file_name)
if 'events' in file_path and 'net1' in file_path:
print(file_path)
env_name = file_path.split('/')[-4]
logs[env_name] = file_path
print('Teacher: ')
for name, path in logs.items():
print('env_name:{}, mean reward:{}, max:{}'.format(name, *(get_mean_last10(path, teacher=True))))
print('Student: ')
for name, path in logs.items():
print('env_name:{}, mean reward:{}, max:{}'.format(name, *(get_mean_last10(path, teacher=False))))