-
Notifications
You must be signed in to change notification settings - Fork 0
/
environment.py
33 lines (22 loc) · 915 Bytes
/
environment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from environementModel import EnvironmentModel
import numpy as np
class Environment(EnvironmentModel):
def __init__(self, n_states, n_actions, max_steps, pi, seed=None):
EnvironmentModel.__init__(self, n_states, n_actions, seed)
self.max_steps = max_steps
self.pi = pi
if self.pi is None:
self.pi = np.full(n_states, 1./n_states)
def reset(self):
self.n_steps = 0
self.state = self.random_state.choice(self.n_states, p=self.pi)
return self.state
def step(self, action):
if action < 0 or action >= self.n_actions:
raise Exception('Invalid action.')
self.n_steps += 1
done = (self.n_steps >= self.max_steps)
self.state, reward = self.draw(self.state, action)
return self.state, reward, done
def render(self, policy=None, value=None):
raise NotImplementedError()