From 1ddb49fed585215f59d6749409233e6982167575 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 10 Nov 2022 15:42:40 -0500 Subject: [PATCH 1/3] restore hidden states added --- hive/agents/drqn.py | 123 ++++++++++++++++++++++++++++--- hive/configs/atari/drqn.yml | 2 + hive/replays/recurrent_replay.py | 55 ++++++++++++++ 3 files changed, 171 insertions(+), 9 deletions(-) diff --git a/hive/agents/drqn.py b/hive/agents/drqn.py index 28cc318f..988f102d 100644 --- a/hive/agents/drqn.py +++ b/hive/agents/drqn.py @@ -58,6 +58,8 @@ def __init__( device="cpu", logger: Logger = None, log_frequency: int = 100, + store_hidden: bool = True, + burn_frames: int = 0, **kwargs, ): """ @@ -116,9 +118,10 @@ def __init__( """ if replay_buffer is None: replay_buffer = RecurrentReplayBuffer - replay_buffer = partial(replay_buffer, max_seq_len=max_seq_len) + replay_buffer = partial( + replay_buffer, max_seq_len=max_seq_len, store_hidden=store_hidden + ) self._max_seq_len = max_seq_len - super().__init__( observation_space=observation_space, action_space=action_space, @@ -144,6 +147,8 @@ def __init__( logger=logger, log_frequency=log_frequency, ) + self._store_hidden = store_hidden + self._burn_frames = burn_frames def create_q_networks(self, representation_net): """Creates the Q-network and target Q-network. @@ -154,7 +159,20 @@ def create_q_networks(self, representation_net): of the DRQN). """ network = representation_net(self._state_size) - network_output_dim = np.prod(calculate_output_dim(network, self._state_size)[0]) + + if isinstance(network.rnn.core, torch.nn.LSTM): + self._rnn_type = "lstm" + elif isinstance(network.rnn.core, torch.nn.GRU): + self._rnn_type = "gru" + else: + raise ValueError( + f"rnn_type is wrong. Expected either lstm or gru," + f"received {network.rnn.core}." + ) + + network_output_dim = np.prod( + calculate_output_dim(network, (1,) + self._state_size)[0] + ) self._qnet = DRQNNetwork(network, network_output_dim, self._action_space.n).to( self._device ) @@ -164,6 +182,46 @@ def create_q_networks(self, representation_net): self._target_qnet = copy.deepcopy(self._qnet).requires_grad_(False) self._hidden_state = self._qnet.init_hidden(batch_size=1) + def preprocess_update_info(self, update_info): + """Preprocesses the :obj:`update_info` before it goes into the replay buffer. + Clips the reward in update_info. + Args: + update_info: Contains the information from the current timestep that the + agent should use to update itself. + """ + if self._reward_clip is not None: + update_info["reward"] = np.clip( + update_info["reward"], -self._reward_clip, self._reward_clip + ) + + preprocessed_update_info = { + "observation": update_info["observation"], + "action": update_info["action"], + "reward": update_info["reward"], + "done": update_info["done"], + } + + if self._store_hidden == True: + if self._rnn_type == "lstm": + preprocessed_update_info.update( + { + "hidden_state": self._prev_hidden_state, + "cell_state": self._prev_cell_state, + } + ) + + elif self._rnn_type == "gru": + preprocessed_update_info.update( + { + "hidden_state": self._prev_hidden_state, + } + ) + + if "agent_id" in update_info: + preprocessed_update_info["agent_id"] = int(update_info["agent_id"]) + + return preprocessed_update_info + @torch.no_grad() def act(self, observation): """Returns the action for the agent. If in training mode, follows an epsilon @@ -201,6 +259,14 @@ def act(self, observation): # Note: not explicitly handling the ties action = torch.argmax(qvals).item() + if self._store_hidden == True: + if self._rnn_type == "lstm": + self._prev_hidden_state = self._hidden_state[0].detach().cpu().numpy() + self._prev_cell_state = self._hidden_state[1].detach().cpu().numpy() + + elif self._rnn_type == "gru": + self._prev_hidden_state = self._hidden_state[0].detach().cpu().numpy() + if ( self._training and self._logger.should_log(self._timescale) @@ -243,12 +309,51 @@ def update(self, update_info): batch, ) = self.preprocess_update_batch(batch) - hidden_state = self._qnet.init_hidden( - batch_size=self._batch_size, - ) - target_hidden_state = self._target_qnet.init_hidden( - batch_size=self._batch_size, - ) + if self._store_hidden == True: + hidden_state = ( + torch.tensor( + batch["hidden_state"][:, 0].squeeze(1).squeeze(1).unsqueeze(0), + device=self._device, + ).float(), + ) + + target_hidden_state = ( + torch.tensor( + batch["next_hidden_state"][:, 0] + .squeeze(1) + .squeeze(1) + .unsqueeze(0), + device=self._device, + ).float(), + ) + + if self._rnn_type == "lstm": + hidden_state += ( + torch.tensor( + batch["cell_state"][:, 0] + .squeeze(1) + .squeeze(1) + .unsqueeze(0), + device=self._device, + ).float(), + ) + + target_hidden_state += ( + torch.tensor( + batch["next_cell_state"][:, 0] + .squeeze(1) + .squeeze(1) + .unsqueeze(0), + device=self._device, + ).float(), + ) + else: + hidden_state = self._qnet.init_hidden( + batch_size=self._batch_size, + ) + target_hidden_state = self._target_qnet.init_hidden( + batch_size=self._batch_size, + ) # Compute predicted Q values self._optimizer.zero_grad() pred_qvals, _ = self._qnet(*current_state_inputs, hidden_state) diff --git a/hive/configs/atari/drqn.yml b/hive/configs/atari/drqn.yml index 77f67f9e..e5237502 100644 --- a/hive/configs/atari/drqn.yml +++ b/hive/configs/atari/drqn.yml @@ -48,6 +48,8 @@ agent: kwargs: capacity: 1000000 gamma: &gamma .99 + rnn_type: 'lstm' + rnn_hidden_size: 128 max_seq_len: *max_seq_len discount_rate: *gamma reward_clip: 1 diff --git a/hive/replays/recurrent_replay.py b/hive/replays/recurrent_replay.py index e7beb20c..fe323789 100644 --- a/hive/replays/recurrent_replay.py +++ b/hive/replays/recurrent_replay.py @@ -24,6 +24,9 @@ def __init__( reward_dtype=np.float32, extra_storage_types=None, num_players_sharing_buffer: int = None, + rnn_type: str = "lstm", + rnn_hidden_size: int = 0, + store_hidden: bool = False, ): """Constructor for RecurrentReplayBuffer. @@ -53,6 +56,23 @@ def __init__( num_players_sharing_buffer (int): Number of agents that share their buffers. It is used for self-play. """ + if extra_storage_types is None: + extra_storage_types = {} + if store_hidden == True: + extra_storage_types["hidden_state"] = ( + np.float32, + (1, 1, rnn_hidden_size), + ) + if rnn_type == "lstm": + extra_storage_types["cell_state"] = ( + np.float32, + (1, 1, rnn_hidden_size), + ) + elif rnn_type != "gru": + raise ValueError( + f"rnn_type is wrong. Expected either lstm or gru," + f"received {rnn_type}." + ) super().__init__( capacity=capacity, stack_size=1, @@ -68,6 +88,9 @@ def __init__( num_players_sharing_buffer=num_players_sharing_buffer, ) self._max_seq_len = max_seq_len + self._rnn_type = rnn_type + self._rnn_hidden_size = rnn_hidden_size + self._store_hidden = store_hidden def size(self): """Returns the number of transitions stored in the buffer.""" @@ -221,6 +244,18 @@ def sample(self, batch_size): indices - self._max_seq_len + 1, num_to_access=self._max_seq_len, ) + elif key == "hidden_state": + batch[key] = self._get_from_storage( + "hidden_state", + indices - self._max_seq_len + 1, + num_to_access=self._max_seq_len, + ) + elif key == "cell_state": + batch[key] = self._get_from_storage( + "cell_state", + indices - self._max_seq_len + 1, + num_to_access=self._max_seq_len, + ) elif key == "done": batch["done"] = is_terminal elif key == "reward": @@ -259,4 +294,24 @@ def sample(self, batch_size): indices + trajectory_lengths - self._max_seq_len + 1, num_to_access=self._max_seq_len, ) + + if self._store_hidden == True: + batch["next_hidden_state"] = self._get_from_storage( + "hidden_state", + batch["indices"] + + batch["trajectory_lengths"] + - self._max_seq_len + + 1, # just return batch["indices"] + num_to_access=self._max_seq_len, + ) + if self._rnn_type == "lstm": + batch["next_cell_state"] = self._get_from_storage( + "cell_state", + batch["indices"] + + batch["trajectory_lengths"] + - self._max_seq_len + + 1, + num_to_access=self._max_seq_len, + ) + return batch From 419bca4b7300accd8dae482861e5cdae5d42c857 Mon Sep 17 00:00:00 2001 From: Your Name Date: Thu, 10 Nov 2022 16:10:52 -0500 Subject: [PATCH 2/3] burn in frames feature added --- hive/agents/drqn.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/hive/agents/drqn.py b/hive/agents/drqn.py index 988f102d..83fdcc00 100644 --- a/hive/agents/drqn.py +++ b/hive/agents/drqn.py @@ -370,7 +370,20 @@ def update(self, update_info): 1 - batch["done"] ) - loss = self._loss_fn(pred_qvals, q_targets).mean() + if self._burn_frames > 0: + interm_loss = self._loss_fn(pred_qvals, q_targets) + mask = torch.zeros( + self._replay_buffer._max_seq_len, + device=self._device, + dtype=torch.float, + ) + mask[self._burn_frames :] = 1.0 + mask = mask.view(1, -1) + interm_loss *= mask + loss = interm_loss.mean() + + else: + loss = self._loss_fn(pred_qvals, q_targets).mean() if self._logger.should_log(self._timescale): self._logger.log_scalar("train_loss", loss, self._timescale) From 5effcc89745328d635b14374a6fcd3b89faafbcf Mon Sep 17 00:00:00 2001 From: javinator48 Date: Mon, 14 Nov 2022 19:27:24 -0500 Subject: [PATCH 3/3] flickering atari wrapper --- hive/agents/drqn.py | 4 +++- hive/configs/atari/drqn.yml | 1 + hive/envs/atari/atari.py | 9 ++++++++- hive/envs/wrappers/__init__.py | 6 +++++- hive/envs/wrappers/gym_wrappers.py | 30 ++++++++++++++++++++++++++++++ 5 files changed, 47 insertions(+), 3 deletions(-) diff --git a/hive/agents/drqn.py b/hive/agents/drqn.py index 28cc318f..39cac297 100644 --- a/hive/agents/drqn.py +++ b/hive/agents/drqn.py @@ -154,7 +154,9 @@ def create_q_networks(self, representation_net): of the DRQN). """ network = representation_net(self._state_size) - network_output_dim = np.prod(calculate_output_dim(network, self._state_size)[0]) + network_output_dim = np.prod( + calculate_output_dim(network, (1,) + self._state_size)[0] + ) self._qnet = DRQNNetwork(network, network_output_dim, self._action_space.n).to( self._device ) diff --git a/hive/configs/atari/drqn.yml b/hive/configs/atari/drqn.yml index 77f67f9e..0f0cf5d0 100644 --- a/hive/configs/atari/drqn.yml +++ b/hive/configs/atari/drqn.yml @@ -15,6 +15,7 @@ environment: name: 'AtariEnv' kwargs: env_name: 'Asterix' + flicker_prob: 0.5 agent: name: 'DRQNAgent' diff --git a/hive/envs/atari/atari.py b/hive/envs/atari/atari.py index 186a64b9..69621986 100644 --- a/hive/envs/atari/atari.py +++ b/hive/envs/atari/atari.py @@ -5,6 +5,7 @@ from hive.envs.env_spec import EnvSpec from hive.envs.gym_env import GymEnv +from hive.envs.wrappers.gym_wrappers import FlickeringWrapper class AtariEnv(GymEnv): @@ -22,6 +23,7 @@ def __init__( frame_skip=4, screen_size=84, sticky_actions=True, + **kwargs, ): """ Args: @@ -47,7 +49,12 @@ def __init__( self.frame_skip = frame_skip self.screen_size = screen_size - super().__init__(full_env_name) + super().__init__(full_env_name, **kwargs) + + def create_env(self, env_name, flicker_prob=0, **kwargs): + super().create_env(env_name, **kwargs) + if flicker_prob: + self._env = FlickeringWrapper(self._env, flicker_prob=flicker_prob) def create_env_spec(self, env_name, **kwargs): observation_shape = self._env.observation_space.shape diff --git a/hive/envs/wrappers/__init__.py b/hive/envs/wrappers/__init__.py index b6f050cb..0de2c2f7 100644 --- a/hive/envs/wrappers/__init__.py +++ b/hive/envs/wrappers/__init__.py @@ -1 +1,5 @@ -from hive.envs.wrappers.gym_wrappers import FlattenWrapper, PermuteImageWrapper +from hive.envs.wrappers.gym_wrappers import ( + FlattenWrapper, + FlickeringWrapper, + PermuteImageWrapper, +) diff --git a/hive/envs/wrappers/gym_wrappers.py b/hive/envs/wrappers/gym_wrappers.py index 3a117015..a161e22c 100644 --- a/hive/envs/wrappers/gym_wrappers.py +++ b/hive/envs/wrappers/gym_wrappers.py @@ -76,3 +76,33 @@ def observation(self, obs): return tuple(np.transpose(o, [2, 1, 0]) for o in obs) else: return np.transpose(obs, [2, 1, 0]) + + +class FlickeringWrapper(gym.core.ObservationWrapper): + """Fully obscure the image with certain probablity.""" + + def __init__(self, env, flicker_prob=0.5): + super().__init__(env) + + self.flicker_prob = flicker_prob + if isinstance(env.observation_space, gym.spaces.Tuple): + self._is_tuple = True + self.obscured_obs = np.zeros( + shape=env.observation_space[0].shape, + dtype=np.uint8, + ) + else: + self._is_tuple = False + self.obscured_obs = np.zeros( + shape=env.observation_space.shape, + dtype=np.uint8, + ) + + def observation(self, obs): + if not np.random.binomial(n=1, p=self.flicker_prob): + return obs + + if self._is_tuple: + return tuple(self.obscured_obs for _ in obs) + else: + return self.obscured_obs