From 8825756588e65a05fc824cc8d19e1a80256c589a Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 26 Apr 2023 18:24:35 +0800 Subject: [PATCH 1/5] feature(nyz): add basic task pipeline and polish code --- lzero/entry/train_muzero.py | 2 +- lzero/entry/train_muzero_with_gym_env.py | 2 +- lzero/mcts/buffer/game_buffer.py | 204 ++++++------ .../mcts/buffer/game_buffer_efficientzero.py | 29 +- lzero/mcts/buffer/game_buffer_muzero.py | 7 +- .../game_buffer_sampled_efficientzero.py | 7 +- lzero/middleware/__init__.py | 4 + lzero/middleware/collector.py | 290 ++++++++++++++++++ lzero/middleware/evaluator.py | 105 +++++++ lzero/middleware/helper.py | 29 ++ lzero/policy/__init__.py | 8 +- lzero/policy/alphazero.py | 9 +- lzero/policy/efficientzero.py | 131 ++++---- lzero/policy/muzero.py | 13 +- lzero/policy/sampled_efficientzero.py | 13 +- lzero/policy/scaling_transform.py | 25 +- lzero/policy/utils.py | 63 +++- lzero/worker/muzero_collector.py | 8 +- lzero/worker/muzero_evaluator.py | 8 +- .../lunarlander_disc_efficientzero_config.py | 2 +- .../cartpole_efficientzero_task_config.py | 133 ++++++++ 21 files changed, 846 insertions(+), 246 deletions(-) create mode 100644 lzero/middleware/__init__.py create mode 100644 lzero/middleware/collector.py create mode 100644 lzero/middleware/evaluator.py create mode 100644 lzero/middleware/helper.py create mode 100644 zoo/classic_control/cartpole/config/cartpole_efficientzero_task_config.py diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index c4e3fd0fa..c981813e9 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -150,7 +150,7 @@ def train_muzero( log_vars = learner.train(train_data, collector.envstep) if cfg.policy.use_priority: - replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + replay_buffer.update_priority(train_data, log_vars[0]['priority']) if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: break diff --git a/lzero/entry/train_muzero_with_gym_env.py b/lzero/entry/train_muzero_with_gym_env.py index d61818948..4c65ce09a 100644 --- a/lzero/entry/train_muzero_with_gym_env.py +++ b/lzero/entry/train_muzero_with_gym_env.py @@ -156,7 +156,7 @@ def train_muzero_with_gym_env( log_vars = learner.train(train_data, collector.envstep) if cfg.policy.use_priority: - replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + replay_buffer.update_priority(train_data, log_vars[0]['priority']) if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: break diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 385bd5ab3..a1a88e02c 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -51,12 +51,11 @@ def __init__(self, cfg: dict): default_config = self.default_config() default_config.update(cfg) self._cfg = default_config - self._cfg = cfg assert self._cfg.env_type in ['not_board_games', 'board_games'] self.replay_buffer_size = self._cfg.replay_buffer_size self.batch_size = self._cfg.batch_size - self._alpha = self._cfg.priority_prob_alpha - self._beta = self._cfg.priority_prob_beta + self.alpha = self._cfg.priority_prob_alpha + self.beta = self._cfg.priority_prob_beta self.game_segment_buffer = [] self.game_pos_priorities = [] @@ -80,6 +79,7 @@ def sample( Returns: - train_data (:obj:`List`): List of train data, including current_batch and target_batch. """ + raise NotImplementedError @abstractmethod def _make_batch(self, orig_data: Any, reanalyze_ratio: float) -> Tuple[Any]: @@ -96,98 +96,7 @@ def _make_batch(self, orig_data: Any, reanalyze_ratio: float) -> Tuple[Any]: Returns: - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch """ - pass - - def _sample_orig_data(self, batch_size: int) -> Tuple: - """ - Overview: - sample orig_data that contains: - game_segment_list: a list of game segments - pos_in_game_segment_list: transition index in game (relative index) - batch_index_list: the index of start transition of sampled minibatch in replay buffer - weights_list: the weight concerning the priority - make_time: the time the batch is made (for correctly updating replay buffer when data is deleted) - Arguments: - - batch_size (:obj:`int`): batch size - - beta: float the parameter in PER for calculating the priority - """ - assert self._beta > 0 - num_of_transitions = self.get_num_of_transitions() - if self._cfg.use_priority is False: - self.game_pos_priorities = np.ones_like(self.game_pos_priorities) - - # +1e-6 for numerical stability - probs = self.game_pos_priorities ** self._alpha + 1e-6 - probs /= probs.sum() - - # sample according to transition index - # TODO(pu): replace=True - batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False) - - if self._cfg.reanalyze_outdated is True: - # NOTE: used in reanalyze part - batch_index_list.sort() - - weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta) - weights_list /= weights_list.max() - - game_segment_list = [] - pos_in_game_segment_list = [] - - for idx in batch_index_list: - game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx] - game_segment_idx -= self.base_idx - game_segment = self.game_segment_buffer[game_segment_idx] - - game_segment_list.append(game_segment) - pos_in_game_segment_list.append(pos_in_game_segment) - - make_time = [time.time() for _ in range(len(batch_index_list))] - - orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) - return orig_data - - def _preprocess_to_play_and_action_mask( - self, game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - ): - """ - Overview: - prepare the to_play and action_mask for the target obs in ``value_obs_list`` - - to_play: {list: game_segment_batch_size * (num_unroll_steps+1)} - - action_mask: {list: game_segment_batch_size * (num_unroll_steps+1)} - """ - to_play = [] - for bs in range(game_segment_batch_size): - to_play_tmp = list( - to_play_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] + - self._cfg.num_unroll_steps + 1] - ) - if len(to_play_tmp) < self._cfg.num_unroll_steps + 1: - # NOTE: the effective to play index is {1,2}, for null padding data, we set to_play=-1 - to_play_tmp += [-1 for _ in range(self._cfg.num_unroll_steps + 1 - len(to_play_tmp))] - to_play.append(to_play_tmp) - to_play = sum(to_play, []) - - if self._cfg.model.continuous_action_space is True: - # when the action space of the environment is continuous, action_mask[:] is None. - return to_play, None - - action_mask = [] - for bs in range(game_segment_batch_size): - action_mask_tmp = list( - action_mask_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] + - self._cfg.num_unroll_steps + 1] - ) - if len(action_mask_tmp) < self._cfg.num_unroll_steps + 1: - action_mask_tmp += [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) - for _ in range(self._cfg.num_unroll_steps + 1 - len(action_mask_tmp)) - ] - action_mask.append(action_mask_tmp) - action_mask = to_list(action_mask) - action_mask = sum(action_mask, []) - - return to_play, action_mask + raise NotImplementedError @abstractmethod def _prepare_reward_value_context( @@ -206,7 +115,7 @@ def _prepare_reward_value_context( - reward_value_context (:obj:`list`): value_obs_lst, value_mask, state_index_lst, rewards_lst, game_segment_lens, td_steps_lst, action_mask_segment, to_play_segment """ - pass + raise NotImplementedError @abstractmethod def _prepare_policy_non_reanalyzed_context( @@ -222,7 +131,7 @@ def _prepare_policy_non_reanalyzed_context( Returns: - policy_non_re_context (:obj:`list`): state_index_lst, child_visits, game_segment_lens, action_mask_segment, to_play_segment """ - pass + raise NotImplementedError @abstractmethod def _prepare_policy_reanalyzed_context( @@ -239,7 +148,7 @@ def _prepare_policy_reanalyzed_context( - policy_re_context (:obj:`list`): policy_obs_lst, policy_mask, state_index_lst, indices, child_visits, game_segment_lens, action_mask_segment, to_play_segment """ - pass + raise NotImplementedError @abstractmethod def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> List[np.ndarray]: @@ -253,7 +162,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix - batch_target_values (:obj:'np.ndarray): batch of value estimation """ - pass + raise NotImplementedError @abstractmethod def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray: @@ -265,7 +174,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Returns: - batch_target_policies_re """ - pass + raise NotImplementedError @abstractmethod def _compute_target_policy_non_reanalyzed( @@ -284,7 +193,7 @@ def _compute_target_policy_non_reanalyzed( Returns: - batch_target_policies_non_re """ - pass + raise NotImplementedError @abstractmethod def update_priority( @@ -297,7 +206,98 @@ def update_priority( - train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority. - batch_priorities (:obj:`batch_priorities`): priorities to update to. """ - pass + raise NotImplementedError + + def _sample_orig_data(self, batch_size: int) -> Tuple: + """ + Overview: + sample orig_data that contains: + game_segment_list: a list of game segments + pos_in_game_segment_list: transition index in game (relative index) + batch_index_list: the index of start transition of sampled minibatch in replay buffer + weights_list: the weight concerning the priority + make_time: the time the batch is made (for correctly updating replay buffer when data is deleted) + Arguments: + - batch_size (:obj:`int`): batch size + - beta: float the parameter in PER for calculating the priority + """ + assert self.beta > 0, self.beta + num_of_transitions = self.get_num_of_transitions() + if self._cfg.use_priority is False: + self.game_pos_priorities = np.ones_like(self.game_pos_priorities) + + # +1e-6 for numerical stability + probs = self.game_pos_priorities ** self.alpha + 1e-6 + probs /= probs.sum() + + # sample according to transition index + # TODO(pu): replace=True + batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False) + + if self._cfg.reanalyze_outdated is True: + # NOTE: used in reanalyze part + batch_index_list.sort() + + weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self.beta) + weights_list /= weights_list.max() + + game_segment_list = [] + pos_in_game_segment_list = [] + + for idx in batch_index_list: + game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx] + game_segment_idx -= self.base_idx + game_segment = self.game_segment_buffer[game_segment_idx] + + game_segment_list.append(game_segment) + pos_in_game_segment_list.append(pos_in_game_segment) + + make_time = [time.time() for _ in range(len(batch_index_list))] + + orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) + return orig_data + + def _preprocess_to_play_and_action_mask( + self, game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + ): + """ + Overview: + prepare the to_play and action_mask for the target obs in ``value_obs_list`` + - to_play: {list: game_segment_batch_size * (num_unroll_steps+1)} + - action_mask: {list: game_segment_batch_size * (num_unroll_steps+1)} + """ + to_play = [] + for bs in range(game_segment_batch_size): + to_play_tmp = list( + to_play_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] + + self._cfg.num_unroll_steps + 1] + ) + if len(to_play_tmp) < self._cfg.num_unroll_steps + 1: + # NOTE: the effective to play index is {1,2}, for null padding data, we set to_play=-1 + to_play_tmp += [-1 for _ in range(self._cfg.num_unroll_steps + 1 - len(to_play_tmp))] + to_play.append(to_play_tmp) + to_play = sum(to_play, []) + + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + return to_play, None + + action_mask = [] + for bs in range(game_segment_batch_size): + action_mask_tmp = list( + action_mask_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] + + self._cfg.num_unroll_steps + 1] + ) + if len(action_mask_tmp) < self._cfg.num_unroll_steps + 1: + action_mask_tmp += [ + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) + for _ in range(self._cfg.num_unroll_steps + 1 - len(action_mask_tmp)) + ] + action_mask.append(action_mask_tmp) + action_mask = to_list(action_mask) + action_mask = sum(action_mask, []) + + return to_play, action_mask def push_game_segments(self, data_and_meta: Any) -> None: """ diff --git a/lzero/mcts/buffer/game_buffer_efficientzero.py b/lzero/mcts/buffer/game_buffer_efficientzero.py index 4ab12259e..cf8f44391 100644 --- a/lzero/mcts/buffer/game_buffer_efficientzero.py +++ b/lzero/mcts/buffer/game_buffer_efficientzero.py @@ -1,4 +1,5 @@ -from typing import Any, List +from typing import Any, List, TYPE_CHECKING +from easydict import EasyDict import numpy as np import torch @@ -10,6 +11,9 @@ from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform from .game_buffer_muzero import MuZeroGameBuffer +if TYPE_CHECKING: + from ding.policy import Policy + @BUFFER_REGISTRY.register('game_buffer_efficientzero') class EfficientZeroGameBuffer(MuZeroGameBuffer): @@ -18,22 +22,17 @@ class EfficientZeroGameBuffer(MuZeroGameBuffer): The specific game buffer for EfficientZero policy. """ - def __init__(self, cfg: dict): - super().__init__(cfg) + def __init__(self, cfg: EasyDict) -> None: """ Overview: Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key in the default configuration, the user-provided value will override the default configuration. Otherwise, the default configuration will be used. """ - default_config = self.default_config() - default_config.update(cfg) - self._cfg = default_config - assert self._cfg.env_type in ['not_board_games', 'board_games'] + super().__init__(cfg) + assert self._cfg.env_type in ['not_board_games', 'board_games'], self._cfg.env_type self.replay_buffer_size = self._cfg.replay_buffer_size self.batch_size = self._cfg.batch_size - self._alpha = self._cfg.priority_prob_alpha - self._beta = self._cfg.priority_prob_beta self.game_segment_buffer = [] self.game_pos_priorities = [] @@ -44,15 +43,16 @@ def __init__(self, cfg: dict): self.base_idx = 0 self.clear_time = 0 - def sample(self, batch_size: int, policy: Any) -> List[Any]: + def sample(self, batch_size: int, policy: 'Policy') -> List[Any]: """ Overview: - sample data from ``GameBuffer`` and prepare the current and target batch for training + Sample a mini-batch of data for training, mainly including random sampling and preparing the current and \ + target batch with/without reanalyzing operation mentioned in MuZero. Arguments: - - batch_size (:obj:`int`): batch size - - policy (:obj:`torch.tensor`): model of policy + - batch_size (:obj:`int`): The number of samples in a mini-batch. + - policy (:obj:`Policy`): The policy instance used to execute reanalyzing operation. Returns: - - train_data (:obj:`List`): List of train data + - train_data (:obj:`List`): List of sampled training data. """ policy._target_model.to(self._cfg.device) policy._target_model.eval() @@ -180,7 +180,6 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A to_play, action_mask = self._preprocess_to_play_and_action_mask( game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list ) - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] # ============================================================== diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 3a4ca14c5..e020707f4 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -22,21 +22,16 @@ class MuZeroGameBuffer(GameBuffer): """ def __init__(self, cfg: dict): - super().__init__(cfg) """ Overview: Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key in the default configuration, the user-provided value will override the default configuration. Otherwise, the default configuration will be used. """ - default_config = self.default_config() - default_config.update(cfg) - self._cfg = default_config + super().__init__(cfg) assert self._cfg.env_type in ['not_board_games', 'board_games'] self.replay_buffer_size = self._cfg.replay_buffer_size self.batch_size = self._cfg.batch_size - self._alpha = self._cfg.priority_prob_alpha - self._beta = self._cfg.priority_prob_beta self.keep_ratio = 1 self.model_update_interval = 10 diff --git a/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py b/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py index 9291ab345..3332a0c9d 100644 --- a/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py +++ b/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py @@ -19,21 +19,16 @@ class SampledEfficientZeroGameBuffer(EfficientZeroGameBuffer): """ def __init__(self, cfg: dict): - super().__init__(cfg) """ Overview: Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key in the default configuration, the user-provided value will override the default configuration. Otherwise, the default configuration will be used. """ - default_config = self.default_config() - default_config.update(cfg) - self._cfg = default_config + super().__init__(cfg) assert self._cfg.env_type in ['not_board_games', 'board_games'] self.replay_buffer_size = self._cfg.replay_buffer_size self.batch_size = self._cfg.batch_size - self._alpha = self._cfg.priority_prob_alpha - self._beta = self._cfg.priority_prob_beta self.game_segment_buffer = [] self.game_pos_priorities = [] diff --git a/lzero/middleware/__init__.py b/lzero/middleware/__init__.py new file mode 100644 index 000000000..0554906b5 --- /dev/null +++ b/lzero/middleware/__init__.py @@ -0,0 +1,4 @@ +from .collector import MuZeroCollector +from .evaluator import MuZeroEvaluator +from .data_processor import data_pusher, data_reanalyze_fetcher +from .helper import lr_scheduler, temperature_handler diff --git a/lzero/middleware/collector.py b/lzero/middleware/collector.py new file mode 100644 index 000000000..7fd26fe5e --- /dev/null +++ b/lzero/middleware/collector.py @@ -0,0 +1,290 @@ +import numpy as np +import torch +from ding.torch_utils import to_ndarray, to_tensor, to_device +from ding.utils import EasyTimer +from lzero.mcts.buffer.game_segment import GameSegment +from lzero.mcts.utils import prepare_observation + + +class MuZeroCollector: + + def __init__(self, cfg, policy, env): + self._cfg = cfg.policy + self._env = env + self._env.seed(cfg.seed) + self._policy = policy + + self._timer = EasyTimer() + self._trajectory_pool = [] + self._default_n_episode = self._cfg.n_episode + self._unroll_plus_td_steps = self._cfg.num_unroll_steps + self._cfg.td_steps + self._last_collect_iter = 0 + + def __call__(self, ctx): + trained_iter = ctx.train_iter - self._last_collect_iter + if ctx.train_iter != 0 and trained_iter < self._cfg.update_per_collect: + return + elif trained_iter == self._cfg.update_per_collect: + self._last_collect_iter = ctx.train_iter + n_episode = self._default_n_episode + temperature = ctx.collect_kwargs['temperature'] + collected_episode = 0 + env_nums = self._env.env_num + if self._env.closed: + self._env.launch() + else: + self._env.reset() + self._policy.reset() + + init_obs = self._env.ready_obs + action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} + + dones = np.array([False for _ in range(env_nums)]) + game_segments = [ + GameSegment(self._env.action_space, game_segment_length=self._cfg.game_segment_length, config=self._cfg) + for _ in range(env_nums) + ] + + last_game_segments = [None for _ in range(env_nums)] + last_game_priorities = [None for _ in range(env_nums)] + + # stacked observation windows in reset stage for init game_segments + stack_obs_windows = [[] for _ in range(env_nums)] + for i in range(env_nums): + stack_obs_windows[i] = [ + to_ndarray(init_obs[i]['observation']) for _ in range(self._cfg.model.frame_stack_num) + ] + game_segments[i].reset(stack_obs_windows[i]) + + # for priorities in self-play + search_values_lst = [[] for _ in range(env_nums)] + pred_values_lst = [[] for _ in range(env_nums)] + + # some logs + eps_ori_reward_lst, eps_reward_lst, eps_steps_lst, visit_entropies_lst = np.zeros(env_nums), np.zeros( + env_nums + ), np.zeros(env_nums), np.zeros(env_nums) + + ready_env_id = set() + remain_episode = n_episode + + return_data = [] + while True: + with self._timer: + obs = self._env.ready_obs + new_available_env_id = set(obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} + action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] + + stack_obs = [game_segments[env_id].get_obs() for env_id in ready_env_id] + stack_obs = to_ndarray(stack_obs) + stack_obs = prepare_observation(stack_obs, self._cfg.model.model_type) + stack_obs = to_tensor(stack_obs) + stack_obs = to_device(stack_obs, self._cfg.device) + + policy_output = self._policy.forward(stack_obs, action_mask, temperature) + + actions = {k: v['action'] for k, v in zip(ready_env_id, policy_output)} + distributions_dict = {k: v['distributions'] for k, v in zip(ready_env_id, policy_output)} + value_dict = {k: v['value'] for k, v in zip(ready_env_id, policy_output)} + pred_value_dict = {k: v['pred_value'] for k, v in zip(ready_env_id, policy_output)} + visit_entropy_dict = { + k: v['visit_count_distribution_entropy'] + for k, v in zip(ready_env_id, policy_output) + } + + timesteps = self._env.step(actions) + ctx.env_step += len(ready_env_id) + + for env_id, timestep in timesteps.items(): + with self._timer: + i = env_id + obs, rew, done = timestep.obs, timestep.reward, timestep.done + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) + game_segments[env_id].append( + actions[env_id], to_ndarray(obs['observation']), rew, action_mask_dict[env_id] + ) + + action_mask_dict[env_id] = to_ndarray(obs['action_mask']) + eps_reward_lst[env_id] += rew + dones[env_id] = done + visit_entropies_lst[env_id] += visit_entropy_dict[env_id] + + eps_steps_lst[env_id] += 1 + + if self._cfg.use_priority and not self._cfg.use_max_priority_for_new_data: + pred_values_lst[env_id].append(pred_value_dict[env_id]) + search_values_lst[env_id].append(value_dict[env_id]) + + del stack_obs_windows[env_id][0] + stack_obs_windows[env_id].append(to_ndarray(obs['observation'])) + + ######### + # we will save a game history if it is the end of the game or the next game history is finished. + ######### + + ######### + # if game history is full, we will save the last game history + ######### + if game_segments[env_id].is_full(): + # pad over last block trajectory + if last_game_segments[env_id] is not None: + # TODO(pu): return the one game history + self.pad_and_save_last_trajectory( + i, last_game_segments, last_game_priorities, game_segments, dones + ) + + # calculate priority + priorities = self.get_priorities(i, pred_values_lst, search_values_lst) + pred_values_lst[env_id] = [] + search_values_lst[env_id] = [] + + # the current game_segments become last_game_segment + last_game_segments[env_id] = game_segments[env_id] + last_game_priorities[env_id] = priorities + + # create new GameSegment + game_segments[env_id] = GameSegment( + self._env.action_space, game_segment_length=self._cfg.game_segment_length, config=self._cfg + ) + game_segments[env_id].reset(stack_obs_windows[env_id]) + + if timestep.done: + collected_episode += 1 + + ######### + # if it is the end of the game, we will save the game history + ######### + + # NOTE: put the penultimate game history in one episode into the _trajectory_pool + # pad over 2th last game_segment using the last game_segment + if last_game_segments[env_id] is not None: + self.pad_and_save_last_trajectory( + i, last_game_segments, last_game_priorities, game_segments, dones + ) + + # store current block trajectory + priorities = self.get_priorities(i, pred_values_lst, search_values_lst) + + # NOTE: put the last game history in one episode into the _trajectory_pool + game_segments[env_id].game_segment_to_array() + + # assert len(game_segments[env_id]) == len(priorities) + # NOTE: save the last game history in one episode into the _trajectory_pool if it's not null + if len(game_segments[env_id].reward_segment) != 0: + self._trajectory_pool.append((game_segments[env_id], priorities, dones[env_id])) + + # reset the finished env and init game_segments + if n_episode > env_nums: + init_obs = self._env.ready_obs + + if len(init_obs.keys()) != env_nums: + while env_id not in init_obs.keys(): + init_obs = self._env.ready_obs + print(f'wait the {env_id} env to reset') + + init_obs = init_obs[env_id]['observation'] + init_obs = to_ndarray(init_obs) + action_mask_dict[env_id] = to_ndarray(init_obs[env_id]['action_mask']) + + game_segments[env_id] = GameSegment( + self._env.action_space, game_segment_length=self._cfg.game_segment_length, config=self._cfg + ) + stack_obs_windows[env_id] = [init_obs for _ in range(self._cfg.model.frame_stack_num)] + game_segments[env_id].init(stack_obs_windows[env_id]) + last_game_segments[env_id] = None + last_game_priorities[env_id] = None + + pred_values_lst[env_id] = [] + search_values_lst[env_id] = [] + eps_steps_lst[env_id] = 0 + eps_reward_lst[env_id] = 0 + eps_ori_reward_lst[env_id] = 0 + visit_entropies_lst[env_id] = 0 + + self._policy.reset([env_id]) + ready_env_id.remove(env_id) + + if collected_episode >= n_episode: + L = len(self._trajectory_pool) + return_data = [self._trajectory_pool[i][0] for i in range(L)], [ + { + 'priorities': self._trajectory_pool[i][1], + 'done': self._trajectory_pool[i][2], + '_unroll_plus_td_steps': self._unroll_plus_td_steps + } for i in range(L) + ] + + del self._trajectory_pool[:] + break + ctx.trajectories = return_data + + def pad_and_save_last_trajectory(self, i, last_game_segments, last_game_priorities, game_segments, done): + """ + Overview: + put the last game history into the pool if the current game is finished + Arguments: + - last_game_segments (:obj:`list`): list of the last game histories + - last_game_priorities (:obj:`list`): list of the last game priorities + - game_segments (:obj:`list`): list of the current game histories + Note: + (last_game_segments[i].obs_history[-4:][j] == game_segments[i].obs_history[:4][j]).all() is True + """ + # pad over last block trajectory + beg_index = self._cfg.model.frame_stack_num + end_index = beg_index + self._cfg.num_unroll_steps + + # the start obs is init zero obs, so we take the [ : +] obs as the pad obs + # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs + pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] + pad_child_visits_lst = game_segments[i].child_visit_segment[:self._cfg.num_unroll_steps] + # pad_child_visits_lst = game_segments[i].child_visit_history[beg_index:end_index] + + beg_index = 0 + # self._unroll_plus_td_steps = self._cfg.num_unroll_steps + self._cfg.td_steps + end_index = beg_index + self._unroll_plus_td_steps - 1 + + pad_reward_lst = game_segments[i].reward_segment[beg_index:end_index] + + beg_index = 0 + end_index = beg_index + self._unroll_plus_td_steps + + pad_root_values_lst = game_segments[i].root_value_segment[beg_index:end_index] + + # pad over and save + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) + """ + Note: + game_segment element shape: + obs: game_segment_length + stack + num_unroll_steps, 20+4 +5 + rew: game_segment_length + stack + num_unroll_steps + td_steps -1 20 +5+3-1 + action: game_segment_length -> 20 + root_values: game_segment_length + num_unroll_steps + td_steps -> 20 +5+3 + child_visits: game_segment_length + num_unroll_steps -> 20 +5 + to_play: game_segment_length -> 20 + action_mask: game_segment_length -> 20 + """ + + last_game_segments[i].game_segment_to_array() + + # put the game history into the pool + self._trajectory_pool.append((last_game_segments[i], last_game_priorities[i], done[i])) + + # reset last game_segments + last_game_segments[i] = None + last_game_priorities[i] = None + + def get_priorities(self, i, pred_values_lst, search_values_lst): + if self._cfg.use_priority and not self._cfg.use_max_priority_for_new_data: + pred_values = torch.from_numpy(np.array(pred_values_lst[i])).to(self._cfg.device).float().view(-1) + search_values = torch.from_numpy(np.array(search_values_lst[i])).to(self._cfg.device).float().view(-1) + priorities = torch.abs(pred_values - search_values).cpu().numpy() + priorities += self._cfg.prioritized_replay_eps + else: + # priorities is None -> use the max priority for all newly collected data + priorities = None + + return priorities diff --git a/lzero/middleware/evaluator.py b/lzero/middleware/evaluator.py new file mode 100644 index 000000000..db7aad087 --- /dev/null +++ b/lzero/middleware/evaluator.py @@ -0,0 +1,105 @@ +import numpy as np +from ditk import logging +from ding.framework import task +from ding.utils import EasyTimer +from ding.torch_utils import to_ndarray, to_tensor, to_device +from ding.framework.middleware.functional.evaluator import VectorEvalMonitor +from lzero.mcts.buffer.game_segment import GameSegment +from lzero.mcts.utils import prepare_observation + + +class MuZeroEvaluator: + + def __init__( + self, + cfg, + policy, + env, + eval_freq: int = 100, + ) -> None: + self._cfg = cfg.policy + self._env = env + self._env.seed(cfg.seed, dynamic_seed=False) + self._n_episode = cfg.env.n_evaluator_episode + self._policy = policy + self._eval_freq = eval_freq + self._max_eval_reward = float("-inf") + self._last_eval_iter = 0 + + self._timer = EasyTimer() + self._stop_value = cfg.env.stop_value + + def __call__(self, ctx): + if ctx.last_eval_iter != -1 and \ + (ctx.train_iter - ctx.last_eval_iter < self._eval_freq): + return + ctx.last_eval_iter = ctx.train_iter + if self._env.closed: + self._env.launch() + else: + self._env.reset() + self._policy.reset() + env_nums = self._env.env_num + n_episode = self._n_episode + eval_monitor = VectorEvalMonitor(env_nums, n_episode) + assert env_nums == n_episode + + init_obs = self._env.ready_obs + action_mask = [init_obs[i]['action_mask'] for i in range(env_nums)] + action_mask_dict = {i: to_ndarray(init_obs[i]['action_mask']) for i in range(env_nums)} + + game_segments = [ + GameSegment(self._env.action_space, game_segment_length=self._cfg.game_segment_length, config=self._cfg) + for _ in range(env_nums) + ] + for i in range(env_nums): + game_segments[i].reset( + [to_ndarray(init_obs[i]['observation']) for _ in range(self._cfg.model.frame_stack_num)] + ) + + ready_env_id = set() + remain_episode = n_episode + + with self._timer: + while not eval_monitor.is_finished(): + obs = self._env.ready_obs + new_available_env_id = set(obs.keys()).difference(ready_env_id) + ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode])) + remain_episode -= min(len(new_available_env_id), remain_episode) + + action_mask_dict = {env_id: action_mask_dict[env_id] for env_id in ready_env_id} + action_mask = [action_mask_dict[env_id] for env_id in ready_env_id] + + stack_obs = [game_segments[env_id].get_obs() for env_id in ready_env_id] + stack_obs = to_ndarray(stack_obs) + stack_obs = prepare_observation(stack_obs, self._cfg.model.model_type) + stack_obs = to_tensor(stack_obs) + stack_obs = to_device(stack_obs, self._cfg.device) + + policy_output = self._policy.forward(stack_obs, action_mask) + + actions = {i: v['action'] for i, v in zip(ready_env_id, policy_output)} + timesteps = self._env.step(actions) + + for env_id, t in timesteps.items(): + i = env_id + game_segments[i].append(actions[i], t.obs['observation'], t.reward) + + if t.done: + # Env reset is done by env_manager automatically. + self._policy.reset([env_id]) + reward = t.info['eval_episode_return'] + if 'episode_info' in t.info: + eval_monitor.update_info(env_id, t.info['episode_info']) + eval_monitor.update_reward(env_id, reward) + logging.info( + "[EVALUATOR]env {} finish episode, final episode_return: {}, current episode: {}".format( + env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode() + ) + ) + ready_env_id.remove(env_id) + episode_reward = eval_monitor.get_episode_return() + eval_reward = np.mean(episode_reward) + stop_flag = eval_reward >= self._stop_value and ctx.train_iter > 0 + if stop_flag: + task.finish = True diff --git a/lzero/middleware/helper.py b/lzero/middleware/helper.py new file mode 100644 index 000000000..23097962f --- /dev/null +++ b/lzero/middleware/helper.py @@ -0,0 +1,29 @@ +from lzero.policy import visit_count_temperature + + +def lr_scheduler(cfg, policy): + max_step = cfg.policy.threshold_training_steps_for_final_lr + + def _schedule(ctx): + if cfg.policy.lr_piecewise_constant_decay: + step = ctx.train_iter * cfg.policy.update_per_collect + if step < 0.5 * max_step: + policy._optimizer.lr = 0.2 + elif step < 0.75 * max_step: + policy._optimizer.lr = 0.02 + else: + policy._optimizer.lr = 0.002 + + return _schedule + + +def temperature_handler(cfg, env): + + def _handle(ctx): + step = ctx.train_iter * cfg.policy.update_per_collect + temperature = visit_count_temperature( + cfg.policy.manual_temperature_decay, 0.25, cfg.policy.threshold_training_steps_for_final_temperature, step + ) + ctx.collect_kwargs['temperature'] = temperature + + return _handle diff --git a/lzero/policy/__init__.py b/lzero/policy/__init__.py index a34930bd7..cfbaec6df 100644 --- a/lzero/policy/__init__.py +++ b/lzero/policy/__init__.py @@ -1,2 +1,6 @@ -from .scaling_transform import * -from .utils import * +from .scaling_transform import InverseScalarTransform, inverse_scalar_transform, scalar_transform, phi_transform +from .utils import to_detach_cpu_numpy, concat_output, concat_output_value, configure_optimizers, cross_entropy_loss, \ + visit_count_temperature +from .alphazero import AlphaZeroPolicy +from .muzero import MuZeroPolicy +from .efficientzero import EfficientZeroPolicy diff --git a/lzero/policy/alphazero.py b/lzero/policy/alphazero.py index 1418c3c25..0b6cd697f 100644 --- a/lzero/policy/alphazero.py +++ b/lzero/policy/alphazero.py @@ -12,7 +12,7 @@ from ding.utils.data import default_collate from lzero.mcts.ptree.ptree_az import MCTS -from lzero.policy import configure_optimizers +from .utils import configure_optimizers @POLICY_REGISTRY.register('alphazero') @@ -109,7 +109,12 @@ def _init_learn(self) -> None: self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay ) elif self._cfg.optim_type == 'AdamW': - self._optimizer = configure_optimizers(model=self._model, weight_decay=self._cfg.weight_decay, learning_rate=self._cfg.learning_rate, device_type=self._cfg.device) + self._optimizer = configure_optimizers( + model=self._model, + weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, + device_type=self._cfg.device + ) if self._cfg.lr_piecewise_constant_decay: from torch.optim.lr_scheduler import LambdaLR diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index 3440c9b17..82063ecd8 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -14,9 +14,9 @@ from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree from lzero.model import ImageTransforms -from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, select_action, to_torch_float_tensor, ez_network_output_unpack, negative_cosine_similarity, prepare_obs, \ - configure_optimizers +from .utils import to_torch_float_tensor, ez_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ + configure_optimizers, cross_entropy_loss +from .scaling_transform import scalar_transform, InverseScalarTransform, phi_transform, DiscreteSupport @POLICY_REGISTRY.register('efficientzero') @@ -193,7 +193,12 @@ def _init_learn(self) -> None: weight_decay=self._cfg.weight_decay, ) elif self._cfg.optim_type == 'AdamW': - self._optimizer = configure_optimizers(model=self._model, weight_decay=self._cfg.weight_decay, learning_rate=self._cfg.learning_rate, device_type=self._cfg.device) + self._optimizer = configure_optimizers( + model=self._model, + weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, + device_type=self._cfg.device + ) if self._cfg.lr_piecewise_constant_decay: from torch.optim.lr_scheduler import LambdaLR @@ -465,7 +470,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # priority related # ============================================================== 'value_priority': td_data[0].flatten().mean().item(), - 'value_priority_orig': value_priority, + 'priority': value_priority, # this key must be priority to update replay buffer 'target_value_prefix': td_data[1].flatten().mean().item(), 'target_value': td_data[2].flatten().mean().item(), 'transformed_target_value_prefix': td_data[3].flatten().mean().item(), @@ -488,23 +493,23 @@ def _init_collect(self) -> None: self.collect_mcts_temperature = 1 def _forward_collect( - self, - data: torch.Tensor, - action_mask: list = None, - temperature: float = 1, - to_play: List = [-1], - ready_env_id=None - ): + self, + data: torch.Tensor, + action_mask: List[int], + temperature: float = 1, + to_play: List[int] = None, + ready_env_id: List[int] = None, + ) -> List[Dict]: """ Overview: The forward function for collecting data in collect mode. Use model to execute MCTS search. Choosing the action through sampling during the collect mode. Arguments: - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - action_mask (:obj:`List[int]`): The action mask, i.e. the action that cannot be selected. - temperature (:obj:`float`): The temperature of the policy. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - to_play (:obj:`List[int]`): The player to play. + - ready_env_id (:obj:`List[int]`): The id of the env that is ready to collect. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ @@ -512,15 +517,20 @@ def _forward_collect( - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - temperature: :math:`(1, )`. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None + - to_play: :math:`(N, )`, where N is the number of collect_env. + - ready_env_id: :math:`(N, )` Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + - output (:obj:`List[Dict]`): Each element is a dict-type data, the keys including ``action``, \ + ``distributions``, ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. """ self._collect_model.eval() self.collect_mcts_temperature = temperature active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + if to_play is None: + to_play = [-1 for _ in ready_env_id] + with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._collect_model.initial_inference(data) @@ -555,15 +565,11 @@ def _forward_collect( roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play ) - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} - data_id = [i for i in range(active_collect_env_num)] - output = {i: None for i in data_id} - if ready_env_id is None: - ready_env_id = np.arange(active_collect_env_num) - + output = [] for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents @@ -573,14 +579,16 @@ def _forward_collect( ) # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[env_id] = { - 'action': action, - 'distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], - } + output.append( + { + 'action': action, + 'distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'value': value, + 'pred_value': pred_values[i], + 'policy_logits': policy_logits[i], + } + ) return output @@ -595,30 +603,41 @@ def _init_eval(self) -> None: else: self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): + def _forward_eval( + self, + data: torch.Tensor, + action_mask: List[int], + to_play: List[int] = None, + ready_env_id: List[int] = None + ) -> List[Dict]: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. Choosing the action with the highest value (argmax) rather than sampling during the eval mode. Arguments: - data (:obj:`torch.Tensor`): The input data, i.e. the observation. - - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - - to_play (:obj:`int`): The player to play. - - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - action_mask (:obj:`List[int]`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`List[int]`): The player to play. + - ready_env_id (:obj:`List[int]`): The id of the env that is ready to collect. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ S is the number of stacked frames, H is the height of the image, W is the width of the image. - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. - - to_play: :math:`(N, 1)`, where N is the number of collect_env. - - ready_env_id: None + - to_play: :math:`(N, )`, where N is the number of collect_env. + - ready_env_id: :math:`(N, )` Returns: - - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ - ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + - output (:obj:`List[Dict]`): Each element is a dict-type data, the keys including ``action``, \ + ``distributions``, ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. """ self._eval_model.eval() active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + if to_play is None: + to_play = [-1 for _ in ready_env_id] + with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._eval_model.initial_inference(data) @@ -646,15 +665,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read roots.prepare_no_noise(value_prefix_roots, policy_logits, to_play) self._mcts_eval.search(roots, self._eval_model, latent_state_roots, reward_hidden_state_roots, to_play) - roots_visit_count_distributions = roots.get_distributions( - ) # shape: ``{list: batch_size} ->{list: action_space_size}`` + # shape: ``{list: batch_size} ->{list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} - data_id = [i for i in range(active_eval_env_num)] - output = {i: None for i in data_id} - - if ready_env_id is None: - ready_env_id = np.arange(active_eval_env_num) + output = [] for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents @@ -665,14 +680,16 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read ) # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] - output[env_id] = { - 'action': action, - 'distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], - } + output.append( + { + 'action': action, + 'distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'value': value, + 'pred_value': pred_values[i], + 'policy_logits': policy_logits[i], + } + ) return output diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 212691f1e..2df5300cf 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -13,9 +13,9 @@ from lzero.mcts import MuZeroMCTSCtree as MCTSCtree from lzero.mcts import MuZeroMCTSPtree as MCTSPtree from lzero.model import ImageTransforms -from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ - configure_optimizers +from .utils import to_torch_float_tensor, mz_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ + configure_optimizers, cross_entropy_loss +from .scaling_transform import scalar_transform, InverseScalarTransform, phi_transform, DiscreteSupport @POLICY_REGISTRY.register('muzero') @@ -189,7 +189,12 @@ def _init_learn(self) -> None: self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay ) elif self._cfg.optim_type == 'AdamW': - self._optimizer = configure_optimizers(model=self._model, weight_decay=self._cfg.weight_decay, learning_rate=self._cfg.learning_rate, device_type=self._cfg.device) + self._optimizer = configure_optimizers( + model=self._model, + weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, + device_type=self._cfg.device + ) if self._cfg.lr_piecewise_constant_decay: from torch.optim.lr_scheduler import LambdaLR diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index 0fe9e690e..9abb58ff6 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -15,9 +15,9 @@ from lzero.mcts import SampledEfficientZeroMCTSCtree as MCTSCtree from lzero.mcts import SampledEfficientZeroMCTSPtree as MCTSPtree from lzero.model import ImageTransforms -from lzero.policy import scalar_transform, InverseScalarTransform, cross_entropy_loss, phi_transform, \ - DiscreteSupport, to_torch_float_tensor, ez_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ - configure_optimizers +from .utils import to_torch_float_tensor, ez_network_output_unpack, select_action, negative_cosine_similarity, prepare_obs, \ + configure_optimizers, cross_entropy_loss +from .scaling_transform import scalar_transform, InverseScalarTransform, phi_transform, DiscreteSupport @POLICY_REGISTRY.register('sampled_efficientzero') @@ -218,7 +218,12 @@ def _init_learn(self) -> None: self._model.parameters(), lr=self._cfg.learning_rate, weight_decay=self._cfg.weight_decay ) elif self._cfg.optim_type == 'AdamW': - self._optimizer = configure_optimizers(model=self._model, weight_decay=self._cfg.weight_decay, learning_rate=self._cfg.learning_rate, device_type=self._cfg.device) + self._optimizer = configure_optimizers( + model=self._model, + weight_decay=self._cfg.weight_decay, + learning_rate=self._cfg.learning_rate, + device_type=self._cfg.device + ) if self._cfg.cos_lr_scheduler is True: from torch.optim.lr_scheduler import CosineAnnealingLR diff --git a/lzero/policy/scaling_transform.py b/lzero/policy/scaling_transform.py index 75b170612..9e6d86adb 100644 --- a/lzero/policy/scaling_transform.py +++ b/lzero/policy/scaling_transform.py @@ -103,26 +103,17 @@ def __call__(self, logits: torch.Tensor, epsilon: float = 0.001) -> torch.Tensor return output -def visit_count_temperature( - manual_temperature_decay: bool, fixed_temperature_value: float, - threshold_training_steps_for_final_lr_temperature: int, trained_steps: int -) -> float: - if manual_temperature_decay: - if trained_steps < 0.5 * threshold_training_steps_for_final_lr_temperature: - return 1.0 - elif trained_steps < 0.75 * threshold_training_steps_for_final_lr_temperature: - return 0.5 - else: - return 0.25 - else: - return fixed_temperature_value - - def phi_transform(discrete_support: DiscreteSupport, x: torch.Tensor) -> torch.Tensor: """ Overview: We then apply a transformation ``phi`` to the scalar in order to obtain equivalent categorical representations. After this transformation, each scalar is represented as the linear combination of its two adjacent supports. + Arguments: + - discrete_support (:obj:`DiscreteSupport`): The discrete support used in categorical representations of \ + reward, value or value_prefix. + - x (:obj:`torch.Tensor`): The input tensor. + Returns: + - target (:obj:`torch.Tensor`): The output transformed tensor. Reference: - MuZero paper Appendix F: Network Architecture. """ @@ -143,7 +134,3 @@ def phi_transform(discrete_support: DiscreteSupport, x: torch.Tensor) -> torch.T target.scatter_(2, x_low_idx.long().unsqueeze(-1), p_low.unsqueeze(-1)) return target - - -def cross_entropy_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: - return -(torch.log_softmax(prediction, dim=1) * target).sum(1) diff --git a/lzero/policy/utils.py b/lzero/policy/utils.py index 0315312ad..e81f2091e 100644 --- a/lzero/policy/utils.py +++ b/lzero/policy/utils.py @@ -1,16 +1,13 @@ -import logging from typing import List, Tuple, Dict from easydict import EasyDict -import numpy as np -import torch -from scipy.stats import entropy -import math +from ditk import logging import inspect - +import numpy as np import torch import torch.nn as nn -from torch.nn import functional as F +import torch.nn.functional as F +from scipy.stats import entropy class LayerNorm(nn.Module): @@ -25,9 +22,13 @@ def forward(self, input): return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) -def configure_optimizers(model: nn.Module, weight_decay: float = 0, learning_rate: float = 3e-3, - betas: tuple = (0.9, 0.999), device_type: str = "cuda"): - +def configure_optimizers( + model: nn.Module, + weight_decay: float = 0, + learning_rate: float = 3e-3, + betas: tuple = (0.9, 0.999), + device_type: str = "cuda" +): """ Overview: This function is adapted from https://github.com/karpathy/nanoGPT/blob/master/model.py @@ -49,7 +50,9 @@ def configure_optimizers(model: nn.Module, weight_decay: float = 0, learning_rat decay = set() no_decay = set() whitelist_weight_modules = (torch.nn.Linear, torch.nn.LSTM, nn.Conv2d) - blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d) + blacklist_weight_modules = ( + torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding, torch.nn.BatchNorm1d, torch.nn.BatchNorm2d + ) for mn, m in model.named_modules(): for pn, p in m.named_parameters(): fpn = '%s.%s' % (mn, pn) if mn else pn # full param name @@ -62,7 +65,8 @@ def configure_optimizers(model: nn.Module, weight_decay: float = 0, learning_rat elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): # weights of whitelist modules will be weight decayed decay.add(fpn) - elif (pn.endswith('weight_ih_l0') or pn.endswith('weight_hh_l0')) and isinstance(m, whitelist_weight_modules): + elif (pn.endswith('weight_ih_l0') or pn.endswith('weight_hh_l0')) and isinstance(m, + whitelist_weight_modules): # some special weights of whitelist modules will be weight decayed decay.add(fpn) elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): @@ -83,15 +87,21 @@ def configure_optimizers(model: nn.Module, weight_decay: float = 0, learning_rat param_dict = {pn: p for pn, p in model.named_parameters()} inter_params = decay & no_decay union_params = decay | no_decay - assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params),) + assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) assert len( param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ % (str(param_dict.keys() - union_params),) # create the pytorch optimizer object optim_groups = [ - {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, - {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, + { + "params": [param_dict[pn] for pn in sorted(list(decay))], + "weight_decay": weight_decay + }, + { + "params": [param_dict[pn] for pn in sorted(list(no_decay))], + "weight_decay": 0.0 + }, ] # new PyTorch nightly has a new 'fused' option for AdamW that is much faster use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters) @@ -202,14 +212,14 @@ def get_max_entropy(action_shape: int) -> np.float32: return -action_shape * p * np.log2(p) -def select_action(visit_counts: np.ndarray, +def select_action(visit_counts: List, temperature: float = 1, deterministic: bool = True) -> Tuple[np.int64, np.ndarray]: """ Overview: Select action from visit counts of the root node. Arguments: - - visit_counts (:obj:`np.ndarray`): The visit counts of the root node. + - visit_counts (:obj:`List`): The visit counts of the root node. - temperature (:obj:`float`): The temperature used to adjust the sampling distribution. - deterministic (:obj:`bool`): Whether to enable deterministic mode in action selection. True means to \ select the argmax result, False indicates to sample action from the distribution. @@ -347,3 +357,22 @@ def mz_network_output_unpack(network_output: Dict) -> Tuple: value = network_output.value # shape: (batch_size, support_support_size) policy_logits = network_output.policy_logits # shape: (batch_size, action_space_size) return latent_state, reward, value, policy_logits + + +def visit_count_temperature( + manual_temperature_decay: bool, fixed_temperature_value: float, + threshold_training_steps_for_final_lr_temperature: int, trained_steps: int +) -> float: + if manual_temperature_decay: + if trained_steps < 0.5 * threshold_training_steps_for_final_lr_temperature: + return 1.0 + elif trained_steps < 0.75 * threshold_training_steps_for_final_lr_temperature: + return 0.5 + else: + return 0.25 + else: + return fixed_temperature_value + + +def cross_entropy_loss(prediction: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + return -(torch.log_softmax(prediction, dim=1) * target).sum(1) diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 5cb5be377..07bb7de30 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -1,14 +1,14 @@ -import time -from collections import deque, namedtuple from typing import Optional, Any, List +from collections import deque, namedtuple +import time import numpy as np import torch +from torch.nn import L1Loss from ding.envs import BaseEnvManager from ding.torch_utils import to_ndarray from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY from ding.worker.collector.base_serial_collector import ISerialCollector -from torch.nn import L1Loss from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation @@ -261,8 +261,6 @@ def pad_and_save_last_trajectory(self, i, last_game_segments, last_game_prioriti last_game_segments[i] = None last_game_priorities[i] = None - return None - def collect(self, n_episode: Optional[int] = None, train_iter: int = 0, diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 04d6fece9..7213e14d0 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -1,11 +1,11 @@ -import time -import copy -from collections import namedtuple from typing import Optional, Callable, Tuple +from collections import namedtuple +from easydict import EasyDict +import time +import copy import numpy as np import torch -from easydict import EasyDict from ding.envs import BaseEnvManager from ding.torch_utils import to_ndarray diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py index 17147b374..7a933af3c 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py @@ -9,7 +9,7 @@ num_simulations = 50 update_per_collect = 200 batch_size = 256 -max_env_step = int(5e6) +max_env_step = int(5e5) reanalyze_ratio = 0. # ============================================================== # end of the most frequently changed config specified by the user diff --git a/zoo/classic_control/cartpole/config/cartpole_efficientzero_task_config.py b/zoo/classic_control/cartpole/config/cartpole_efficientzero_task_config.py new file mode 100644 index 000000000..474af7372 --- /dev/null +++ b/zoo/classic_control/cartpole/config/cartpole_efficientzero_task_config.py @@ -0,0 +1,133 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(1e5) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cartpole_efficientzero_config = dict( + exp_name='data_ez_ctree/cartpole_efficientzero_task_seed0', + env=dict( + env_name='CartPole-v0', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=4, + action_space_size=2, + model_type='mlp', # options={'mlp', 'conv'} + lstm_hidden_size=128, + latent_state_dim=128, + ), + cuda=True, + env_type='not_board_games', + game_segment_length=50, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e2), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) + +cartpole_efficientzero_config = EasyDict(cartpole_efficientzero_config) +main_config = cartpole_efficientzero_config + +cartpole_efficientzero_create_config = dict( + env=dict( + type='cartpole_lightzero', + import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'], + ), + env_manager=dict(type='base'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +cartpole_efficientzero_create_config = EasyDict(cartpole_efficientzero_create_config) +create_config = cartpole_efficientzero_create_config + +if __name__ == "__main__": + from functools import partial + from ditk import logging + from ding.config import compile_config + from ding.envs import create_env_manager, get_vec_env_setting + from ding.framework import task, ding_init + from ding.framework.context import OnlineRLContext + from ding.framework.middleware import ContextExchanger, ModelExchanger, CkptSaver, trainer, \ + termination_checker, online_logger + from ding.utils import set_pkg_seed + from lzero.policy import EfficientZeroPolicy + from lzero.mcts import EfficientZeroGameBuffer + from lzero.middleware import MuZeroEvaluator, MuZeroCollector, temperature_handler, data_reanalyze_fetcher, \ + lr_scheduler, data_pusher + + logging.getLogger().setLevel(logging.INFO) + main_config.policy.device = 'cpu' # ['cpu', 'cuda'] + cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0) + ding_init(cfg) + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + policy = EfficientZeroPolicy(cfg.policy, enable_field=['learn', 'collect', 'eval']) + replay_buffer = EfficientZeroGameBuffer(cfg.policy) + + with task.start(ctx=OnlineRLContext()): + + # Consider the case with multiple processes + if task.router.is_active: + # You can use labels to distinguish between workers with different roles, + # here we use node_id to distinguish. + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + elif task.router.node_id == 1: + task.add_role(task.role.EVALUATOR) + elif task.router.node_id == 2: + task.add_role(task.role.REANALYZER) + else: + task.add_role(task.role.COLLECTOR) + + # Sync their context and model between each worker. + task.use(ContextExchanger(skip_n_iter=1)) + task.use(ModelExchanger(policy.model)) + + # Here is the part of single process pipeline. + task.use(MuZeroEvaluator(cfg, policy.eval_mode, evaluator_env, eval_freq=100)) + task.use(temperature_handler(cfg, collector_env)) + task.use(MuZeroCollector(cfg, policy.collect_mode, collector_env)) + task.use(data_pusher(replay_buffer)) + task.use(data_reanalyze_fetcher(cfg, policy, replay_buffer)) + task.use(trainer(cfg, policy.learn_mode)) + task.use(lr_scheduler(cfg, policy)) + task.use(online_logger(train_show_freq=10)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=int(1e4))) + task.use(termination_checker(max_env_step=int(max_env_step))) + task.run() From f8b3830c4c6669474d25261d11acb3f898055ed1 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 26 Apr 2023 19:47:33 +0800 Subject: [PATCH 2/5] fix(nyz): add data processor which is ignored --- lzero/middleware/data_processor.py | 40 ++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 lzero/middleware/data_processor.py diff --git a/lzero/middleware/data_processor.py b/lzero/middleware/data_processor.py new file mode 100644 index 000000000..88284bd6a --- /dev/null +++ b/lzero/middleware/data_processor.py @@ -0,0 +1,40 @@ +from typing import Callable, TYPE_CHECKING +from easydict import EasyDict +from ding.utils import one_time_warning + +if TYPE_CHECKING: + from ding.policy import Policy + from lzero.mcts import GameBuffer + + +def data_pusher(replay_buffer: 'GameBuffer') -> Callable: + + def _push(ctx): + if ctx.trajectories is not None: # collector will skip when not reach update_per_collect + new_data = ctx.trajectories + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + return _push + + +def data_reanalyze_fetcher(cfg: EasyDict, policy: 'Policy', replay_buffer: 'GameBuffer') -> Callable: + B = cfg.policy.batch_size + + def _fetch(ctx): + if replay_buffer.get_num_of_transitions() > B: + ctx.train_data = replay_buffer.sample(B, policy) + + yield + + if cfg.policy.use_priority: + replay_buffer.update_priority(ctx.train_data, ctx.train_output['priority']) + else: + one_time_warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {B}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + + return _fetch From 78fa13e59366f0ad95d70ba5b06a5c859c8daa03 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Wed, 26 Apr 2023 20:12:02 +0800 Subject: [PATCH 3/5] fix(nyz): fix metadata name bug --- lzero/middleware/collector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lzero/middleware/collector.py b/lzero/middleware/collector.py index 7fd26fe5e..8442732c2 100644 --- a/lzero/middleware/collector.py +++ b/lzero/middleware/collector.py @@ -214,7 +214,7 @@ def __call__(self, ctx): { 'priorities': self._trajectory_pool[i][1], 'done': self._trajectory_pool[i][2], - '_unroll_plus_td_steps': self._unroll_plus_td_steps + 'unroll_plus_td_steps': self._unroll_plus_td_steps } for i in range(L) ] From b610842ba8ddb38105d87e81600bc8ca7df278e6 Mon Sep 17 00:00:00 2001 From: niuyazhe Date: Sat, 6 May 2023 22:30:28 +0800 Subject: [PATCH 4/5] fix(nyz): fix old pipeline compatibility bug for list output --- lzero/entry/train_muzero.py | 2 +- lzero/middleware/data_processor.py | 2 +- lzero/policy/efficientzero.py | 2 +- lzero/worker/muzero_collector.py | 37 ++++++++---------------------- lzero/worker/muzero_evaluator.py | 35 +++++----------------------- 5 files changed, 19 insertions(+), 59 deletions(-) diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py index c981813e9..97eff0e58 100644 --- a/lzero/entry/train_muzero.py +++ b/lzero/entry/train_muzero.py @@ -150,7 +150,7 @@ def train_muzero( log_vars = learner.train(train_data, collector.envstep) if cfg.policy.use_priority: - replay_buffer.update_priority(train_data, log_vars[0]['priority']) + replay_buffer.update_priority(train_data, log_vars[0]['td_error_priority']) if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: break diff --git a/lzero/middleware/data_processor.py b/lzero/middleware/data_processor.py index 88284bd6a..4649cfa36 100644 --- a/lzero/middleware/data_processor.py +++ b/lzero/middleware/data_processor.py @@ -28,7 +28,7 @@ def _fetch(ctx): yield if cfg.policy.use_priority: - replay_buffer.update_priority(ctx.train_data, ctx.train_output['priority']) + replay_buffer.update_priority(ctx.train_data, ctx.train_output['td_error_priority']) else: one_time_warning( f'The data in replay_buffer is not sufficient to sample a mini-batch: ' diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index 82063ecd8..6788c30e3 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -470,7 +470,7 @@ def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: # priority related # ============================================================== 'value_priority': td_data[0].flatten().mean().item(), - 'priority': value_priority, # this key must be priority to update replay buffer + 'td_error_priority': value_priority, # this key must be priority to update replay buffer 'target_value_prefix': td_data[1].flatten().mean().item(), 'target_value': td_data[2].flatten().mean().item(), 'transformed_target_value_prefix': td_data[3].flatten().mean().item(), diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 07bb7de30..0274e1534 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -369,36 +369,19 @@ def collect(self, # ============================================================== policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play) - actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} - distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} - if self.policy_config.sampled_algo: - root_sampled_actions_dict_no_env_id = { - k: v['root_sampled_actions'] - for k, v in policy_output.items() - } - value_dict_no_env_id = {k: v['value'] for k, v in policy_output.items()} - pred_value_dict_no_env_id = {k: v['pred_value'] for k, v in policy_output.items()} - visit_entropy_dict_no_env_id = { + actions = {k: v['action'] for k, v in zip(ready_env_id, policy_output)} + distributions_dict = {k: v['distributions'] for k, v in zip(ready_env_id, policy_output)} + value_dict = {k: v['value'] for k, v in zip(ready_env_id, policy_output)} + pred_value_dict = {k: v['pred_value'] for k, v in zip(ready_env_id, policy_output)} + visit_entropy_dict = { k: v['visit_count_distribution_entropy'] - for k, v in policy_output.items() + for k, v in zip(ready_env_id, policy_output) } - - # TODO(pu): subprocess - actions = {} - distributions_dict = {} if self.policy_config.sampled_algo: - root_sampled_actions_dict = {} - value_dict = {} - pred_value_dict = {} - visit_entropy_dict = {} - for index, env_id in enumerate(ready_env_id): - actions[env_id] = actions_no_env_id.pop(index) - distributions_dict[env_id] = distributions_dict_no_env_id.pop(index) - if self.policy_config.sampled_algo: - root_sampled_actions_dict[env_id] = root_sampled_actions_dict_no_env_id.pop(index) - value_dict[env_id] = value_dict_no_env_id.pop(index) - pred_value_dict[env_id] = pred_value_dict_no_env_id.pop(index) - visit_entropy_dict[env_id] = visit_entropy_dict_no_env_id.pop(index) + root_sampled_actions_dict = { + k: v['root_sampled_actions'] + for k, v in zip(ready_env_id, policy_output) + } # ============================================================== # Interact with env. diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 7213e14d0..89c599f67 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -274,37 +274,14 @@ def eval( # ============================================================== policy_output = self._policy.forward(stack_obs, action_mask, to_play) - actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} - distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} - if self.policy_config.sampled_algo: - root_sampled_actions_dict_no_env_id = { - k: v['root_sampled_actions'] - for k, v in policy_output.items() - } - - value_dict_no_env_id = {k: v['value'] for k, v in policy_output.items()} - pred_value_dict_no_env_id = {k: v['pred_value'] for k, v in policy_output.items()} - visit_entropy_dict_no_env_id = { + actions = {k: v['action'] for k, v in zip(ready_env_id, policy_output)} + distributions_dict = {k: v['distributions'] for k, v in zip(ready_env_id, policy_output)} + value_dict = {k: v['value'] for k, v in zip(ready_env_id, policy_output)} + pred_value_dict = {k: v['pred_value'] for k, v in zip(ready_env_id, policy_output)} + visit_entropy_dict = { k: v['visit_count_distribution_entropy'] - for k, v in policy_output.items() + for k, v in zip(ready_env_id, policy_output) } - - actions = {} - distributions_dict = {} - if self.policy_config.sampled_algo: - root_sampled_actions_dict = {} - value_dict = {} - pred_value_dict = {} - visit_entropy_dict = {} - for index, env_id in enumerate(ready_env_id): - actions[env_id] = actions_no_env_id.pop(index) - distributions_dict[env_id] = distributions_dict_no_env_id.pop(index) - if self.policy_config.sampled_algo: - root_sampled_actions_dict[env_id] = root_sampled_actions_dict_no_env_id.pop(index) - value_dict[env_id] = value_dict_no_env_id.pop(index) - pred_value_dict[env_id] = pred_value_dict_no_env_id.pop(index) - visit_entropy_dict[env_id] = visit_entropy_dict_no_env_id.pop(index) - # ============================================================== # Interact with env. # ============================================================== From 4b2a99daa497d3460259fc8c86036fabb173a47b Mon Sep 17 00:00:00 2001 From: yangzhenjie Date: Tue, 9 May 2023 18:18:22 +0800 Subject: [PATCH 5/5] polish(yzj):polish atari pong config on new pipeline --- .../config/atari_efficientzero_task_config.py | 146 ++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 zoo/atari/config/atari_efficientzero_task_config.py diff --git a/zoo/atari/config/atari_efficientzero_task_config.py b/zoo/atari/config/atari_efficientzero_task_config.py new file mode 100644 index 000000000..4d8bf4c88 --- /dev/null +++ b/zoo/atari/config/atari_efficientzero_task_config.py @@ -0,0 +1,146 @@ +from easydict import EasyDict + +# options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} +env_name = 'PongNoFrameskip-v4' + +if env_name == 'PongNoFrameskip-v4': + action_space_size = 6 +elif env_name == 'QbertNoFrameskip-v4': + action_space_size = 6 +elif env_name == 'MsPacmanNoFrameskip-v4': + action_space_size = 9 +elif env_name == 'SpaceInvadersNoFrameskip-v4': + action_space_size = 6 +elif env_name == 'BreakoutNoFrameskip-v4': + action_space_size = 4 + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 +max_env_step = int(1e6) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +atari_efficientzero_config = dict( + exp_name= + f'data_ez_ctree/{env_name[:-14]}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + env=dict( + env_name=env_name, + obs_shape=(4, 96, 96), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(4, 96, 96), + frame_stack_num=4, + action_space_size=action_space_size, + downsample=True, + ), + cuda=True, + env_type='not_board_games', + game_segment_length=400, + use_augmentation=True, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), +) +atari_efficientzero_config = EasyDict(atari_efficientzero_config) +main_config = atari_efficientzero_config + +atari_efficientzero_create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), + collector=dict( + type='episode_muzero', + import_names=['lzero.worker.muzero_collector'], + ) +) +atari_efficientzero_create_config = EasyDict(atari_efficientzero_create_config) +create_config = atari_efficientzero_create_config + +if __name__ == "__main__": + from functools import partial + from ditk import logging + from ding.config import compile_config + from ding.envs import create_env_manager, get_vec_env_setting + from ding.framework import task, ding_init + from ding.framework.context import OnlineRLContext + from ding.framework.middleware import ContextExchanger, ModelExchanger, CkptSaver, trainer, \ + termination_checker, online_logger + from ding.utils import set_pkg_seed + from lzero.policy import EfficientZeroPolicy + from lzero.mcts import EfficientZeroGameBuffer + from lzero.middleware import MuZeroEvaluator, MuZeroCollector, temperature_handler, data_reanalyze_fetcher, \ + lr_scheduler, data_pusher + + logging.getLogger().setLevel(logging.INFO) + main_config.policy.device = 'cuda' # ['cpu', 'cuda'] + cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0) + ding_init(cfg) + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + policy = EfficientZeroPolicy(cfg.policy, enable_field=['learn', 'collect', 'eval']) + replay_buffer = EfficientZeroGameBuffer(cfg.policy) + + with task.start(ctx=OnlineRLContext()): + + # Consider the case with multiple processes + if task.router.is_active: + # You can use labels to distinguish between workers with different roles, + # here we use node_id to distinguish. + if task.router.node_id == 0: + task.add_role(task.role.LEARNER) + elif task.router.node_id == 1: + task.add_role(task.role.EVALUATOR) + elif task.router.node_id == 2: + task.add_role(task.role.REANALYZER) + else: + task.add_role(task.role.COLLECTOR) + + # Sync their context and model between each worker. + task.use(ContextExchanger(skip_n_iter=1)) + task.use(ModelExchanger(policy.model)) + + # Here is the part of single process pipeline. + task.use(MuZeroEvaluator(cfg, policy.eval_mode, evaluator_env, eval_freq=100)) + task.use(temperature_handler(cfg, collector_env)) + task.use(MuZeroCollector(cfg, policy.collect_mode, collector_env)) + task.use(data_pusher(replay_buffer)) + task.use(data_reanalyze_fetcher(cfg, policy, replay_buffer)) + task.use(trainer(cfg, policy.learn_mode)) + task.use(lr_scheduler(cfg, policy)) + task.use(online_logger(train_show_freq=10)) + task.use(CkptSaver(policy, cfg.exp_name, train_freq=int(1e4))) + task.use(termination_checker(max_env_step=int(max_env_step))) + task.run()