Skip to content

Commit

Permalink
polish(pu): polish some variables name and redundant code
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Oct 30, 2023
1 parent c093d77 commit 4969d87
Show file tree
Hide file tree
Showing 13 changed files with 107 additions and 103 deletions.
4 changes: 2 additions & 2 deletions lzero/mcts/tests/cprofile_mcts_ptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def initial_inference(self, observation):
reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))

output = {
'value': value,
'searched_value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand All @@ -45,7 +45,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions):
policy_logits = torch.zeros(size=(batch_size, self.action_num))

output = {
'value': value,
'searched_value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand Down
4 changes: 2 additions & 2 deletions lzero/mcts/tests/eval_tree_speed.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def initial_inference(self, observation):
reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))

output = {
'value': value,
'searched_value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand All @@ -50,7 +50,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions):
policy_logits = torch.zeros(size=(batch_size, self.action_num))

output = {
'value': value,
'searched_value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand Down
4 changes: 2 additions & 2 deletions lzero/mcts/tests/test_mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def initial_inference(self, observation):
reward_hidden_state_roots = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))

output = {
'value': value,
'searched_value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand All @@ -60,7 +60,7 @@ def recurrent_inference(self, latent_states, reward_hidden_states, actions=None)
policy_logits = torch.zeros(size=(batch_size, self.action_num))

output = {
'value': value,
'searched_value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand Down
4 changes: 2 additions & 2 deletions lzero/mcts/tests/test_mcts_ptree.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def initial_inference(self, observation):
reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))

