From 4969d878ac2e0e3b21f68a33e766570cd0704411 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Mon, 30 Oct 2023 12:48:21 +0800 Subject: [PATCH] polish(pu): polish some variables name and redundant code --- lzero/mcts/tests/cprofile_mcts_ptree.py | 4 +- lzero/mcts/tests/eval_tree_speed.py | 4 +- lzero/mcts/tests/test_mcts_ctree.py | 4 +- lzero/mcts/tests/test_mcts_ptree.py | 4 +- lzero/mcts/tests/test_mcts_sampled_ctree.py | 4 +- lzero/policy/efficientzero.py | 20 +++---- lzero/policy/gumbel_muzero.py | 20 +++---- lzero/policy/muzero.py | 22 +++---- lzero/policy/random_policy.py | 57 ++++++++++--------- lzero/policy/sampled_efficientzero.py | 42 +++++++------- lzero/policy/stochastic_muzero.py | 20 +++---- lzero/worker/muzero_collector.py | 6 +- ...ander_cont_sampled_efficientzero_config.py | 3 +- 13 files changed, 107 insertions(+), 103 deletions(-) diff --git a/lzero/mcts/tests/cprofile_mcts_ptree.py b/lzero/mcts/tests/cprofile_mcts_ptree.py index 4edf376ec..956ec39fa 100644 --- a/lzero/mcts/tests/cprofile_mcts_ptree.py +++ b/lzero/mcts/tests/cprofile_mcts_ptree.py @@ -27,7 +27,7 @@ def initial_inference(self, observation): reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16))) output = { - 'value': value, + 'searched_value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, @@ -45,7 +45,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions): policy_logits = torch.zeros(size=(batch_size, self.action_num)) output = { - 'value': value, + 'searched_value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, diff --git a/lzero/mcts/tests/eval_tree_speed.py b/lzero/mcts/tests/eval_tree_speed.py index b80957dbf..c7134f3b3 100644 --- a/lzero/mcts/tests/eval_tree_speed.py +++ b/lzero/mcts/tests/eval_tree_speed.py @@ -32,7 +32,7 @@ def initial_inference(self, observation): reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16))) output = { - 'value': value, + 'searched_value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, @@ -50,7 +50,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions): policy_logits = torch.zeros(size=(batch_size, self.action_num)) output = { - 'value': value, + 'searched_value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, diff --git a/lzero/mcts/tests/test_mcts_ctree.py b/lzero/mcts/tests/test_mcts_ctree.py index f15fe6780..0e569d329 100644 --- a/lzero/mcts/tests/test_mcts_ctree.py +++ b/lzero/mcts/tests/test_mcts_ctree.py @@ -37,7 +37,7 @@ def initial_inference(self, observation): reward_hidden_state_roots = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16))) output = { - 'value': value, + 'searched_value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, @@ -60,7 +60,7 @@ def recurrent_inference(self, latent_states, reward_hidden_states, actions=None) policy_logits = torch.zeros(size=(batch_size, self.action_num)) output = { - 'value': value, + 'searched_value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, diff --git a/lzero/mcts/tests/test_mcts_ptree.py b/lzero/mcts/tests/test_mcts_ptree.py index 613c0a205..75ec347ca 100644 --- a/lzero/mcts/tests/test_mcts_ptree.py +++ b/lzero/mcts/tests/test_mcts_ptree.py @@ -29,7 +29,7 @@ def initial_inference(self, observation): reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16))) output = { - 'value': value, + 'searched_value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, @@ -47,7 +47,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions): policy_logits = torch.zeros(size=(batch_size, self.action_num)) output = { - 'value': value, + 'searched_value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, diff --git a/lzero/mcts/tests/test_mcts_sampled_ctree.py b/lzero/mcts/tests/test_mcts_sampled_ctree.py index 8b0a74b21..2e4c15277 100644 --- a/lzero/mcts/tests/test_mcts_sampled_ctree.py +++ b/lzero/mcts/tests/test_mcts_sampled_ctree.py @@ -29,7 +29,7 @@ def initial_inference(self, observation): reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16))) output = { - 'value': value, + 'searched_value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, @@ -48,7 +48,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions): # policy_logits = torch.zeros(size=(batch_size, self.action_num)) output = { - 'value': value, + 'searched_value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, diff --git a/lzero/policy/efficientzero.py b/lzero/policy/efficientzero.py index cd69bba86..685e473d7 100644 --- a/lzero/policy/efficientzero.py +++ b/lzero/policy/efficientzero.py @@ -155,7 +155,7 @@ class EfficientZeroPolicy(MuZeroPolicy): # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. fixed_temperature_value=0.25, - # (bool) Whether to use the true chance in MCTS in 2048 env. + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. use_ture_chance_label_in_chance_encoder=False, # ****** Priority ****** @@ -626,11 +626,11 @@ def _forward_collect( action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output @@ -646,7 +646,7 @@ def _init_eval(self) -> None: else: self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,): """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -718,11 +718,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output diff --git a/lzero/policy/gumbel_muzero.py b/lzero/policy/gumbel_muzero.py index b4f88a387..218ab99b6 100644 --- a/lzero/policy/gumbel_muzero.py +++ b/lzero/policy/gumbel_muzero.py @@ -486,7 +486,7 @@ def _forward_collect( action_mask: list = None, temperature: float = 1, to_play: List = [-1], - ready_env_id=None + ready_env_id: np.array = None, ) -> Dict: """ Overview: @@ -572,13 +572,13 @@ def _forward_collect( action = np.argmax([v for v in valid_value]) output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, + 'searched_value': value, 'roots_completed_value': roots_completed_value, 'improved_policy_probs': improved_policy_probs, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output @@ -594,7 +594,7 @@ def _init_eval(self) -> None: else: self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -674,11 +674,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py index 136e7d250..80bbc53ba 100644 --- a/lzero/policy/muzero.py +++ b/lzero/policy/muzero.py @@ -152,7 +152,7 @@ class MuZeroPolicy(Policy): # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. fixed_temperature_value=0.25, - # (bool) Whether to use the true chance in MCTS in 2048 env. + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. use_ture_chance_label_in_chance_encoder=False, # ****** Priority ****** @@ -477,7 +477,7 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id=None + ready_env_id: np.array = None, ) -> Dict: """ Overview: @@ -562,11 +562,11 @@ def _forward_collect( action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output @@ -606,7 +606,7 @@ def _get_target_obs_index_in_step_k(self, step): end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num) return beg_index, end_index - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -676,11 +676,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output diff --git a/lzero/policy/random_policy.py b/lzero/policy/random_policy.py index 3abd568ce..735a4122d 100644 --- a/lzero/policy/random_policy.py +++ b/lzero/policy/random_policy.py @@ -20,7 +20,7 @@ def __init__( cfg: dict, model: Optional[Union[type, torch.nn.Module]] = None, enable_field: Optional[List[str]] = None, - action_space = None, + action_space: Any = None, ): if cfg.type == 'muzero': from lzero.mcts import MuZeroMCTSCtree as MCTSCtree @@ -65,15 +65,15 @@ def default_model(self) -> Tuple[str, List[str]]: elif self._cfg.type == 'muzero': return 'MuZeroModelMLP', ['lzero.model.muzero_model_mlp'] elif self._cfg.type == 'sampled_efficientzero': - return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_modelMLP'] + return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_model_mlp'] else: raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type)) def _init_collect(self) -> None: """ - Overview: - Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. - """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ self._collect_model = self._model if self._cfg.mcts_ctree: self._mcts_collect = self.MCTSCtree(self._cfg) @@ -92,8 +92,8 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id=None, - ): + ready_env_id: np.array = None, + ) -> Dict: """ Overview: The forward function for collecting data in collect mode. Use model to execute MCTS search. @@ -141,7 +141,7 @@ def _forward_collect( ) policy_logits = policy_logits.detach().cpu().numpy().tolist() - if self._cfg.model.continuous_action_space is True: + if self._cfg.model.continuous_action_space: # when the action space of the environment is continuous, action_mask[:] is None. # NOTE: in continuous action space env: we set all legal_actions as -1 legal_actions = [ @@ -208,11 +208,12 @@ def _forward_collect( distributions, value = roots_visit_count_distributions[i], roots_values[i] if self._cfg.type in ['sampled_efficientzero']: - try: - root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]]) - except Exception: - # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') + if self._cfg.mcts_ctree: + # In ctree, the method roots.get_sampled_actions() returns a list object. root_sampled_actions = np.array([action for action in roots_sampled_actions[i]]) + else: + # In ptree, the same method roots.get_sampled_actions() returns an Action object. + root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]]) # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. @@ -220,32 +221,32 @@ def _forward_collect( distributions, temperature=self._collect_mcts_temperature, deterministic=False ) - # ****** sample a random action from the legal action set ******** - if self._cfg.type in ['sampled_efficientzero']: - random_action = self.action_space.sample() - else: - # all items except action are formally obtained from MCTS - random_action = int(np.random.choice(legal_actions[env_id], 1)) # **************************************************************** - # NOTE: The action is randomly selected from the legal action set, the distribution is the real visit count distribution from the MCTS seraech. + # NOTE: The action is randomly selected from the legal action set, + # the distribution is the real visit count distribution from the MCTS search. if self._cfg.type in ['sampled_efficientzero']: + # ****** sample a random action from the legal action set ******** + random_action = self.action_space.sample() output[env_id] = { 'action': random_action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'root_sampled_actions': root_sampled_actions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } else: + # ****** sample a random action from the legal action set ******** + random_action = int(np.random.choice(legal_actions[env_id], 1)) + # all items except action are formally obtained from MCTS output[env_id] = { 'action': random_action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output @@ -268,7 +269,7 @@ def _init_learn(self) -> None: def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]: pass - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,): pass def _monitor_vars_learn(self) -> List[str]: diff --git a/lzero/policy/sampled_efficientzero.py b/lzero/policy/sampled_efficientzero.py index 5f6509d68..184ab8c42 100644 --- a/lzero/policy/sampled_efficientzero.py +++ b/lzero/policy/sampled_efficientzero.py @@ -170,7 +170,7 @@ class SampledEfficientZeroPolicy(MuZeroPolicy): # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. fixed_temperature_value=0.25, - # (bool) Whether to use the true chance in MCTS in 2048 env. + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. use_ture_chance_label_in_chance_encoder=False, # ****** Priority ****** @@ -788,7 +788,7 @@ def _init_collect(self) -> None: def _forward_collect( self, data: torch.Tensor, action_mask: list = None, temperature: np.ndarray = 1, to_play=-1, - epsilon: float = 0.25, ready_env_id=None + epsilon: float = 0.25, ready_env_id: np.array = None, ): """ Overview: @@ -878,23 +878,25 @@ def _forward_collect( for i, env_id in enumerate(ready_env_id): distributions, value = roots_visit_count_distributions[i], roots_values[i] - try: - root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]]) - except Exception: - # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') + if self._cfg.mcts_ctree: + # In ctree, the method roots.get_sampled_actions() returns a list object. root_sampled_actions = np.array([action for action in roots_sampled_actions[i]]) + else: + # In ptree, the same method roots.get_sampled_actions() returns an Action object. + root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]]) # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents # the index within the legal action set, rather than the index in the entire action set. action, visit_count_distribution_entropy = select_action( distributions, temperature=self._collect_mcts_temperature, deterministic=False ) - try: - action = roots_sampled_actions[i][action].value - # logging.warning('ptree_sampled_efficientzero roots.get_sampled_actions() return array') - except Exception: - # logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list') + + if self._cfg.mcts_ctree: + # In ctree, the method roots.get_sampled_actions() returns a list object. action = np.array(roots_sampled_actions[i][action]) + else: + # In ptree, the same method roots.get_sampled_actions() returns an Action object. + action = roots_sampled_actions[i][action].value if not self._cfg.model.continuous_action_space: if len(action.shape) == 0: @@ -904,12 +906,12 @@ def _forward_collect( output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'root_sampled_actions': root_sampled_actions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output @@ -925,7 +927,7 @@ def _init_eval(self) -> None: else: self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None): + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,): """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -1040,12 +1042,12 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'root_sampled_actions': root_sampled_actions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output diff --git a/lzero/policy/stochastic_muzero.py b/lzero/policy/stochastic_muzero.py index 78f66213f..96a9f7ff4 100644 --- a/lzero/policy/stochastic_muzero.py +++ b/lzero/policy/stochastic_muzero.py @@ -575,7 +575,7 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id=None + ready_env_id: np.array = None, ) -> Dict: """ Overview: @@ -652,11 +652,11 @@ def _forward_collect( action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output @@ -672,7 +672,7 @@ def _init_eval(self) -> None: else: self._mcts_eval = MCTSPtree(self._cfg) - def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict: + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. \ @@ -742,11 +742,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 output[env_id] = { 'action': action, - 'distributions': distributions, + 'visit_count_distributions': distributions, 'visit_count_distribution_entropy': visit_count_distribution_entropy, - 'value': value, - 'pred_value': pred_values[i], - 'policy_logits': policy_logits[i], + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], } return output diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 331c72d17..aca581c47 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -411,14 +411,14 @@ def collect(self, policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon) actions_no_env_id = {k: v['action'] for k, v in policy_output.items()} - distributions_dict_no_env_id = {k: v['distributions'] for k, v in policy_output.items()} + distributions_dict_no_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} if self.policy_config.sampled_algo: root_sampled_actions_dict_no_env_id = { k: v['root_sampled_actions'] for k, v in policy_output.items() } - value_dict_no_env_id = {k: v['value'] for k, v in policy_output.items()} - pred_value_dict_no_env_id = {k: v['pred_value'] for k, v in policy_output.items()} + value_dict_no_env_id = {k: v['searched_value'] for k, v in policy_output.items()} + pred_value_dict_no_env_id = {k: v['predicted_value'] for k, v in policy_output.items()} visit_entropy_dict_no_env_id = { k: v['visit_count_distribution_entropy'] for k, v in policy_output.items() diff --git a/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py index 4de24a8bf..5d1ede4b2 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py @@ -30,6 +30,7 @@ manager=dict(shared_memory=False, ), ), policy=dict( + mcts_ctree=True, model=dict( observation_shape=8, action_space_size=2, @@ -53,7 +54,7 @@ grad_clip_value=0.5, num_simulations=num_simulations, reanalyze_ratio=reanalyze_ratio, - random_collect_episode_num=8, + random_collect_episode_num=0, # NOTE: for continuous gaussian policy, we use the policy_entropy_loss as in the original Sampled MuZero paper. policy_entropy_loss_weight=5e-3, n_episode=n_episode,