Skip to content

Commit

Permalink
polish(pu): add random_policy support for continuous env
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Oct 22, 2023
1 parent 030173a commit c093d77
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 25 deletions.
2 changes: 1 addition & 1 deletion lzero/entry/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def random_collect(
) -> None: # noqa
assert policy_cfg.random_collect_episode_num > 0

random_policy = RandomPolicy(cfg=policy_cfg)
random_policy = RandomPolicy(cfg=policy_cfg, action_space=collector_env.env_ref.action_space)
# set the policy to random policy
collector.reset_policy(random_policy.collect_mode)

Expand Down
2 changes: 2 additions & 0 deletions lzero/policy/efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ class EfficientZeroPolicy(MuZeroPolicy):
# (float) The fixed temperature value for MCTS action selection, which is used to control the exploration.
# The larger the value, the more exploration. This value is only used when manual_temperature_decay=False.
fixed_temperature_value=0.25,
# (bool) Whether to use the true chance in MCTS in 2048 env.
use_ture_chance_label_in_chance_encoder=False,

# ****** Priority ******
# (bool) Whether to use priority when sampling training data from the buffer.
Expand Down
2 changes: 1 addition & 1 deletion lzero/policy/muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class MuZeroPolicy(Policy):
# (float) The fixed temperature value for MCTS action selection, which is used to control the exploration.
# The larger the value, the more exploration. This value is only used when manual_temperature_decay=False.
fixed_temperature_value=0.25,
# (bool) Whether to use the true chance in MCTS.
# (bool) Whether to use the true chance in MCTS in 2048 env.
use_ture_chance_label_in_chance_encoder=False,

# ****** Priority ******
Expand Down
108 changes: 85 additions & 23 deletions lzero/policy/random_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,23 @@ def __init__(
self,
cfg: dict,
model: Optional[Union[type, torch.nn.Module]] = None,
enable_field: Optional[List[str]] = None
enable_field: Optional[List[str]] = None,
action_space = None,
):
if cfg.type == 'muzero':
from lzero.mcts import MuZeroMCTSCtree as MCTSCtree
from lzero.mcts import MuZeroMCTSPtree as MCTSPtree
elif cfg.type == 'efficientzero':
from lzero.mcts import EfficientZeroMCTSCtree as MCTSCtree
from lzero.mcts import EfficientZeroMCTSPtree as MCTSPtree
elif cfg.type == 'sampled_efficientzero':
from lzero.mcts import SampledEfficientZeroMCTSCtree as MCTSCtree
from lzero.mcts import SampledEfficientZeroMCTSPtree as MCTSPtree
else:
raise NotImplementedError("need to implement pipeline: {}".format(cfg.type))
self.MCTSCtree = MCTSCtree
self.MCTSPtree = MCTSPtree
self.action_space = action_space
super().__init__(cfg, model, enable_field)

def default_model(self) -> Tuple[str, List[str]]:
Expand All @@ -50,13 +55,17 @@ def default_model(self) -> Tuple[str, List[str]]:
return 'EfficientZeroModel', ['lzero.model.efficientzero_model']
elif self._cfg.type == 'muzero':
return 'MuZeroModel', ['lzero.model.muzero_model']
elif self._cfg.type == 'sampled_efficientzero':
return 'SampledEfficientZeroModel', ['lzero.model.sampled_efficientzero_model']
else:
raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type))
elif self._cfg.model.model_type == "mlp":
if self._cfg.type == 'efficientzero':
return 'EfficientZeroModelMLP', ['lzero.model.efficientzero_model_mlp']
elif self._cfg.type == 'muzero':
return 'MuZeroModelMLP', ['lzero.model.muzero_model_mlp']
elif self._cfg.type == 'sampled_efficientzero':
return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_modelMLP']
else:
raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type))

