-
Notifications
You must be signed in to change notification settings - Fork 0
/
uaas.py
213 lines (159 loc) · 6.74 KB
/
uaas.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import torch
def compute_policy_loss_reinforce(logps, returns):
"""
Function for computing the policy loss for the REINFORCE algorithm. See
4.2 of lecture notes.
logps: log probabilities for each time step. Shape: (T,)
returns: total return for each time step. Shape: (T,)
----
return : tensor.float Shape: [T,]
policy loss for each timestep
"""
policy_loss = torch.tensor(0)
#### TODO: complete policy loss (10 pts) ###
# HINT: Recall, that we want to perform gradient ASCENT to maximize returns
policy_loss = -torch.sum(logps * returns)
############################################
return policy_loss
def compute_policy_loss_with_baseline(logps, advantages):
"""
Computes policy loss with added baseline term. Refer to 4.3 in Lecture Notes.
logps: computed log probabilities. shape (T,)
advantages: computed advantages. shape: (T,)
---
return policy loss computed with baseline term: tensor.float. Shape (,1)
refer to 4.3- Baseline in lecture notes
"""
policy_loss = 0
### TODO: implement the policy loss (5 pts) ##############
policy_loss = compute_policy_loss_reinforce(logps, advantages)
##################################################
return policy_loss
class UAASParameterUpdate:
def __init__(self, alpha, epsilon):
self.q_j = 0
self.j = 1
self.alpha = alpha
self.epsilon = epsilon
def step_size(self):
return self.j ** (-0.5 + self.epsilon)
def __call__(self, optimizer, acmodel, sb, args):
"""
optimizer: Optimizer function used to perform gradient updates to model. torch.optim.Optimizer
acmodel: Network used to compute policy. torch.nn.Module
sb: stores experience data. Refer to "collect_experiences". dict
args: Config arguments. Config
return output logs : dict
"""
dist, vals = acmodel(sb["obs"])
logps = dist.log_prob(sb["action"])
val_nograd = sb["value"]
reward = sb["discounted_reward"]
val_t1 = torch.roll(val_nograd, shifts=-1, dims=0)
val_t1[-1] = 0
reduced_reward = sb["reward"] + args.discount * val_t1
score = (val_nograd - reward) * (val_nograd - reward)
indices = []
for x in score[1:]:
s = x.item()
self.j += 1
self.q_j += self.step_size() * ((1 if self.q_j <= s else 0) - self.alpha)
indices.append(1 if self.q_j <= s else 0)
indices.append(0)
reward_prime = torch.stack([reward, reduced_reward])[indices]
advantage = reward_prime - val_nograd
# computes policy loss
policy_loss = compute_policy_loss_with_baseline(logps, advantage)
update_policy_loss = policy_loss.item()
value_loss = torch.norm(reward - vals, p=2)
update_value_loss = value_loss.item()
loss = value_loss + policy_loss
# Update actor-critic
optimizer.zero_grad()
loss.backward()
# Perform gradient clipping for stability
update_grad_norm = (
sum(p.grad.data.norm(2) ** 2 for p in acmodel.parameters()) ** 0.5
)
torch.nn.utils.clip_grad_norm_(acmodel.parameters(), args.max_grad_norm)
optimizer.step()
# Log some values
logs = {
"policy_loss": update_policy_loss,
"grad_norm": update_grad_norm,
"value_loss": update_value_loss,
}
return logs
def update_parameters_with_baseline(optimizer, acmodel, sb, args):
"""
Updates model parameters using value and policy functions
optimizer: Optimizer function used to perform gradient updates to model. torch.optim.Optimizer
acmodel: Network used to compute policy. torch.nn.Module
sb: stores experience data. Refer to "collect_experiences". dict
args: Config arguments
"""
def _compute_value_loss(values, returns):
"""
Computes the value loss of critic model. See 4.3 of Lecture Notes
values: computed values from critic model shape: (T,)
returns: discounted rewards. shape: (T,)
---
computes loss of value function. See 4.3, eq. 11 in lecture notes : tensor.float. Shape (,1)
"""
value_loss = 0
### TODO: implement the value loss (5 pts) ###############
value_loss = torch.norm(returns - values, p=2)
##################################################
return value_loss
logps, advantage, values, reward = None, None, None, None
dist, values = acmodel(sb["obs"])
logps = dist.log_prob(sb["action"])
advantage = sb["advantage_gae"] if args.use_gae else sb["advantage"]
reward = sb["discounted_reward"]
policy_loss = compute_policy_loss_with_baseline(logps, advantage)
value_loss = _compute_value_loss(values, reward)
loss = policy_loss + value_loss
update_policy_loss = policy_loss.item()
update_value_loss = value_loss.item()
# Update actor-critic
optimizer.zero_grad()
loss.backward()
update_grad_norm = (
sum(p.grad.data.norm(2) ** 2 for p in acmodel.parameters()) ** 0.5
)
torch.nn.utils.clip_grad_norm_(acmodel.parameters(), args.max_grad_norm)
optimizer.step()
# Log some values
logs = {
"policy_loss": update_policy_loss,
"value_loss": update_value_loss,
"grad_norm": update_grad_norm,
}
return logs
def update_parameters_reinforce(optimizer, acmodel, sb, args):
"""
optimizer: Optimizer function used to perform gradient updates to model. torch.optim.Optimizer
acmodel: Network used to compute policy. torch.nn.Module
sb: stores experience data. Refer to "collect_experiences". dict
args: Config arguments. Config
return output logs : dict
"""
# logps is the log probability for taking an action for each time step. Shape (T,)
logps, reward = None, None
### TODO: compute logps and reward from acmodel, sb['obs'], sb['action'], and sb['reward'] ###
### If args.use_discounted_reward is True, use sb['discounted_reward'] instead. ##############
### (10 pts) #########################################
dist, val = acmodel(sb["obs"])
logps = dist.log_prob(sb["action"])
reward = sb["discounted_reward"] if args.use_discounted_reward else sb["reward"]
reward = (reward - reward.mean()) / (reward.std() + 1e-10)
##############################################################################################
# computes policy loss
policy_loss = compute_policy_loss_reinforce(logps, reward)
# Update actor-critic
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()
# Log some values
logs = {"policy_loss": policy_loss.item()}
return logs