From d845648f43d44ebf96f5ef58813d08d29d7ea199 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <48008469+puyuan1996@users.noreply.github.com> Date: Mon, 30 Oct 2023 17:26:16 +0800 Subject: [PATCH] feature(pu): add random_policy support for continuous env (#118) * polish(pu): add random_policy support for continuous env * polish(pu): polish some variables name and redundant code * fix(pu): fix mcts_ptree/mcts_ctree unittest * fix(pu): fix test_mcts_sampled_ctree --- lzero/entry/utils.py | 2 +- lzero/mcts/tests/cprofile_mcts_ptree.py | 4 +- lzero/mcts/tests/eval_tree_speed.py | 4 +- lzero/mcts/tests/test_mcts_ctree.py | 48 ++++--- lzero/policy/efficientzero.py | 20 +-- lzero/policy/gumbel_muzero.py | 20 +-- lzero/policy/muzero.py | 22 ++-- lzero/policy/random_policy.py | 123 +++++++++++++----- lzero/policy/sampled_efficientzero.py | 43 +++--- lzero/policy/stochastic_muzero.py | 20 +-- lzero/worker/muzero_collector.py | 6 +- ...ander_cont_sampled_efficientzero_config.py | 2 + 12 files changed, 196 insertions(+), 118 deletions(-) diff --git a/lzero/entry/utils.py b/lzero/entry/utils.py index b11e37d0a..8e26bc506 100644 --- a/lzero/entry/utils.py +++ b/lzero/entry/utils.py @@ -17,7 +17,7 @@ def random_collect( ) -> None: # noqa assert policy_cfg.random_collect_episode_num > 0 - random_policy = RandomPolicy(cfg=policy_cfg) + random_policy = RandomPolicy(cfg=policy_cfg, action_space=collector_env.env_ref.action_space) # set the policy to random policy collector.reset_policy(random_policy.collect_mode) diff --git a/lzero/mcts/tests/cprofile_mcts_ptree.py b/lzero/mcts/tests/cprofile_mcts_ptree.py index 4edf376ec..956ec39fa 100644 --- a/lzero/mcts/tests/cprofile_mcts_ptree.py +++ b/lzero/mcts/tests/cprofile_mcts_ptree.py @@ -27,7 +27,7 @@ def initial_inference(self, observation): reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16))) output = { - 'value': value, + 'searched_value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, @@ -45,7 +45,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions): policy_logits = torch.zeros(size=(batch_size, self.action_num)) output = { - 'value': value, + 'searched_value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, diff --git a/lzero/mcts/tests/eval_tree_speed.py b/lzero/mcts/tests/eval_tree_speed.py index b80957dbf..c7134f3b3 100644 --- a/lzero/mcts/tests/eval_tree_speed.py +++ b/lzero/mcts/tests/eval_tree_speed.py @@ -32,7 +32,7 @@ def initial_inference(self, observation): reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16))) output = { - 'value': value, + 'searched_value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, @@ -50,7 +50,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions): policy_logits = torch.zeros(size=(batch_size, self.action_num)) output = { - 'value': value, + 'searched_value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, diff --git a/lzero/mcts/tests/test_mcts_ctree.py b/lzero/mcts/tests/test_mcts_ctree.py index f15fe6780..e9a6424f0 100644 --- a/lzero/mcts/tests/test_mcts_ctree.py +++ b/lzero/mcts/tests/test_mcts_ctree.py @@ -5,7 +5,6 @@ from lzero.policy import inverse_scalar_transform, select_action - policy = 'GumbelMuZero' if policy == 'EfficientZero': @@ -14,6 +13,8 @@ from lzero.mcts.tree_search.mcts_ctree import GumbelMuZeroMCTSCtree as MCTSCtree else: raise KeyError('Only support test for EfficientZero and GumbelMuZero.') + + class MuZeroModelFake(torch.nn.Module): """ Overview: @@ -50,7 +51,7 @@ def initial_inference(self, observation): def recurrent_inference(self, latent_states, reward_hidden_states, actions=None): if policy == 'GumbelMuZero': - assert actions==None + assert actions == None actions = reward_hidden_states batch_size = latent_states.shape[0] latent_state = torch.zeros(size=(batch_size, 12, 3, 3)) @@ -78,7 +79,7 @@ def recurrent_inference(self, latent_states, reward_hidden_states, actions=None) batch_size=16, pb_c_base=1, pb_c_init=1, - max_num_considered_actions = 6, + max_num_considered_actions=6, discount_factor=0.9, root_dirichlet_alpha=0.3, root_noise_weight=0.2, @@ -163,17 +164,18 @@ def test_mcts_vs_bot_to_play(): policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) elif policy == 'GumbelMuZero': roots.prepare( - policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)] + policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, + [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) roots_distributions = roots.get_distributions() roots_values = roots.get_values() assert np.array(roots_distributions).shape == (batch_size, action_space_size) - assert np.array(roots_values).shape == (batch_size, ) + assert np.array(roots_values).shape == (batch_size,) @pytest.mark.unittest @@ -220,17 +222,18 @@ def test_mcts_vs_bot_to_play_large(): policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) elif policy == 'GumbelMuZero': roots.prepare( - policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)] + policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, + [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) roots_distributions = roots.get_distributions() roots_values = roots.get_values() assert np.array(roots_distributions).shape == (policy_config.batch_size, policy_config.model.action_space_size) - assert np.array(roots_values).shape == (policy_config.batch_size, ) + assert np.array(roots_values).shape == (policy_config.batch_size,) @pytest.mark.unittest @@ -250,13 +253,14 @@ def test_mcts_vs_bot_to_play_legal_action(): policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) elif policy == 'GumbelMuZero': roots.prepare( - policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)] + policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, + [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) roots_distributions = roots.get_distributions() roots_values = roots.get_values() assert len(roots_values) == env_nums @@ -290,17 +294,18 @@ def test_mcts_self_play(): policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) elif policy == 'GumbelMuZero': roots.prepare( - policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)] + policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, + [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) roots_distributions = roots.get_distributions() roots_values = roots.get_values() assert np.array(roots_distributions).shape == (batch_size, action_space_size) - assert np.array(roots_values).shape == (batch_size, ) + assert np.array(roots_values).shape == (batch_size,) @pytest.mark.unittest @@ -319,13 +324,14 @@ def test_mcts_self_play_legal_action(): policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) elif policy == 'GumbelMuZero': roots.prepare( - policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)] + policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, + [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) roots_distributions = roots.get_distributions() roots_values = roots.get_values() assert len(roots_values) == env_nums diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index b18a9e297..685e473d7 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -155,6 +155,8 @@ class EfficientZeroPolicy(MuZeroPolicy): # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, # ****** Priority ****** # (bool) Whether to use priority when sampling training data from the buffer. @@ -624,11 +626,11 @@ def _forward_collect( action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output @@ -644,7 +646,7 @@ def _init_eval(self) -> None: else: self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,): """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -716,11 +718,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output diff --git a/lzero/policy/gumbel_muzero.py b/lzero/policy/gumbel_muzero.py index b4f88a387..218ab99b6 100644 --- a/lzero/policy/gumbel_muzero.py +++ b/lzero/policy/gumbel_muzero.py @@ -486,7 +486,7 @@ def _forward_collect( action_mask: list = None, temperature: float = 1, to_play: List = [-1], - ready_env_id=None + ready_env_id: np.array = None, ) -> Dict: """ Overview: @@ -572,13 +572,13 @@ def _forward_collect( action = np.argmax([v for v in valid_value]) output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, + 'searched_value': value, 'roots_completed_value': roots_completed_value, 'improved_policy_probs': improved_policy_probs, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output @@ -594,7 +594,7 @@ def _init_eval(self) -> None: else: self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -674,11 +674,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index b6c056982..b4094e0a0 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -159,7 +159,7 @@ class MuZeroPolicy(Policy): # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. fixed_temperature_value=0.25, - # (bool) Whether to use the true chance in MCTS. + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. use_ture_chance_label_in_chance_encoder=False, # ****** Priority ****** @@ -515,7 +515,7 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id: List = None, + ready_env_id: np.array = None, ) -> Dict: """ Overview: @@ -600,11 +600,11 @@ def _forward_collect( action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output @@ -644,7 +644,7 @@ def _get_target_obs_index_in_step_k(self, step): end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) return beg_index, end_index - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -714,11 +714,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output diff --git a/lzero/policy/random_policy.py b/lzero/policy/random_policy.py index 5a2c62c53..735a4122d 100644 --- a/lzero/policy/random_policy.py +++ b/lzero/policy/random_policy.py @@ -19,7 +19,8 @@ def __init__( self, cfg: dict, model: Optional[Union[type, torch.nn.Module]] = None, - enable_field: Optional[List[str]] = None + enable_field: Optional[List[str]] = None, + action_space: Any = None, ): if cfg.type == 'muzero': from lzero.mcts import MuZeroMCTSCtree as MCTSCtree @@ -27,10 +28,14 @@ def __init__( elif cfg.type == 'efficientzero': from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree + elif cfg.type == 'sampled_efficientzero': + from lzero.mcts import SampledEfficientZeroMCTSCtree as MCTSCtree + from lzero.mcts import SampledEfficientZeroMCTSPtree as MCTSPtree else: raise NotImplementedError("need to implement pipeline: {}".format(cfg.type)) self.MCTSCtree = MCTSCtree self.MCTSPtree = MCTSPtree + self.action_space = action_space super().__init__(cfg, model, enable_field) def default_model(self) -> Tuple[str, List[str]]: @@ -50,6 +55,8 @@ def default_model(self) -> Tuple[str, List[str]]: return 'EfficientZeroModel', ['lzero.model.efficientzero_model'] elif self._cfg.type == 'muzero': return 'MuZeroModel', ['lzero.model.muzero_model'] + elif self._cfg.type == 'sampled_efficientzero': + return 'SampledEfficientZeroModel', ['lzero.model.sampled_efficientzero_model'] else: raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) elif self._cfg.model.model_type == "mlp": @@ -57,14 +64,16 @@ def default_model(self) -> Tuple[str, List[str]]: return 'EfficientZeroModelMLP', ['lzero.model.efficientzero_model_mlp'] elif self._cfg.type == 'muzero': return 'MuZeroModelMLP', ['lzero.model.muzero_model_mlp'] + elif self._cfg.type == 'sampled_efficientzero': + return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_model_mlp'] else: raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) def _init_collect(self) -> None: """ - Overview: - Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. - """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ self._collect_model = self._model if self._cfg.mcts_ctree: self._mcts_collect = self.MCTSCtree(self._cfg) @@ -83,8 +92,8 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id = None - ): + ready_env_id: np.array = None, + ) -> Dict: """ Overview: The forward function for collecting data in collect mode. Use model to execute MCTS search. @@ -114,7 +123,7 @@ def _forward_collect( with torch.no_grad(): # data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)} network_output = self._collect_model.initial_inference(data) - if self._cfg.type == 'efficientzero': + if self._cfg.type in ['efficientzero', 'sampled_efficientzero']: latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack( network_output ) @@ -125,26 +134,56 @@ def _forward_collect( pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() latent_state_roots = latent_state_roots.detach().cpu().numpy() - if self._cfg.type == 'efficientzero': + if self._cfg.type in ['efficientzero', 'sampled_efficientzero']: reward_hidden_state_roots = ( reward_hidden_state_roots[0].detach().cpu().numpy(), reward_hidden_state_roots[1].detach().cpu().numpy() ) policy_logits = policy_logits.detach().cpu().numpy().tolist() - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + if self._cfg.model.continuous_action_space: + # when the action space of the environment is continuous, action_mask[:] is None. + # NOTE: in continuous action space env: we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(active_collect_env_num) + ] + else: + legal_actions = [ + [i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num) + ] + # the only difference between collect and eval is the dirichlet noise. - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) - ).astype(np.float32).tolist() for j in range(active_collect_env_num) - ] + if self._cfg.type in ['sampled_efficientzero']: + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(self._cfg.model.num_of_sampled_actions) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + else: + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: # cpp mcts_tree - roots = self.MCTSCtree.roots(active_collect_env_num, legal_actions) + if self._cfg.type in ['sampled_efficientzero']: + roots = self.MCTSCtree.roots( + active_collect_env_num, legal_actions, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space + ) + else: + roots = self.MCTSCtree.roots(active_collect_env_num, legal_actions) else: # python mcts_tree - roots = self.MCTSPtree.roots(active_collect_env_num, legal_actions) - if self._cfg.type == 'efficientzero': + if self._cfg.type in ['sampled_efficientzero']: + roots = self.MCTSPtree.roots( + active_collect_env_num, legal_actions, self._cfg.model.action_space_size, + self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space + ) + else: + roots = self.MCTSPtree.roots(active_collect_env_num, legal_actions) + + if self._cfg.type in ['efficientzero', 'sampled_efficientzero']: roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play) self._mcts_collect.search( roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play @@ -157,6 +196,8 @@ def _forward_collect( roots_visit_count_distributions = roots.get_distributions() roots_values = roots.get_values() # shape: {list: batch_size} + if self._cfg.type in ['sampled_efficientzero']: + roots_sampled_actions = roots.get_sampled_actions() data_id = [i for i in range(active_collect_env_num)] output = {i: None for i in data_id} @@ -165,26 +206,48 @@ def _forward_collect( for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] + + if self._cfg.type in ['sampled_efficientzero']: + if self._cfg.mcts_ctree: + # In ctree, the method roots.get_sampled_actions() returns a list object. + root_sampled_actions = np.array([action for action in roots_sampled_actions[i]]) + else: + # In ptree, the same method roots.get_sampled_actions() returns an Action object. + root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]]) + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( distributions, temperature=self._collect_mcts_temperature, deterministic=False ) - # ****** sample a random action from the legal action set ******** - # all items except action are formally obtained from MCTS - random_action = int(np.random.choice(legal_actions[env_id], 1)) # **************************************************************** - - # NOTE: The action is randomly selected from the legal action set, the distribution is the real visit count distribution from the MCTS seraech. - output[env_id] = { - 'action': random_action, - 'distributions': distributions, - 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], - } + # NOTE: The action is randomly selected from the legal action set, + # the distribution is the real visit count distribution from the MCTS search. + if self._cfg.type in ['sampled_efficientzero']: + # ****** sample a random action from the legal action set ******** + random_action = self.action_space.sample() + output[env_id] = { + 'action': random_action, + 'visit_count_distributions': distributions, + 'root_sampled_actions': root_sampled_actions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + else: + # ****** sample a random action from the legal action set ******** + random_action = int(np.random.choice(legal_actions[env_id], 1)) + # all items except action are formally obtained from MCTS + output[env_id] = { + 'action': random_action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } return output @@ -206,7 +269,7 @@ def _init_learn(self) -> None: def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: pass - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,): pass def _monitor_vars_learn(self) -> List[str]: diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index fc52387e7..184ab8c42 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -170,6 +170,8 @@ class SampledEfficientZeroPolicy(MuZeroPolicy): # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, # ****** Priority ****** # (bool) Whether to use priority when sampling training data from the buffer. @@ -786,7 +788,7 @@ def _init_collect(self) -> None: def _forward_collect( self, data: torch.Tensor, action_mask: list = None, temperature: np.ndarray = 1, to_play=-1, - epsilon: float = 0.25, ready_env_id=None + epsilon: float = 0.25, ready_env_id: np.array = None, ): """ Overview: @@ -876,22 +878,25 @@ def _forward_collect( for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] - try: - root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]]) - except Exception: - # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') + if self._cfg.mcts_ctree: + # In ctree, the method roots.get_sampled_actions() returns a list object. root_sampled_actions = np.array([action for action in roots_sampled_actions[i]]) + else: + # In ptree, the same method roots.get_sampled_actions() returns an Action object. + root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]]) + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. action, visit_count_distribution_entropy = select_action( distributions, temperature=self._collect_mcts_temperature, deterministic=False ) - try: - action = roots_sampled_actions[i][action].value - # logging.warning('ptree_sampled_efficientzero roots.get_sampled_actions() return array') - except Exception: - # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') + + if self._cfg.mcts_ctree: + # In ctree, the method roots.get_sampled_actions() returns a list object. action = np.array(roots_sampled_actions[i][action]) + else: + # In ptree, the same method roots.get_sampled_actions() returns an Action object. + action = roots_sampled_actions[i][action].value if not self._cfg.model.continuous_action_space: if len(action.shape) == 0: @@ -901,12 +906,12 @@ def _forward_collect( output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'root_sampled_actions': root_sampled_actions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output @@ -922,7 +927,7 @@ def _init_eval(self) -> None: else: self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,): """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -1037,12 +1042,12 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'root_sampled_actions': root_sampled_actions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index 78f66213f..96a9f7ff4 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -575,7 +575,7 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id=None + ready_env_id: np.array = None, ) -> Dict: """ Overview: @@ -652,11 +652,11 @@ def _forward_collect( action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output @@ -672,7 +672,7 @@ def _init_eval(self) -> None: else: self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. \ @@ -742,11 +742,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 331c72d17..aca581c47 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -411,14 +411,14 @@ def collect(self, policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon) actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} - distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} + distributions_dict_no_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: root_sampled_actions_dict_no_env_id = { k: v['root_sampled_actions'] for k, v in policy_output.items() } - value_dict_no_env_id = {k: v['value'] for k, v in policy_output.items()} - pred_value_dict_no_env_id = {k: v['pred_value'] for k, v in policy_output.items()} + value_dict_no_env_id = {k: v['searched_value'] for k, v in policy_output.items()} + pred_value_dict_no_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} visit_entropy_dict_no_env_id = { k: v['visit_count_distribution_entropy'] for k, v in policy_output.items() diff --git a/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py index 74c02b970..5d1ede4b2 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py @@ -30,6 +30,7 @@ manager=dict(shared_memory=False, ), ), policy=dict( + mcts_ctree=True, model=dict( observation_shape=8, action_space_size=2, @@ -53,6 +54,7 @@ grad_clip_value=0.5, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, + random_collect_episode_num=0, # NOTE: for continuous gaussian policy, we use the policy_entropy_loss as in the original Sampled MuZero paper. policy_entropy_loss_weight=5e-3, n_episode=n_episode,