-
Notifications
You must be signed in to change notification settings - Fork 0
/
sarsa.py
36 lines (29 loc) · 1.33 KB
/
sarsa.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import numpy as np
from frozenlake import FrozenLake
import numpy as np
from chooseAction import choose_action
def sarsa(env, max_episodes, eta, gamma, epsilon, seed=None):
random_state = np.random.RandomState(seed)
eta = np.linspace(eta, 0, max_episodes)
# eta is the learning rate decay linearly eta[i] is the learning rate for episode i
epsilon = np.linspace(epsilon, 0, max_episodes)
# epsilon is decay linearly espilon[i] is the for episode i
q = np.zeros((env.n_states, env.n_actions))
for i in range(max_episodes):
s = env.reset()
# select action based on epsilon greedy policy
action = choose_action(epsilon[i], random_state, q[s])
done = False
# iterate till terminal state is not reached
while not done:
# observed reward and next state for given action
next_state, reward, done = env.step(action)
# select action based on epsilon greedy policy
next_action = choose_action(epsilon[i], random_state, q[s])
#update value of q[s, action] with decay of learning rate and epsilon
q[s, action] += eta[i]* (reward + (gamma * q[next_state, next_action]-q[s, action]))
s = next_state
action = next_action
policy = q.argmax(axis=1)
value = q.max(axis=1)
return policy, value