Skip to content

Commit

Permalink
fix(pu): fix mcts_ptree/mcts_ctree unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Oct 30, 2023
1 parent 4bbb309 commit 4ac9c8f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 25 deletions.
52 changes: 29 additions & 23 deletions lzero/mcts/tests/test_mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from lzero.policy import inverse_scalar_transform, select_action


policy = 'GumbelMuZero'

if policy == 'EfficientZero':
Expand All @@ -14,6 +13,8 @@
from lzero.mcts.tree_search.mcts_ctree import GumbelMuZeroMCTSCtree as MCTSCtree
else:
raise KeyError('Only support test for EfficientZero and GumbelMuZero.')


class MuZeroModelFake(torch.nn.Module):
"""
Overview:
Expand All @@ -37,7 +38,7 @@ def initial_inference(self, observation):
reward_hidden_state_roots = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))

output = {
'searched_value': value,
'value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand All @@ -50,7 +51,7 @@ def initial_inference(self, observation):

def recurrent_inference(self, latent_states, reward_hidden_states, actions=None):
if policy == 'GumbelMuZero':
assert actions==None
assert actions == None
actions = reward_hidden_states
batch_size = latent_states.shape[0]
latent_state = torch.zeros(size=(batch_size, 12, 3, 3))
Expand All @@ -60,7 +61,7 @@ def recurrent_inference(self, latent_states, reward_hidden_states, actions=None)
policy_logits = torch.zeros(size=(batch_size, self.action_num))

output = {
'searched_value': value,
'value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand All @@ -78,7 +79,7 @@ def recurrent_inference(self, latent_states, reward_hidden_states, actions=None)
batch_size=16,
pb_c_base=1,
pb_c_init=1,
max_num_considered_actions = 6,
max_num_considered_actions=6,
discount_factor=0.9,
root_dirichlet_alpha=0.3,
root_noise_weight=0.2,
Expand Down Expand Up @@ -163,17 +164,18 @@ def test_mcts_vs_bot_to_play():
policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
elif policy == 'GumbelMuZero':
roots.prepare(
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)]
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool,
[0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
roots_distributions = roots.get_distributions()
roots_values = roots.get_values()
assert np.array(roots_distributions).shape == (batch_size, action_space_size)
assert np.array(roots_values).shape == (batch_size, )
assert np.array(roots_values).shape == (batch_size,)


@pytest.mark.unittest
Expand Down Expand Up @@ -220,17 +222,18 @@ def test_mcts_vs_bot_to_play_large():
policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
elif policy == 'GumbelMuZero':
roots.prepare(
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)]
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool,
[0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
roots_distributions = roots.get_distributions()
roots_values = roots.get_values()
assert np.array(roots_distributions).shape == (policy_config.batch_size, policy_config.model.action_space_size)
assert np.array(roots_values).shape == (policy_config.batch_size, )
assert np.array(roots_values).shape == (policy_config.batch_size,)


@pytest.mark.unittest
Expand All @@ -250,13 +253,14 @@ def test_mcts_vs_bot_to_play_legal_action():
policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
elif policy == 'GumbelMuZero':
roots.prepare(
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)]
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool,
[0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
roots_distributions = roots.get_distributions()
roots_values = roots.get_values()
assert len(roots_values) == env_nums
Expand Down Expand Up @@ -290,17 +294,18 @@ def test_mcts_self_play():
policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
elif policy == 'GumbelMuZero':
roots.prepare(
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)]
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool,
[0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
roots_distributions = roots.get_distributions()
roots_values = roots.get_values()
assert np.array(roots_distributions).shape == (batch_size, action_space_size)
assert np.array(roots_values).shape == (batch_size, )
assert np.array(roots_values).shape == (batch_size,)


@pytest.mark.unittest
Expand All @@ -319,13 +324,14 @@ def test_mcts_self_play_legal_action():
policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
elif policy == 'GumbelMuZero':
roots.prepare(
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)]
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool,
[0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
roots_distributions = roots.get_distributions()
roots_values = roots.get_values()
assert len(roots_values) == env_nums
Expand Down
4 changes: 2 additions & 2 deletions lzero/mcts/tests/test_mcts_ptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def initial_inference(self, observation):
reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))

output = {
'searched_value': value,
'value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand All @@ -47,7 +47,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions):
policy_logits = torch.zeros(size=(batch_size, self.action_num))

output = {
'searched_value': value,
'value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand Down

0 comments on commit 4ac9c8f

Please sign in to comment.