diff --git a/lzero/mcts/tests/test_mcts_ctree.py b/lzero/mcts/tests/test_mcts_ctree.py index 0e569d329..e9a6424f0 100644 --- a/lzero/mcts/tests/test_mcts_ctree.py +++ b/lzero/mcts/tests/test_mcts_ctree.py @@ -5,7 +5,6 @@ from lzero.policy import inverse_scalar_transform, select_action - policy = 'GumbelMuZero' if policy == 'EfficientZero': @@ -14,6 +13,8 @@ from lzero.mcts.tree_search.mcts_ctree import GumbelMuZeroMCTSCtree as MCTSCtree else: raise KeyError('Only support test for EfficientZero and GumbelMuZero.') + + class MuZeroModelFake(torch.nn.Module): """ Overview: @@ -37,7 +38,7 @@ def initial_inference(self, observation): reward_hidden_state_roots = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16))) output = { - 'searched_value': value, + 'value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, @@ -50,7 +51,7 @@ def initial_inference(self, observation): def recurrent_inference(self, latent_states, reward_hidden_states, actions=None): if policy == 'GumbelMuZero': - assert actions==None + assert actions == None actions = reward_hidden_states batch_size = latent_states.shape[0] latent_state = torch.zeros(size=(batch_size, 12, 3, 3)) @@ -60,7 +61,7 @@ def recurrent_inference(self, latent_states, reward_hidden_states, actions=None) policy_logits = torch.zeros(size=(batch_size, self.action_num)) output = { - 'searched_value': value, + 'value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, @@ -78,7 +79,7 @@ def recurrent_inference(self, latent_states, reward_hidden_states, actions=None) batch_size=16, pb_c_base=1, pb_c_init=1, - max_num_considered_actions = 6, + max_num_considered_actions=6, discount_factor=0.9, root_dirichlet_alpha=0.3, root_noise_weight=0.2, @@ -163,17 +164,18 @@ def test_mcts_vs_bot_to_play(): policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) elif policy == 'GumbelMuZero': roots.prepare( - policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)] + policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, + [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) roots_distributions = roots.get_distributions() roots_values = roots.get_values() assert np.array(roots_distributions).shape == (batch_size, action_space_size) - assert np.array(roots_values).shape == (batch_size, ) + assert np.array(roots_values).shape == (batch_size,) @pytest.mark.unittest @@ -220,17 +222,18 @@ def test_mcts_vs_bot_to_play_large(): policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) elif policy == 'GumbelMuZero': roots.prepare( - policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)] + policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, + [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) roots_distributions = roots.get_distributions() roots_values = roots.get_values() assert np.array(roots_distributions).shape == (policy_config.batch_size, policy_config.model.action_space_size) - assert np.array(roots_values).shape == (policy_config.batch_size, ) + assert np.array(roots_values).shape == (policy_config.batch_size,) @pytest.mark.unittest @@ -250,13 +253,14 @@ def test_mcts_vs_bot_to_play_legal_action(): policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) elif policy == 'GumbelMuZero': roots.prepare( - policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)] + policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, + [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) roots_distributions = roots.get_distributions() roots_values = roots.get_values() assert len(roots_values) == env_nums @@ -290,17 +294,18 @@ def test_mcts_self_play(): policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) elif policy == 'GumbelMuZero': roots.prepare( - policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)] + policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, + [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) roots_distributions = roots.get_distributions() roots_values = roots.get_values() assert np.array(roots_distributions).shape == (batch_size, action_space_size) - assert np.array(roots_values).shape == (batch_size, ) + assert np.array(roots_values).shape == (batch_size,) @pytest.mark.unittest @@ -319,13 +324,14 @@ def test_mcts_self_play_legal_action(): policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)]) elif policy == 'GumbelMuZero': roots.prepare( - policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)] + policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, + [0 for _ in range(env_nums)] ) MCTSCtree(policy_config - ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) + ).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)]) roots_distributions = roots.get_distributions() roots_values = roots.get_values() assert len(roots_values) == env_nums diff --git a/lzero/mcts/tests/test_mcts_ptree.py b/lzero/mcts/tests/test_mcts_ptree.py index 75ec347ca..613c0a205 100644 --- a/lzero/mcts/tests/test_mcts_ptree.py +++ b/lzero/mcts/tests/test_mcts_ptree.py @@ -29,7 +29,7 @@ def initial_inference(self, observation): reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16))) output = { - 'searched_value': value, + 'value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state, @@ -47,7 +47,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions): policy_logits = torch.zeros(size=(batch_size, self.action_num)) output = { - 'searched_value': value, + 'value': value, 'value_prefix': value_prefix, 'policy_logits': policy_logits, 'latent_state': latent_state,