Skip to content

Commit

Permalink
feature(pu): add random_policy support for continuous env (#118)
Browse files Browse the repository at this point in the history
* polish(pu): add random_policy support for continuous env

* polish(pu): polish some variables name and redundant code

* fix(pu): fix mcts_ptree/mcts_ctree unittest

* fix(pu): fix test_mcts_sampled_ctree
  • Loading branch information
puyuan1996 authored Oct 30, 2023
1 parent 4de2a9e commit d845648
Show file tree
Hide file tree
Showing 12 changed files with 196 additions and 118 deletions.
2 changes: 1 addition & 1 deletion lzero/entry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def random_collect(
) -> None: # noqa
assert policy_cfg.random_collect_episode_num > 0

random_policy = RandomPolicy(cfg=policy_cfg)
random_policy = RandomPolicy(cfg=policy_cfg, action_space=collector_env.env_ref.action_space)
# set the policy to random policy
collector.reset_policy(random_policy.collect_mode)

Expand Down
4 changes: 2 additions & 2 deletions lzero/mcts/tests/cprofile_mcts_ptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def initial_inference(self, observation):
reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))

output = {
'value': value,
'searched_value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand All @@ -45,7 +45,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions):
policy_logits = torch.zeros(size=(batch_size, self.action_num))

output = {
'value': value,
'searched_value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand Down
4 changes: 2 additions & 2 deletions lzero/mcts/tests/eval_tree_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def initial_inference(self, observation):
reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))