output = {
'value': value,
'searched_value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand All @@ -47,7 +47,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions):
policy_logits = torch.zeros(size=(batch_size, self.action_num))

output = {
'value': value,
'searched_value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand Down
4 changes: 2 additions & 2 deletions lzero/mcts/tests/test_mcts_sampled_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def initial_inference(self, observation):
reward_hidden_state_state = (torch.zeros(size=(1, batch_size, 16)), torch.zeros(size=(1, batch_size, 16)))

output = {
'value': value,
'searched_value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand All @@ -48,7 +48,7 @@ def recurrent_inference(self, hidden_states, reward_hidden_states, actions):
# policy_logits = torch.zeros(size=(batch_size, self.action_num))

output = {
'value': value,
'searched_value': value,
'value_prefix': value_prefix,
'policy_logits': policy_logits,
'latent_state': latent_state,
Expand Down
20 changes: 10 additions & 10 deletions lzero/policy/efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class EfficientZeroPolicy(MuZeroPolicy):
# (float) The fixed temperature value for MCTS action selection, which is used to control the exploration.
# The larger the value, the more exploration. This value is only used when manual_temperature_decay=False.
fixed_temperature_value=0.25,
# (bool) Whether to use the true chance in MCTS in 2048 env.
# (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048.
use_ture_chance_label_in_chance_encoder=False,

# ****** Priority ******
Expand Down Expand Up @@ -626,11 +626,11 @@ def _forward_collect(
action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
output[env_id] = {
'action': action,
'distributions': distributions,
'visit_count_distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
'searched_value': value,
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}

return output
Expand All @@ -646,7 +646,7 @@ def _init_eval(self) -> None:
else:
self._mcts_eval = MCTSPtree(self._cfg)

def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None):
def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,):
"""
Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search.
Expand Down Expand Up @@ -718,11 +718,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, read
action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
output[env_id] = {
'action': action,
'distributions': distributions,
'visit_count_distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
'searched_value': value,
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}

return output
Expand Down
20 changes: 10 additions & 10 deletions lzero/policy/gumbel_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,7 @@ def _forward_collect(
action_mask: list = None,
temperature: float = 1,
to_play: List = [-1],
ready_env_id=None
ready_env_id: np.array = None,
) -> Dict:
"""
Overview:
Expand Down Expand Up @@ -572,13 +572,13 @@ def _forward_collect(
action = np.argmax([v for v in valid_value])
output[env_id] = {
'action': action,
'distributions': distributions,
'visit_count_distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'searched_value': value,
'roots_completed_value': roots_completed_value,
'improved_policy_probs': improved_policy_probs,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}

return output
Expand All @@ -594,7 +594,7 @@ def _init_eval(self) -> None:
else:
self._mcts_eval = MCTSPtree(self._cfg)

def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict:
def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict:
"""
Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search.
Expand Down Expand Up @@ -674,11 +674,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1

output[env_id] = {
'action': action,
'distributions': distributions,
'visit_count_distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
'searched_value': value,
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}

return output
Expand Down
22 changes: 11 additions & 11 deletions lzero/policy/muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ class MuZeroPolicy(Policy):
# (float) The fixed temperature value for MCTS action selection, which is used to control the exploration.
# The larger the value, the more exploration. This value is only used when manual_temperature_decay=False.
fixed_temperature_value=0.25,
# (bool) Whether to use the true chance in MCTS in 2048 env.
# (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048.
use_ture_chance_label_in_chance_encoder=False,

# ****** Priority ******
Expand Down Expand Up @@ -477,7 +477,7 @@ def _forward_collect(
temperature: float = 1,
to_play: List = [-1],
epsilon: float = 0.25,
ready_env_id=None
ready_env_id: np.array = None,
) -> Dict:
"""
Overview:
Expand Down Expand Up @@ -562,11 +562,11 @@ def _forward_collect(
action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set]
output[env_id] = {
'action': action,
'distributions': distributions,
'visit_count_distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
'searched_value': value,
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}

return output
Expand Down Expand Up @@ -606,7 +606,7 @@ def _get_target_obs_index_in_step_k(self, step):
end_index = self._cfg.model.observation_shape * (step + self._cfg.model.frame_stack_num)
return beg_index, end_index

def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id=None) -> Dict:
def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None,) -> Dict:
"""
Overview:
The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search.
Expand Down Expand Up @@ -676,11 +676,11 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1

output[env_id] = {
'action': action,
'distributions': distributions,
'visit_count_distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
'searched_value': value,
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}

return output
Expand Down
57 changes: 29 additions & 28 deletions lzero/policy/random_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
cfg: dict,
model: Optional[Union[type, torch.nn.Module]] = None,
enable_field: Optional[List[str]] = None,
action_space = None,
action_space: Any = None,
):
if cfg.type == 'muzero':
from lzero.mcts import MuZeroMCTSCtree as MCTSCtree
Expand Down Expand Up @@ -65,15 +65,15 @@ def default_model(self) -> Tuple[str, List[str]]:
elif self._cfg.type == 'muzero':
return 'MuZeroModelMLP', ['lzero.model.muzero_model_mlp']
elif self._cfg.type == 'sampled_efficientzero':
return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_modelMLP']
return 'SampledEfficientZeroModelMLP', ['lzero.model.sampled_efficientzero_model_mlp']
else:
raise NotImplementedError("need to implement pipeline: {}".format(self._cfg.type))

def _init_collect(self) -> None:
"""
Overview:
Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils.
"""
Overview:
Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils.
"""
self._collect_model = self._model
if self._cfg.mcts_ctree:
self._mcts_collect = self.MCTSCtree(self._cfg)
Expand All @@ -92,8 +92,8 @@ def _forward_collect(
temperature: float = 1,
to_play: List = [-1],
epsilon: float = 0.25,
ready_env_id=None,
):
ready_env_id: np.array = None,
) -> Dict:
"""
Overview:
The forward function for collecting data in collect mode. Use model to execute MCTS search.
Expand Down Expand Up @@ -141,7 +141,7 @@ def _forward_collect(
)
policy_logits = policy_logits.detach().cpu().numpy().tolist()

if self._cfg.model.continuous_action_space is True:
if self._cfg.model.continuous_action_space:
# when the action space of the environment is continuous, action_mask[:] is None.
# NOTE: in continuous action space env: we set all legal_actions as -1
legal_actions = [
Expand Down Expand Up @@ -208,44 +208,45 @@ def _forward_collect(
distributions, value = roots_visit_count_distributions[i], roots_values[i]

if self._cfg.type in ['sampled_efficientzero']:
try:
root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]])
except Exception:
# logging.warning('ctree_sampled_efficientzero roots.get_sampled_actions() return list')
if self._cfg.mcts_ctree:
# In ctree, the method roots.get_sampled_actions() returns a list object.
root_sampled_actions = np.array([action for action in roots_sampled_actions[i]])
else:
# In ptree, the same method roots.get_sampled_actions() returns an Action object.
root_sampled_actions = np.array([action.value for action in roots_sampled_actions[i]])

# NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents
# the index within the legal action set, rather than the index in the entire action set.
action_index_in_legal_action_set, visit_count_distribution_entropy = select_action(
distributions, temperature=self._collect_mcts_temperature, deterministic=False
)

# ****** sample a random action from the legal action set ********
if self._cfg.type in ['sampled_efficientzero']:
random_action = self.action_space.sample()
else:
# all items except action are formally obtained from MCTS
random_action = int(np.random.choice(legal_actions[env_id], 1))
# ****************************************************************
# NOTE: The action is randomly selected from the legal action set, the distribution is the real visit count distribution from the MCTS seraech.
# NOTE: The action is randomly selected from the legal action set,
# the distribution is the real visit count distribution from the MCTS search.
if self._cfg.type in ['sampled_efficientzero']:
# ****** sample a random action from the legal action set ********
random_action = self.action_space.sample()
output[env_id] = {
'action': random_action,
'distributions': distributions,
'visit_count_distributions': distributions,
'root_sampled_actions': root_sampled_actions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
'searched_value': value,
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}
else:
# ****** sample a random action from the legal action set ********
random_action = int(np.random.choice(legal_actions[env_id], 1))
# all items except action are formally obtained from MCTS
output[env_id] = {
'action': random_action,
'distributions': distributions,
'visit_count_distributions': distributions,
'visit_count_distribution_entropy': visit_count_distribution_entropy,
'value': value,
'pred_value': pred_values[i],
'policy_logits': policy_logits[i],
'searched_value': value,
'predicted_value': pred_values[i],
'predicted_policy_logits': policy_logits[i],
}

return output
Expand All @@ -268,7 +269,7 @@ def _init_learn(self) -> None:
def _forward_learn(self, data: torch.Tensor) -> Dict[str, Union[float, int]]:
pass

def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id=None):
def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: -1, ready_env_id: np.array = None,):
pass

def _monitor_vars_learn(self) -> List[str]:
Expand Down
Loading

0 comments on commit 4969d87

Please sign in to comment.