Expand All @@ -83,7 +92,7 @@ def _forward_collect(
temperature: float = 1,
to_play: List = [-1],
epsilon: float = 0.25,
ready_env_id = None
ready_env_id=None,
):
"""
Overview:
Expand Down Expand Up @@ -114,7 +123,7 @@ def _forward_collect(
with torch.no_grad():
# data shape [B, S x C, W, H], e.g. {Tensor:(B, 12, 96, 96)}
network_output = self._collect_model.initial_inference(data)
if self._cfg.type == 'efficientzero':
if self._cfg.type in ['efficientzero', 'sampled_efficientzero']:
latent_state_roots, value_prefix_roots, reward_hidden_state_roots, pred_values, policy_logits = ez_network_output_unpack(
network_output
)
Expand All @@ -125,26 +134,56 @@ def _forward_collect(

pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy()
latent_state_roots = latent_state_roots.detach().cpu().numpy()
if self._cfg.type == 'efficientzero':
if self._cfg.type in ['efficientzero', 'sampled_efficientzero']:
reward_hidden_state_roots = (
reward_hidden_state_roots[0].detach().cpu().numpy(),
reward_hidden_state_roots[1].detach().cpu().numpy()
)
policy_logits = policy_logits.detach().cpu().numpy().tolist()

legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)]
if self._cfg.model.continuous_action_space is True:
# when the action space of the environment is continuous, action_mask[:] is None.
# NOTE: in continuous action space env: we set all legal_actions as -1
legal_actions = [
[-1 for _ in range(self._cfg.model.num_of_sampled_actions)] for _ in range(active_collect_env_num)
]
else:
legal_actions = [
[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)
]

# the only difference between collect and eval is the dirichlet noise.
noises = [
np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j]))
).astype(np.float32).tolist() for j in range(active_collect_env_num)
]
if self._cfg.type in ['sampled_efficientzero']:
noises = [
np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(self._cfg.model.num_of_sampled_actions)
).astype(np.float32).tolist() for j in range(active_collect_env_num)
]
else:
noises = [
np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j]))
).astype(np.float32).tolist() for j in range(active_collect_env_num)
]

if self._cfg.mcts_ctree:
# cpp mcts_tree
roots = self.MCTSCtree.roots(active_collect_env_num, legal_actions)
if self._cfg.type in ['sampled_efficientzero']:
roots = self.MCTSCtree.roots(
active_collect_env_num, legal_actions, self._cfg.model.action_space_size,
self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space
)
else:
roots = self.MCTSCtree.roots(active_collect_env_num, legal_actions)
else:
# python mcts_tree
roots = self.MCTSPtree.roots(active_collect_env_num, legal_actions)
if self._cfg.type == 'efficientzero':
if self._cfg.type in ['sampled_efficientzero']:
roots = self.MCTSPtree.roots(
active_collect_env_num, legal_actions, self._cfg.model.action_space_size,
self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space
)
else:
roots = self.MCTSPtree.roots(active_collect_env_num, legal_actions)

if self._cfg.type in ['efficientzero', 'sampled_efficientzero']:
roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_roots, policy_logits, to_play)
self._mcts_collect.search(
roots, self._collect_model, latent_state_roots, reward_hidden_state_roots, to_play
Expand All @@ -157,6 +196,8 @@ def _forward_collect(

roots_visit_count_distributions = roots.get_distributions()
roots_values = roots.get_values() # shape: {list: batch_size}
if self._cfg.type in ['sampled_efficientzero']:
roots_sampled_actions = roots.get_sampled_actions()

data_id = [i for i in range(active_collect_env_num)]
output = {i: None for i in data_id}
Expand All @@ -165,26 +206,47 @@ def _forward_collect(

for i, env_id in enumerate(ready_env_id):
distributions, value = roots_visit_count_distributions[i], roots_values[i]

if self._cfg.type in ['sampled_efficientzero']:
try:
root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]])
except Exception:
# logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list')
root_sampled_actions = np.array([action for action in roots_sampled_actions[i]])

# NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents
# the index within the legal action set, rather than the index in the entire action set.
action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
distributions, temperature=self._collect_mcts_temperature, deterministic=False
)

# ****** sample a random action from the legal action set ********
# all items except action are formally obtained from MCTS
random_action = int(np.random.choice(legal_actions[env_id], 1))
if self._cfg.type in ['sampled_efficientzero']:
random_action = self.action_space.sample()
else:
# all items except action are formally obtained from MCTS
random_action = int(np.random.choice(legal_actions[env_id], 1))
# ****************************************************************

# NOTE: The action is randomly selected from the legal action set, the distribution is the real visit count distribution from the MCTS seraech.
output[env_id] = {
'action': random_action,
'distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
}
if self._cfg.type in ['sampled_efficientzero']:
output[env_id] = {
'action': random_action,
'distributions': distributions,
'root_sampled_actions': root_sampled_actions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
}
else:
output[env_id] = {
'action': random_action,
'distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
}

return output

Expand Down
3 changes: 3 additions & 0 deletions lzero/policy/sampled_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ class SampledEfficientZeroPolicy(MuZeroPolicy):
# (float) The fixed temperature value for MCTS action selection, which is used to control the exploration.
# The larger the value, the more exploration. This value is only used when manual_temperature_decay=False.
fixed_temperature_value=0.25,
# (bool) Whether to use the true chance in MCTS in 2048 env.
use_ture_chance_label_in_chance_encoder=False,

# ****** Priority ******
# (bool) Whether to use priority when sampling training data from the buffer.
Expand Down Expand Up @@ -881,6 +883,7 @@ def _forward_collect(
except Exception:
# logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list')
root_sampled_actions = np.array([action for action in roots_sampled_actions[i]])

# NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents
# the index within the legal action set, rather than the index in the entire action set.
action, visit_count_distribution_entropy = select_action(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
grad_clip_value=0.5,
num_simulations=num_simulations,
reanalyze_ratio=reanalyze_ratio,
random_collect_episode_num=8,
# NOTE: for continuous gaussian policy, we use the policy_entropy_loss as in the original Sampled MuZero paper.
policy_entropy_loss_weight=5e-3,
n_episode=n_episode,
Expand Down

0 comments on commit c093d77

Please sign in to comment.