output = {
'value': value,
'searched_value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand All @@ -50,7 +50,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions):
policy_logits = torch.zeros(size=(batch_size, self.action_num))

output = {
'value': value,
'searched_value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand Down
48 changes: 27 additions & 21 deletions lzero/mcts/tests/test_mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from lzero.policy import inverse_scalar_transform, select_action


policy = 'GumbelMuZero'

if policy == 'EfficientZero':
Expand All @@ -14,6 +13,8 @@
from lzero.mcts.tree_search.mcts_ctree import GumbelMuZeroMCTSCtree as MCTSCtree
else:
raise KeyError('Only support test for EfficientZero and GumbelMuZero.')


class MuZeroModelFake(torch.nn.Module):
"""
Overview:
Expand Down Expand Up @@ -50,7 +51,7 @@ def initial_inference(self, observation):

def recurrent_inference(self, latent_states, reward_hidden_states, actions=None):
if policy == 'GumbelMuZero':
assert actions==None
assert actions == None
actions = reward_hidden_states
batch_size = latent_states.shape[0]
latent_state = torch.zeros(size=(batch_size, 12, 3, 3))
Expand Down Expand Up @@ -78,7 +79,7 @@ def recurrent_inference(self, latent_states, reward_hidden_states, actions=None)
batch_size=16,
pb_c_base=1,
pb_c_init=1,
max_num_considered_actions = 6,
max_num_considered_actions=6,
discount_factor=0.9,
root_dirichlet_alpha=0.3,
root_noise_weight=0.2,
Expand Down Expand Up @@ -163,17 +164,18 @@ def test_mcts_vs_bot_to_play():
policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
elif policy == 'GumbelMuZero':
roots.prepare(
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)]
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool,
[0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
roots_distributions = roots.get_distributions()
roots_values = roots.get_values()
assert np.array(roots_distributions).shape == (batch_size, action_space_size)
assert np.array(roots_values).shape == (batch_size, )
assert np.array(roots_values).shape == (batch_size,)


@pytest.mark.unittest
Expand Down Expand Up @@ -220,17 +222,18 @@ def test_mcts_vs_bot_to_play_large():
policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
elif policy == 'GumbelMuZero':
roots.prepare(
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)]
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool,
[0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
roots_distributions = roots.get_distributions()
roots_values = roots.get_values()
assert np.array(roots_distributions).shape == (policy_config.batch_size, policy_config.model.action_space_size)
assert np.array(roots_values).shape == (policy_config.batch_size, )
assert np.array(roots_values).shape == (policy_config.batch_size,)


@pytest.mark.unittest
Expand All @@ -250,13 +253,14 @@ def test_mcts_vs_bot_to_play_legal_action():
policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
elif policy == 'GumbelMuZero':
roots.prepare(
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)]
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool,
[0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
roots_distributions = roots.get_distributions()
roots_values = roots.get_values()
assert len(roots_values) == env_nums
Expand Down Expand Up @@ -290,17 +294,18 @@ def test_mcts_self_play():
policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
elif policy == 'GumbelMuZero':
roots.prepare(
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)]
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool,
[0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
roots_distributions = roots.get_distributions()
roots_values = roots.get_values()
assert np.array(roots_distributions).shape == (batch_size, action_space_size)
assert np.array(roots_values).shape == (batch_size, )
assert np.array(roots_values).shape == (batch_size,)


@pytest.mark.unittest
Expand All @@ -319,13 +324,14 @@ def test_mcts_self_play_legal_action():
policy_config.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, [0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, reward_hidden_state_roots, [0 for _ in range(env_nums)])
elif policy == 'GumbelMuZero':
roots.prepare(
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool, [0 for _ in range(env_nums)]
policy_config.root_noise_weight, noises, value_prefix_pool, list(pred_values_pool), policy_logits_pool,
[0 for _ in range(env_nums)]
)
MCTSCtree(policy_config
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
).search(roots, model, latent_state_roots, [0 for _ in range(env_nums)])
roots_distributions = roots.get_distributions()
roots_values = roots.get_values()
assert len(roots_values) == env_nums
Expand Down
20 changes: 11 additions & 9 deletions lzero/policy/efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ class EfficientZeroPolicy(MuZeroPolicy):
# (float) The fixed temperature value for MCTS action selection, which is used to control the exploration.
# The larger the value, the more exploration. This value is only used when manual_temperature_decay=False.
fixed_temperature_value=0.25,
# (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048.
use_ture_chance_label_in_chance_encoder=False,

# ****** Priority ******
# (bool) Whether to use priority when sampling training data from the buffer.
Expand Down Expand Up @@ -624,11 +626,11 @@ def _forward_collect(
action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
output[env_id] = {
'action': action,
'distributions': distributions,
'visit_count_distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
'searched_value': value,
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}

return output
Expand All @@ -644,7 +646,7 @@ def _init_eval(self) -> None:
else:
self._mcts_eval = MCTSPtree(self._cfg)

def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None):
def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,):
"""
Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search.
Expand Down Expand Up @@ -716,11 +718,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read
action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
output[env_id] = {
'action': action,
'distributions': distributions,
'visit_count_distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
'searched_value': value,
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}

return output
Expand Down
20 changes: 10 additions & 10 deletions lzero/policy/gumbel_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def _forward_collect(
action_mask: list = None,
temperature: float = 1,
to_play: List = [-1],
ready_env_id=None
ready_env_id: np.array = None,
) -> Dict:
"""
Overview:
Expand Down Expand Up @@ -572,13 +572,13 @@ def _forward_collect(
action = np.argmax([v for v in valid_value])
output[env_id] = {
'action': action,
'distributions': distributions,
'visit_count_distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'searched_value': value,
'roots_completed_value': roots_completed_value,
'improved_policy_probs': improved_policy_probs,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}

return output
Expand All @@ -594,7 +594,7 @@ def _init_eval(self) -> None:
else:
self._mcts_eval = MCTSPtree(self._cfg)

def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict:
def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict:
"""
Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search.
Expand Down Expand Up @@ -674,11 +674,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1

output[env_id] = {
'action': action,
'distributions': distributions,
'visit_count_distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
'searched_value': value,
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}

return output
Expand Down
22 changes: 11 additions & 11 deletions lzero/policy/muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ class MuZeroPolicy(Policy):
# (float) The fixed temperature value for MCTS action selection, which is used to control the exploration.
# The larger the value, the more exploration. This value is only used when manual_temperature_decay=False.
fixed_temperature_value=0.25,
# (bool) Whether to use the true chance in MCTS.
# (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048.
use_ture_chance_label_in_chance_encoder=False,

# ****** Priority ******
Expand Down Expand Up @@ -515,7 +515,7 @@ def _forward_collect(
temperature: float = 1,
to_play: List = [-1],
epsilon: float = 0.25,
ready_env_id: List = None,
ready_env_id: np.array = None,
) -> Dict:
"""
Overview:
Expand Down Expand Up @@ -600,11 +600,11 @@ def _forward_collect(
action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
output[env_id] = {
'action': action,
'distributions': distributions,
'visit_count_distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
'searched_value': value,
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}

return output
Expand Down Expand Up @@ -644,7 +644,7 @@ def _get_target_obs_index_in_step_k(self, step):
end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num)
return beg_index, end_index

def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict:
def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict:
"""
Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search.
Expand Down Expand Up @@ -714,11 +714,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1

output[env_id] = {
'action': action,
'distributions': distributions,
'visit_count_distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
'searched_value': value,
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}

return output
Expand Down
Loading

0 comments on commit d845648

Please sign in to comment.