Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flickering wrapper for DRQN #307

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 128 additions & 10 deletions hive/agents/drqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(
device="cpu",
logger: Logger = None,
log_frequency: int = 100,
store_hidden: bool = True,
burn_frames: int = 0,
**kwargs,
):
"""
Expand Down Expand Up @@ -116,9 +118,10 @@ def __init__(
"""
if replay_buffer is None:
replay_buffer = RecurrentReplayBuffer
replay_buffer = partial(replay_buffer, max_seq_len=max_seq_len)
replay_buffer = partial(
replay_buffer, max_seq_len=max_seq_len, store_hidden=store_hidden
)
self._max_seq_len = max_seq_len

super().__init__(
observation_space=observation_space,
action_space=action_space,
Expand All @@ -144,6 +147,8 @@ def __init__(
logger=logger,
log_frequency=log_frequency,
)
self._store_hidden = store_hidden
self._burn_frames = burn_frames

def create_q_networks(self, representation_net):
"""Creates the Q-network and target Q-network.
Expand All @@ -154,7 +159,20 @@ def create_q_networks(self, representation_net):
of the DRQN).
"""
network = representation_net(self._state_size)
network_output_dim = np.prod(calculate_output_dim(network, self._state_size)[0])

if isinstance(network.rnn.core, torch.nn.LSTM):
self._rnn_type = "lstm"
elif isinstance(network.rnn.core, torch.nn.GRU):
self._rnn_type = "gru"
else:
raise ValueError(
f"rnn_type is wrong. Expected either lstm or gru,"
f"received {network.rnn.core}."
)

network_output_dim = np.prod(
calculate_output_dim(network, (1,) + self._state_size)[0]
)
self._qnet = DRQNNetwork(network, network_output_dim, self._action_space.n).to(
self._device
)
Expand All @@ -164,6 +182,46 @@ def create_q_networks(self, representation_net):
self._target_qnet = copy.deepcopy(self._qnet).requires_grad_(False)
self._hidden_state = self._qnet.init_hidden(batch_size=1)

def preprocess_update_info(self, update_info):
"""Preprocesses the :obj:`update_info` before it goes into the replay buffer.
Clips the reward in update_info.
Args:
update_info: Contains the information from the current timestep that the
agent should use to update itself.
"""
if self._reward_clip is not None:
update_info["reward"] = np.clip(
update_info["reward"], -self._reward_clip, self._reward_clip
)

preprocessed_update_info = {
"observation": update_info["observation"],
"action": update_info["action"],
"reward": update_info["reward"],
"done": update_info["done"],
}

if self._store_hidden == True:
if self._rnn_type == "lstm":
preprocessed_update_info.update(
{
"hidden_state": self._prev_hidden_state,
"cell_state": self._prev_cell_state,
}
)

elif self._rnn_type == "gru":
preprocessed_update_info.update(
{
"hidden_state": self._prev_hidden_state,
}
)

if "agent_id" in update_info:
preprocessed_update_info["agent_id"] = int(update_info["agent_id"])

return preprocessed_update_info

@torch.no_grad()
def act(self, observation):
"""Returns the action for the agent. If in training mode, follows an epsilon
Expand Down Expand Up @@ -201,6 +259,14 @@ def act(self, observation):
# Note: not explicitly handling the ties
action = torch.argmax(qvals).item()

if self._store_hidden == True:
if self._rnn_type == "lstm":
self._prev_hidden_state = self._hidden_state[0].detach().cpu().numpy()
self._prev_cell_state = self._hidden_state[1].detach().cpu().numpy()

elif self._rnn_type == "gru":
self._prev_hidden_state = self._hidden_state[0].detach().cpu().numpy()

if (
self._training
and self._logger.should_log(self._timescale)
Expand Down Expand Up @@ -243,12 +309,51 @@ def update(self, update_info):
batch,
) = self.preprocess_update_batch(batch)

hidden_state = self._qnet.init_hidden(
batch_size=self._batch_size,
)
target_hidden_state = self._target_qnet.init_hidden(
batch_size=self._batch_size,
)
if self._store_hidden == True:
hidden_state = (
torch.tensor(
batch["hidden_state"][:, 0].squeeze(1).squeeze(1).unsqueeze(0),
device=self._device,
).float(),
)

target_hidden_state = (
torch.tensor(
batch["next_hidden_state"][:, 0]
.squeeze(1)
.squeeze(1)
.unsqueeze(0),
device=self._device,
).float(),
)

if self._rnn_type == "lstm":
hidden_state += (
torch.tensor(
batch["cell_state"][:, 0]
.squeeze(1)
.squeeze(1)
.unsqueeze(0),
device=self._device,
).float(),
)

target_hidden_state += (
torch.tensor(
batch["next_cell_state"][:, 0]
.squeeze(1)
.squeeze(1)
.unsqueeze(0),
device=self._device,
).float(),
)
else:
hidden_state = self._qnet.init_hidden(
batch_size=self._batch_size,
)
target_hidden_state = self._target_qnet.init_hidden(
batch_size=self._batch_size,
)
# Compute predicted Q values
self._optimizer.zero_grad()
pred_qvals, _ = self._qnet(*current_state_inputs, hidden_state)
Expand All @@ -265,7 +370,20 @@ def update(self, update_info):
1 - batch["done"]
)

loss = self._loss_fn(pred_qvals, q_targets).mean()
if self._burn_frames > 0:
interm_loss = self._loss_fn(pred_qvals, q_targets)
mask = torch.zeros(
self._replay_buffer._max_seq_len,
device=self._device,
dtype=torch.float,
)
mask[self._burn_frames :] = 1.0
mask = mask.view(1, -1)
interm_loss *= mask
loss = interm_loss.mean()

else:
loss = self._loss_fn(pred_qvals, q_targets).mean()

if self._logger.should_log(self._timescale):
self._logger.log_scalar("train_loss", loss, self._timescale)
Expand Down
3 changes: 3 additions & 0 deletions hive/configs/atari/drqn.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ environment:
name: 'AtariEnv'
kwargs:
env_name: 'Asterix'
flicker_prob: 0.5

agent:
name: 'DRQNAgent'
Expand Down Expand Up @@ -48,6 +49,8 @@ agent:
kwargs:
capacity: 1000000
gamma: &gamma .99
rnn_type: 'lstm'
rnn_hidden_size: 128
max_seq_len: *max_seq_len
discount_rate: *gamma
reward_clip: 1
Expand Down
9 changes: 8 additions & 1 deletion hive/envs/atari/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from hive.envs.env_spec import EnvSpec
from hive.envs.gym_env import GymEnv
from hive.envs.wrappers.gym_wrappers import FlickeringWrapper


class AtariEnv(GymEnv):
Expand All @@ -22,6 +23,7 @@ def __init__(
frame_skip=4,
screen_size=84,
sticky_actions=True,
**kwargs,
):
"""
Args:
Expand All @@ -47,7 +49,12 @@ def __init__(
self.frame_skip = frame_skip
self.screen_size = screen_size

super().__init__(full_env_name)
super().__init__(full_env_name, **kwargs)

def create_env(self, env_name, flicker_prob=0, **kwargs):
super().create_env(env_name, **kwargs)
if flicker_prob:
self._env = FlickeringWrapper(self._env, flicker_prob=flicker_prob)

def create_env_spec(self, env_name, **kwargs):
observation_shape = self._env.observation_space.shape
Expand Down
6 changes: 5 additions & 1 deletion hive/envs/wrappers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
from hive.envs.wrappers.gym_wrappers import FlattenWrapper, PermuteImageWrapper
from hive.envs.wrappers.gym_wrappers import (
FlattenWrapper,
FlickeringWrapper,
PermuteImageWrapper,
)
30 changes: 30 additions & 0 deletions hive/envs/wrappers/gym_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,33 @@ def observation(self, obs):
return tuple(np.transpose(o, [2, 1, 0]) for o in obs)
else:
return np.transpose(obs, [2, 1, 0])


class FlickeringWrapper(gym.core.ObservationWrapper):
"""Fully obscure the image with certain probablity."""

def __init__(self, env, flicker_prob=0.5):
super().__init__(env)

self.flicker_prob = flicker_prob
if isinstance(env.observation_space, gym.spaces.Tuple):
self._is_tuple = True
self.obscured_obs = np.zeros(
shape=env.observation_space[0].shape,
dtype=np.uint8,
)
else:
self._is_tuple = False
self.obscured_obs = np.zeros(
shape=env.observation_space.shape,
dtype=np.uint8,
)

def observation(self, obs):
if not np.random.binomial(n=1, p=self.flicker_prob):
return obs

if self._is_tuple:
return tuple(self.obscured_obs for _ in obs)
else:
return self.obscured_obs
55 changes: 55 additions & 0 deletions hive/replays/recurrent_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def __init__(
reward_dtype=np.float32,
extra_storage_types=None,
num_players_sharing_buffer: int = None,
rnn_type: str = "lstm",
rnn_hidden_size: int = 0,
store_hidden: bool = False,
):
"""Constructor for RecurrentReplayBuffer.

Expand Down Expand Up @@ -53,6 +56,23 @@ def __init__(
num_players_sharing_buffer (int): Number of agents that share their
buffers. It is used for self-play.
"""
if extra_storage_types is None:
extra_storage_types = {}
if store_hidden == True:
extra_storage_types["hidden_state"] = (
np.float32,
(1, 1, rnn_hidden_size),
)
if rnn_type == "lstm":
extra_storage_types["cell_state"] = (
np.float32,
(1, 1, rnn_hidden_size),
)
elif rnn_type != "gru":
raise ValueError(
f"rnn_type is wrong. Expected either lstm or gru,"
f"received {rnn_type}."
)
super().__init__(
capacity=capacity,
stack_size=1,
Expand All @@ -68,6 +88,9 @@ def __init__(
num_players_sharing_buffer=num_players_sharing_buffer,
)
self._max_seq_len = max_seq_len
self._rnn_type = rnn_type
self._rnn_hidden_size = rnn_hidden_size
self._store_hidden = store_hidden

def size(self):
"""Returns the number of transitions stored in the buffer."""
Expand Down Expand Up @@ -221,6 +244,18 @@ def sample(self, batch_size):
indices - self._max_seq_len + 1,
num_to_access=self._max_seq_len,
)
elif key == "hidden_state":
batch[key] = self._get_from_storage(
"hidden_state",
indices - self._max_seq_len + 1,
num_to_access=self._max_seq_len,
)
elif key == "cell_state":
batch[key] = self._get_from_storage(
"cell_state",
indices - self._max_seq_len + 1,
num_to_access=self._max_seq_len,
)
elif key == "done":
batch["done"] = is_terminal
elif key == "reward":
Expand Down Expand Up @@ -259,4 +294,24 @@ def sample(self, batch_size):
indices + trajectory_lengths - self._max_seq_len + 1,
num_to_access=self._max_seq_len,
)

if self._store_hidden == True:
batch["next_hidden_state"] = self._get_from_storage(
"hidden_state",
batch["indices"]
+ batch["trajectory_lengths"]
- self._max_seq_len
+ 1, # just return batch["indices"]
num_to_access=self._max_seq_len,
)
if self._rnn_type == "lstm":
batch["next_cell_state"] = self._get_from_storage(
"cell_state",
batch["indices"]
+ batch["trajectory_lengths"]
- self._max_seq_len
+ 1,
num_to_access=self._max_seq_len,
)

return batch