From 6ffcc4d48588f8fb2816585f359409ec793cd2b1 Mon Sep 17 00:00:00 2001 From: HarryXuancy <52876902+HarryXuancy@users.noreply.github.com> Date: Thu, 8 Feb 2024 12:14:09 +0800 Subject: [PATCH] polish(xcy): polish comments in tree search files (#185) * polish(xcy): polish comment in tree_search * polish(xcy):polish the comments for ptree * polish(xcy): add comment to ptree files * polish(xcy):add comments for ctree files --- .../ctree/ctree_efficientzero/lib/cnode.cpp | 12 +- .../ctree/ctree_gumbel_muzero/lib/cnode.cpp | 30 +- lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp | 12 +- .../ctree_sampled_efficientzero/lib/cnode.cpp | 16 +- .../ctree_stochastic_muzero/lib/cnode.cpp | 13 +- lzero/mcts/ptree/minimax.py | 56 ++++ lzero/mcts/ptree/ptree_ez.py | 229 ++++++++----- lzero/mcts/ptree/ptree_mz.py | 267 ++++++++------- lzero/mcts/ptree/ptree_sez.py | 308 +++++++++++------- lzero/mcts/ptree/ptree_stochastic_mz.py | 269 ++++++++------- lzero/mcts/tree_search/mcts_ctree.py | 157 ++++++--- lzero/mcts/tree_search/mcts_ctree_sampled.py | 53 ++- .../mcts/tree_search/mcts_ctree_stochastic.py | 49 ++- lzero/mcts/tree_search/mcts_ptree.py | 110 +++++-- lzero/mcts/tree_search/mcts_ptree_sampled.py | 52 ++- .../mcts/tree_search/mcts_ptree_stochastic.py | 48 ++- 16 files changed, 1074 insertions(+), 607 deletions(-) diff --git a/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp index 59846f1ac..c07924836 100644 --- a/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp @@ -248,7 +248,7 @@ namespace tree /* Overview: Find the current best trajectory starts from the current node. - Outputs: + Returns: - traj: a vector of node index, which is the current best trajectory from this node. */ std::vector traj; @@ -270,7 +270,7 @@ namespace tree /* Overview: Get the distribution of child nodes in the format of visit_count. - Outputs: + Returns: - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). */ std::vector distribution; @@ -378,7 +378,7 @@ namespace tree /* Overview: Find the current best trajectory starts from each root. - Outputs: + Returns: - traj: a vector of node index, which is the current best trajectory from each root. */ std::vector > trajs; @@ -396,7 +396,7 @@ namespace tree /* Overview: Get the children distribution of each root. - Outputs: + Returns: - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). */ std::vector > distributions; @@ -618,7 +618,7 @@ namespace tree - disount_factor: the discount factor of reward. - mean_q: the mean q value of the parent node. - players: the number of players. - Outputs: + Returns: - action: the action to select. */ float max_score = FLOAT_MIN; @@ -667,7 +667,7 @@ namespace tree - pb_c_init: constants c1 in muzero. - disount_factor: the discount factor of reward. - players: the number of players. - Outputs: + Returns: - ucb_value: the ucb score of the child. */ float pb_c = 0.0, prior_score = 0.0, value_score = 0.0; diff --git a/lzero/mcts/ctree/ctree_gumbel_muzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_gumbel_muzero/lib/cnode.cpp index 123d29c88..1adf1c1d2 100644 --- a/lzero/mcts/ctree/ctree_gumbel_muzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_gumbel_muzero/lib/cnode.cpp @@ -265,7 +265,7 @@ namespace tree{ /* Overview: Find the current best trajectory starts from the current node. - Outputs: + Returns: - traj: a vector of node index, which is the current best trajectory from this node. */ std::vector traj; @@ -287,7 +287,7 @@ namespace tree{ /* Overview: Get the distribution of child nodes in the format of visit_count. - Outputs: + Returns: - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). */ std::vector distribution; @@ -311,7 +311,7 @@ namespace tree{ /* Overview: Get the completed value of child nodes. - Outputs: + Returns: - discount_factor: the discount_factor of reward. - action_space_size: the size of action space. */ @@ -468,7 +468,7 @@ namespace tree{ /* Overview: Find the current best trajectory starts from each root. - Outputs: + Returns: - traj: a vector of node index, which is the current best trajectory from each root. */ std::vector > trajs; @@ -486,7 +486,7 @@ namespace tree{ /* Overview: Get the children distribution of each root. - Outputs: + Returns: - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). */ std::vector > distributions; @@ -664,7 +664,7 @@ namespace tree{ - disount_factor: the discount factor of reward. - mean_q: the mean q value of the parent node. - players: the number of players. - Outputs: + Returns: - action: the action to select. */ float max_score = FLOAT_MIN; @@ -708,7 +708,7 @@ namespace tree{ - disount_factor: the discount factor of reward. - num_simulations: the upper limit number of simulations. - max_num_considered_actions: the maximum number of considered actions. - Outputs: + Returns: - action: the action to select. */ std::vector child_visit_count; @@ -752,7 +752,7 @@ namespace tree{ Arguments: - root: the roots to select the child node. - disount_factor: the discount factor of reward. - Outputs: + Returns: - action: the action to select. */ std::vector child_visit_count; @@ -803,7 +803,7 @@ namespace tree{ - pb_c_init: constants c1 in muzero. - disount_factor: the discount factor of reward. - players: the number of players. - Outputs: + Returns: - ucb_value: the ucb score of the child. */ float pb_c = 0.0, prior_score = 0.0, value_score = 0.0; @@ -942,7 +942,7 @@ namespace tree{ - q_value: the q value of the current node. - child_visit: the visit counts of the child nodes. - child_prior: the prior of the child nodes. - Outputs: + Returns: - mixed Q value. */ float visit_count_sum = 0.0; @@ -1002,7 +1002,7 @@ namespace tree{ - value_cale: the scale of value. - rescale_values: whether to rescale the values. - epsilon: the lower limit of gap in max-min normalization - Outputs: + Returns: - completed Q value. */ assert (child_visit.size() == child_prior.size()); @@ -1047,7 +1047,7 @@ namespace tree{ Arguments: - max_num_considered_actions: the maximum number of considered actions. - num_simulations: the upper limit number of simulations. - Outputs: + Returns: - the considered visit sequence. */ std::vector visit_seq; @@ -1084,7 +1084,7 @@ namespace tree{ Arguments: - max_num_considered_actions: the maximum number of considered actions. - num_simulations: the upper limit number of simulations. - Outputs: + Returns: - the table of considered visits. */ std::vector > table; @@ -1105,7 +1105,7 @@ namespace tree{ - logits: the logits vector of child nodes. - normalized_qvalues: the normalized Q values of child nodes. - visit_counts: the visit counts of child nodes. - Outputs: + Returns: - the score of nodes to be considered. */ float low_logit = -1e9; @@ -1139,7 +1139,7 @@ namespace tree{ - gumbel_scale: the scale of gumbel. - gumbel_rng: the seed to generate gumbel. - shape: the shape of gumbel vectors to be generated - Outputs: + Returns: - gumbel vectors. */ std::mt19937 gen(static_cast(gumbel_rng)); diff --git a/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp index d8891f42e..fba68b25e 100644 --- a/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_muzero/lib/cnode.cpp @@ -239,7 +239,7 @@ namespace tree /* Overview: Find the current best trajectory starts from the current node. - Outputs: + Returns: - traj: a vector of node index, which is the current best trajectory from this node. */ std::vector traj; @@ -261,7 +261,7 @@ namespace tree /* Overview: Get the distribution of child nodes in the format of visit_count. - Outputs: + Returns: - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). */ std::vector distribution; @@ -371,7 +371,7 @@ namespace tree /* Overview: Find the current best trajectory starts from each root. - Outputs: + Returns: - traj: a vector of node index, which is the current best trajectory from each root. */ std::vector > trajs; @@ -389,7 +389,7 @@ namespace tree /* Overview: Get the children distribution of each root. - Outputs: + Returns: - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). */ std::vector > distributions; @@ -561,7 +561,7 @@ namespace tree - disount_factor: the discount factor of reward. - mean_q: the mean q value of the parent node. - players: the number of players. - Outputs: + Returns: - action: the action to select. */ float max_score = FLOAT_MIN; @@ -609,7 +609,7 @@ namespace tree - pb_c_init: constants c1 in muzero. - disount_factor: the discount factor of reward. - players: the number of players. - Outputs: + Returns: - ucb_value: the ucb score of the child. */ float pb_c = 0.0, prior_score = 0.0, value_score = 0.0; diff --git a/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp index 7eddc4962..6bc4ea2e8 100644 --- a/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_sampled_efficientzero/lib/cnode.cpp @@ -94,7 +94,7 @@ namespace tree { /* Overview: - get the final combined hash value from the hash values of each dimension of the multi-dimensional action. + Get the final combined hash value from the hash values of each dimension of the multi-dimensional action. */ std::vector hash = this->get_hash(); size_t combined_hash = hash[0]; @@ -558,7 +558,7 @@ namespace tree /* Overview: Find the current best trajectory starts from the current node. - Outputs: + Returns: - traj: a vector of node index, which is the current best trajectory from this node. */ std::vector traj; @@ -585,7 +585,7 @@ namespace tree /* Overview: Get the distribution of child nodes in the format of visit_count. - Outputs: + Returns: - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). */ std::vector distribution; @@ -724,7 +724,7 @@ namespace tree /* Overview: Find the current best trajectory starts from each root. - Outputs: + Returns: - traj: a vector of node index, which is the current best trajectory from each root. */ std::vector > > trajs; @@ -742,7 +742,7 @@ namespace tree /* Overview: Get the children distribution of each root. - Outputs: + Returns: - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). */ std::vector > distributions; @@ -761,7 +761,7 @@ namespace tree /* Overview: Get the sampled_actions of each root. - Outputs: + Returns: - python_sampled_actions: a vector of sampled_actions for each root, e.g. the size of original action space is 6, the K=3, python_sampled_actions = [[1,3,0], [2,4,0], [5,4,1]]. */ @@ -984,7 +984,7 @@ namespace tree - mean_q: the mean q value of the parent node. - players: the number of players. - continuous_action_space: whether the action space is continous in current env. - Outputs: + Returns: - action: the action to select. */ // sampled related core code @@ -1040,7 +1040,7 @@ namespace tree - disount_factor: the discount factor of reward. - players: the number of players. - continuous_action_space: whether the action space is continous in current env. - Outputs: + Returns: - ucb_value: the ucb score of the child. */ float pb_c = 0.0, prior_score = 0.0, value_score = 0.0; diff --git a/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp b/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp index 004b11099..f848e07e4 100644 --- a/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp +++ b/lzero/mcts/ctree/ctree_stochastic_muzero/lib/cnode.cpp @@ -76,6 +76,7 @@ namespace tree Arguments: - prior: the prior value of this node. - legal_actions: a vector of legal actions of this node. + - is_chance: Whether the node is a chance node. */ this->prior = prior; this->legal_actions = legal_actions; @@ -264,7 +265,7 @@ namespace tree /* Overview: Find the current best trajectory starts from the current node. - Outputs: + Returns: - traj: a vector of node index, which is the current best trajectory from this node. */ std::vector traj; @@ -286,7 +287,7 @@ namespace tree /* Overview: Get the distribution of child nodes in the format of visit_count. - Outputs: + Returns: - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). */ std::vector distribution; @@ -398,7 +399,7 @@ namespace tree /* Overview: Find the current best trajectory starts from each root. - Outputs: + Returns: - traj: a vector of node index, which is the current best trajectory from each root. */ std::vector > trajs; @@ -416,7 +417,7 @@ namespace tree /* Overview: Get the children distribution of each root. - Outputs: + Returns: - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). */ std::vector > distributions; @@ -603,7 +604,7 @@ namespace tree - disount_factor: the discount factor of reward. - mean_q: the mean q value of the parent node. - players: the number of players. - Outputs: + Returns: - action: the action to select. */ if (root->is_chance) { @@ -675,7 +676,7 @@ namespace tree - pb_c_init: constants c1 in muzero. - disount_factor: the discount factor of reward. - players: the number of players. - Outputs: + Returns: - ucb_value: the ucb score of the child. */ float pb_c = 0.0, prior_score = 0.0, value_score = 0.0; diff --git a/lzero/mcts/ptree/minimax.py b/lzero/mcts/ptree/minimax.py index 62a71d5a7..4ca10996a 100644 --- a/lzero/mcts/ptree/minimax.py +++ b/lzero/mcts/ptree/minimax.py @@ -1,27 +1,65 @@ +""" +This code defines two classes, ``MinMaxStats`` and ``MinMaxStatsList``, \ +for tracking and normalizing minimum and maximum values. +""" FLOAT_MAX = 1000000.0 FLOAT_MIN = -float('inf') class MinMaxStats: + """ + Overview: + Class for tracking and normalizing minimum and maximum values. + Interfaces: + ``__init__``,``set_delta``,``update``,``clear``,``normalize``. + """ def __init__(self) -> None: + """ + Overview: + Initializes an instance of the class. + """ self.clear() self.value_delta_max = 0 def set_delta(self, value_delta_max: float) -> None: + """ + Overview: + Sets the maximum delta value. + Arguments: + - value_delta_max (:obj:`float`): The maximum delta value. + """ self.value_delta_max = value_delta_max def update(self, value: float) -> None: + """ + Overview: + Updates the minimum and maximum values. + Arguments: + - value (:obj:`float`): The value to update. + """ if value > self.maximum: self.maximum = value if value < self.minimum: self.minimum = value def clear(self) -> None: + """ + Overview: + Clears the minimum and maximum values. + """ self.minimum = FLOAT_MAX self.maximum = FLOAT_MIN def normalize(self, value: float) -> float: + """ + Overview: + Normalizes a value based on the minimum and maximum values. + Arguments: + - value (:obj:`float`): The value to normalize. + Returns: + - norm_value (:obj:`float`): The normalized value. + """ norm_value = value delta = self.maximum - self.minimum if delta > 0: @@ -33,11 +71,29 @@ def normalize(self, value: float) -> float: class MinMaxStatsList: + """ + Overview: + Class for managing a list of MinMaxStats instances. + Interfaces: + ``__init__``,``set_delta``. + """ def __init__(self, num: int) -> None: + """ + Overview: + Initializes a list of MinMaxStats instances. + Arguments: + - num (:obj:`int`): The number of MinMaxStats instances to create. + """ self.num = num self.stats_lst = [MinMaxStats() for _ in range(self.num)] def set_delta(self, value_delta_max: float) -> None: + """ + Overview: + Sets the maximum delta value for each MinMaxStats instance in the list. + Arguments: + - value_delta_max (:obj:`float`): The maximum delta value. + """ for i in range(self.num): self.stats_lst[i].set_delta(value_delta_max) diff --git a/lzero/mcts/ptree/ptree_ez.py b/lzero/mcts/ptree/ptree_ez.py index b97505441..7486e5581 100644 --- a/lzero/mcts/ptree/ptree_ez.py +++ b/lzero/mcts/ptree/ptree_ez.py @@ -13,11 +13,22 @@ class Node: """ - Overview: - the node base class for EfficientZero. - """ + Overview: + The Node class for EfficientZero. The basic functions of the node are implemented. + Interfaces: + ``__init__``, ``expand``, ``add_exploration_noise``, ``compute_mean_q``, ``get_trajectory``, \ + ``get_children_distribution``, ``get_child``, ``expanded``, ``value``. + """ def __init__(self, prior: float, legal_actions: List = None, action_space_size: int = 9) -> None: + """ + Overview: + Initializes a Node instance. + Arguments: + - prior (:obj:`float`): The prior probability of the node. + - legal_actions (:obj:`List`, optional): The list of legal actions of the node. Defaults to None. + - action_space_size (:obj:`int`, optional): The size of the action space. Defaults to 9. + """ self.prior = prior self.legal_actions = legal_actions self.action_space_size = action_space_size @@ -32,7 +43,6 @@ def __init__(self, prior: float, legal_actions: List = None, action_space_size: self.children_index = [] self.simulation_index = 0 self.batch_index = 0 - self.parent_value_prefix = 0 # only used in update_tree_q method def expand( self, to_play: int, simulation_index: int, batch_index: int, value_prefix: float, policy_logits: List[float] @@ -41,11 +51,11 @@ def expand( Overview: Expand the child nodes of the current node. Arguments: - - to_play (:obj:`Class int`): which player to play the game in the current node. - - simulation_index (:obj:`Class int`): the x/first index of hidden state vector of the current node, i.e. the search depth. - - batch_index (:obj:`Class int`): the y/second index of hidden state vector of the current node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. - - value_prefix: (:obj:`Class float`): the value prefix of the current node. - - policy_logits: (:obj:`Class List`): the policy logit of the child nodes. + - to_play (:obj:`int`): The player to play the game in the current node. + - simulation_index (:obj:`int`): The x/first index of the hidden state vector of the current node, i.e., the search depth. + - batch_index (:obj:`int`): The y/second index of the hidden state vector of the current node, i.e., the index of the batch root node, its maximum is `batch_size`/`env_num`. + - value_prefix: (:obj:`float`): the value prefix of the current node. + - policy_logits (:obj:`List[float]`): The policy logits providing the priors for the child nodes. """ self.to_play = to_play if self.legal_actions is None: @@ -63,10 +73,10 @@ def expand( def add_exploration_noise(self, exploration_fraction: float, noises: List[float]) -> None: """ Overview: - Add a noise to the prior of the child nodes. + Add exploration noise to the priors of the child nodes.. Arguments: - - exploration_fraction: the fraction to add noise. - - noises (:obj: list): the vector of noises added to each child node. length is len(self.legal_actions) + - exploration_fraction (:obj:`float`): The fraction of exploration noise to be added. + - noises (:obj:`List[float]`): The list of noises to be added to the priors. """ for i, a in enumerate(self.legal_actions): """ @@ -83,11 +93,13 @@ def add_exploration_noise(self, exploration_fraction: float, noises: List[float] def compute_mean_q(self, is_root: bool, parent_q: float, discount_factor: float) -> float: """ Overview: - Compute the mean q value of the current node. + Compute the mean of the action values of all legal actions. Arguments: - - is_root (:obj:`bool`): whether the current node is a root node. - - parent_q (:obj:`float`): the q value of the parent node. - - discount_factor (:obj:`float`): the discount_factor of reward. + - is_root (:obj:`bool`): Whether the current node is a root node. + - parent_q (:obj:`float`): The q value of the parent node. + - discount_factor (:obj:`float`): The discount factor of the reward. + Returns: + - mean_q (:obj:`float`): The mean of the action values. """ total_unsigned_q = 0.0 total_visits = 0 @@ -117,9 +129,9 @@ def print_out(self) -> None: def get_trajectory(self) -> List[Union[int, float]]: """ Overview: - Find the current best trajectory starts from the current node. - Outputs: - - traj: a vector of node index, which is the current best trajectory from this node. + Find the current best trajectory starting from the current node. + Returns: + - traj (:obj:`List[Union[int, float]]`): A vector of node indices representing the current best trajectory. """ # TODO(pu): best action traj = [] @@ -133,6 +145,13 @@ def get_trajectory(self) -> List[Union[int, float]]: return traj def get_children_distribution(self) -> List[Union[int, float]]: + """ + Overview: + Get the distribution of visit counts among the child nodes. + Returns: + - distribution (:obj:`List[Union[int, float]]` or :obj:`None`): The distribution of visit counts among the children nodes. \ + If the legal_actions list is empty, returns None. + """ if self.legal_actions == []: return None distribution = {a: 0 for a in self.legal_actions} @@ -147,7 +166,11 @@ def get_children_distribution(self) -> List[Union[int, float]]: def get_child(self, action: Union[int, float]) -> "Node": """ Overview: - get children node according to the input action. + Get the child node according to the input action. + Arguments: + - action (:obj:`Union[int, float]`): The action for which the child node is to be retrieved. + Returns: + - child (:obj:`Node`): The child node corresponding to the input action. """ if not isinstance(action, np.int64): action = int(action) @@ -155,13 +178,21 @@ def get_child(self, action: Union[int, float]) -> "Node": @property def expanded(self) -> bool: + """ + Overview: + Check if the node has been expanded. + Returns: + - expanded (:obj:`bool`): True if the node has been expanded, False otherwise. + """ return len(self.children) > 0 @property def value(self) -> float: """ Overview: - Return the estimated value of the current root node. + Return the estimated value of the current node. + Returns: + - value (:obj:`float`): The estimated value of the current node. """ if self.visit_count == 0: return 0 @@ -170,8 +201,22 @@ def value(self) -> float: class Roots: + """ + Overview: + The class to process a batch of roots(Node instances) at the same time. + Interfaces: + ``__init__``, ``prepare``, ``prepare_no_noise``, ``clear``, ``get_trajectories``, \ + ``get_distributions``, ``get_values`` + """ def __init__(self, root_num: int, legal_actions_list: List) -> None: + """ + Overview: + Initializes an instance of the Roots class with the specified number of roots and legal actions. + Arguments: + - root_num (:obj:`int`): The number of roots. + - legal_actions_list(:obj:`List`): A list of the legal actions for each root. + """ self.num = root_num self.root_num = root_num self.legal_actions_list = legal_actions_list # list of list @@ -198,13 +243,13 @@ def prepare( ) -> None: """ Overview: - Expand the roots and add noises. + Expand the roots and add noises for exploration. Arguments: - - root_noise_weight: the exploration fraction of roots - - noises: the vector of noise add to the roots. - - value_prefixs: the vector of value prefixs of each root. - - policies: the vector of policy logits of each root. - - to_play_batch: the vector of the player side of each root. + - root_noise_weight (:obj:`float`): the exploration fraction of roots + - noises (:obj:`List[float]`): the vector of noise add to the roots. + - value_prefixs (:obj:`List[float]`): the vector of rewards of each root. + - policies (:obj:`List[List[float]]`): the vector of policy logits of each root. + - to_play(:obj:`List`): The vector of the player side of each root. """ for i in range(self.root_num): @@ -223,9 +268,9 @@ def prepare_no_noise(self, value_prefixs: List[float], policies: List[List[float Overview: Expand the roots without noise. Arguments: - - value_prefixs: the vector of value prefixs of each root. - - policies: the vector of policy logits of each root. - - to_play_batch: the vector of the player side of each root. + - value_prefixs (:obj:`List[float]`): the vector of value prefixs of each root. + - policies (:obj:`List[List[float]]`): the vector of policy logits of each root. + - to_play(:obj:`List`): The vector of the player side of each root. """ for i in range(self.root_num): if to_play in [-1, None]: @@ -236,14 +281,18 @@ def prepare_no_noise(self, value_prefixs: List[float], policies: List[List[float self.roots[i].visit_count += 1 def clear(self) -> None: + """ + Overview: + Clear all the roots in the list. + """ self.roots.clear() def get_trajectories(self) -> List[List[Union[int, float]]]: """ Overview: Find the current best trajectory starts from each root. - Outputs: - - traj: a vector of node index, which is the current best trajectory from each root. + Returns: + - traj (:obj:`List[List[Union[int, float]]]`): a vector of node index, which is the current best trajectory from each root. """ trajs = [] for i in range(self.root_num): @@ -253,9 +302,9 @@ def get_trajectories(self) -> List[List[Union[int, float]]]: def get_distributions(self) -> List[List[Union[int, float]]]: """ Overview: - Get the children distribution of each root. - Outputs: - - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). + Get the visit count distribution of child nodes for each root. + Returns: + - distribution (:obj:`List[List[Union[int, float]]]`): a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). """ distributions = [] for i in range(self.root_num): @@ -266,7 +315,9 @@ def get_distributions(self) -> List[List[Union[int, float]]]: def get_values(self) -> List[float]: """ Overview: - Return the estimated value of each root. + Get the estimated value of each root. + Returns: + - values (:obj:`List[float]`): The estimated value of each root. """ values = [] for i in range(self.root_num): @@ -275,8 +326,20 @@ def get_values(self) -> List[float]: class SearchResults: + """ + Overview: + The class to record the results of the simulations for the batch of roots. + Interfaces: + ``__init__``. + """ def __init__(self, num: int) -> None: + """ + Overview: + Initiaizes the attributes to be recorded. + Arguments: + -num (:obj:`int`): The number of search results(equal to ``batch_size``). + """ self.num = num self.nodes = [] self.search_paths = [] @@ -294,13 +357,13 @@ def select_child( Overview: Select the child node of the roots according to ucb scores. Arguments: - - root: the roots to select the child node. - - min_max_stats (:obj:`Class MinMaxStats`): a tool used to min-max normalize the score. - - pb_c_base (:obj:`Class Float`): constant c1 used in pUCT rule, typically 1.25. - - pb_c_int (:obj:`Class Float`): constant c2 used in pUCT rule, typically 19652. - - discount_factor (:obj:`Class Float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. - - mean_q (:obj:`Class Float`): the mean q value of the parent node. - - players (:obj:`Class Int`): the number of players. one/in self-play-mode board games. + - root(:obj:`Node`): The root to select the child node. + - min_max_stats (:obj:`MinMaxStats`): A tool used to min-max normalize the score. + - pb_c_base (:obj:`float`): Constant c1 used in pUCT rule, typically 1.25. + - pb_c_int (:obj:`float`): Constant c2 used in pUCT rule, typically 19652. + - discount_factor (:obj:`float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. + - mean_q (:obj:`float`): The mean q value of the parent node. + - players (:obj:`int`): The number of players. In two-player games such as board games, the value need to be negatived when backpropating. Returns: - action (:obj:`Union[int, float]`): Choose the action with the highest ucb score. """ @@ -342,24 +405,27 @@ def compute_ucb_score( """ Overview: Compute the ucb score of the child. - Arguments: - - child: the child node to compute ucb score. - - min_max_stats: a tool used to min-max normalize the score. - - parent_mean_q: the mean q value of the parent node. - - is_reset: whether the value prefix needs to be reset. - - total_children_visit_counts: the total visit counts of the child nodes of the parent node. - - parent_value_prefix: the value prefix of parent node. - - pb_c_base: constants c2 in muzero. - - pb_c_init: constants c1 in muzero. - - disount_factor: the discount factor of reward. - - players: the number of players. - Outputs: - - ucb_value: the ucb score of the child. + Arguments: + - child (:obj:`Node`): the child node to compute ucb score. + - min_max_stats (:obj:`MinMaxStats`): a tool used to min-max normalize the score. + - parent_mean_q (:obj:`float`): the mean q value of the parent node. + - is_reset (:obj:`float`): whether the value prefix needs to be reset. + - total_children_visit_counts (:obj:`float`): the total visit counts of the child nodes of the parent node. + - parent_value_prefix(:obj:`float`): The value prefix of parent node. + - pb_c_base (:obj:`float`): constants c2 in muzero. + - pb_c_init (:obj:`float`): constants c1 in muzero. + - disount_factor (:obj:`float`): the discount factor of reward. + - players (:obj:`int`): the number of players. + Returns: + - ucb_value (:obj:`float`): the ucb score of the child. """ + # Compute the prior score. pb_c = math.log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init pb_c *= (math.sqrt(total_children_visit_counts) / (child.visit_count + 1)) prior_score = pb_c * child.prior + + # Compute the value score. if child.visit_count == 0: value_score = parent_mean_q else: @@ -371,11 +437,13 @@ def compute_ucb_score( elif players == 2: value_score = true_reward + discount_factor * (-child.value) + # Normalize the value score. value_score = min_max_stats.normalize(value_score) if value_score < 0: value_score = 0 if value_score > 1: value_score = 1 + ucb_score = prior_score + value_score return ucb_score @@ -392,19 +460,20 @@ def batch_traverse( ) -> Tuple[List[None], List[None], List[None], Union[list, int]]: """ Overview: - traverse, also called expansion. process a batch roots parallely. + traverse, also called expansion. Process a batch roots at once. Arguments: - - roots (:obj:`Any`): a batch of root nodes to be expanded. - - pb_c_base (:obj:`float`): constant c1 used in pUCT rule, typically 1.25. - - pb_c_init (:obj:`float`): constant c2 used in pUCT rule, typically 19652. + - roots (:obj:`Any`): A batch of root nodes to be expanded. + - pb_c_base (:obj:`float`): Constant c1 used in pUCT rule, typically 1.25. + - pb_c_init (:obj:`float`): Constant c2 used in pUCT rule, typically 19652. - discount_factor (:obj:`float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. - - virtual_to_play (:obj:`list`): the to_play list used in self_play collecting and training in board games, + - results (:obj:`SearchResults`): An instance to record the simulation results for all the roots in the batch. + - virtual_to_play (:obj:`list`): The to_play list used in self_play collecting and training in board games, `virtual` is to emphasize that actions are performed on an imaginary hidden state. Returns: - - latent_state_index_in_search_path (:obj:`list`): the list of x/first index of hidden state vector of the searched node, i.e. the search depth. - - latent_state_index_in_batch (:obj:`list`): the list of y/second index of hidden state vector of the searched node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. - - last_actions (:obj:`list`): the action performed by the previous node. - - virtual_to_play (:obj:`list`): the to_play list used in self_play collecting and trainin gin board games, + - latent_state_index_in_search_path (:obj:`list`): The list of x/first index of hidden state vector of the searched node, i.e. the search depth. + - latent_state_index_in_batch (:obj:`list`): The list of y/second index of hidden state vector of the searched node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. + - last_actions (:obj:`list`): The action performed by the previous node. + - virtual_to_play (:obj:`list`): The to_play list used in self_play collecting and trainin gin board games, `virtual` is to emphasize that actions are performed on an imaginary hidden state. """ parent_q = 0.0 @@ -480,11 +549,11 @@ def backpropagate( Overview: Update the value sum and visit count of nodes along the search path. Arguments: - - search_path: a vector of nodes on the search path. - - min_max_stats: a tool used to min-max normalize the q value. - - to_play: which player to play the game in the current node. - - value: the value to propagate along the search path. - - discount_factor: the discount factor of reward. + - search_path (:obj:`List[Node]`): a vector of nodes on the search path. + - min_max_stats (:obj:`MinMaxStats`): a tool used to min-max normalize the q value. + - to_play (:obj:`int`): which player to play the game in the current node. + - value (:obj:`float`): the value to propagate along the search path. + - discount_factor (:obj:`float`): the discount factor of reward. """ assert to_play is None or to_play in [-1, 1, 2], f'to_play is {to_play}!' if to_play is None or to_play == -1: @@ -553,17 +622,17 @@ def batch_backpropagate( ) -> None: """ Overview: - Backpropagation along the search path to update the attributes. + Update the value sum and visit count of nodes along the search paths for each root in the batch. Arguments: - - simulation_index (:obj:`Class Int`): The index of latent state of the leaf node in the search path. - - discount_factor (:obj:`Class Float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. - - value_prefixs (:obj:`Class List`): the value prefixs of nodes along the search path. - - values (:obj:`Class List`): the values to propagate along the search path. - - policies (:obj:`Class List`): the policy logits of nodes along the search path. - - min_max_stats_lst (:obj:`Class List[MinMaxStats]`): a tool used to min-max normalize the q value. - - results (:obj:`Class List`): the search results. - - is_reset_list (:obj:`Class List`): the vector of is_reset nodes along the search path, where is_reset represents for whether the parent value prefix needs to be reset. - - to_play (:obj:`Class List`): the batch of which player is playing on this node. + - simulation_index (:obj:`int`): The index of latent state of the leaf node in the search path. + - discount_factor (:obj:`float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. + - value_prefixs (:obj:`List`): the value prefixs of nodes along the search path. + - values (:obj:`List`): the values to propagate along the search path. + - policies (:obj:`List`): the policy logits of nodes along the search path. + - min_max_stats_lst (:obj:`List[MinMaxStats]`): a tool used to min-max normalize the q value. + - results (:obj:`List`): the search results. + - is_reset_list (:obj:`List`): the vector of is_reset nodes along the search path, where is_reset represents for whether the parent value prefix needs to be reset. + - to_play (:obj:`List`): the batch of which player is playing on this node. """ for i in range(results.num): # ****** expand the leaf node ****** diff --git a/lzero/mcts/ptree/ptree_mz.py b/lzero/mcts/ptree/ptree_mz.py index 00e672bc5..909d48e4f 100644 --- a/lzero/mcts/ptree/ptree_mz.py +++ b/lzero/mcts/ptree/ptree_mz.py @@ -13,10 +13,21 @@ class Node: """ - Overview: - The base class of Node for MuZero. - """ + Overview: + The Node class for MuZero. The basic functions of the node are implemented. + Interfaces: + ``__init__``, ``expand``, ``add_exploration_noise``, ``compute_mean_q``, ``get_trajectory``, \ + ``get_children_distribution``, ``get_child``, ``expanded``, ``value``. + """ def __init__(self, prior: float, legal_actions: List = None, action_space_size: int = 9) -> None: + """ + Overview: + Initializes a Node instance. + Arguments: + - prior (:obj:`float`): The prior probability of the node. + - legal_actions (:obj:`List`, optional): The list of legal actions of the node. Defaults to None. + - action_space_size (:obj:`int`, optional): The size of the action space. Defaults to 9. + """ self.prior = prior self.legal_actions = legal_actions self.action_space_size = action_space_size @@ -30,7 +41,6 @@ def __init__(self, prior: float, legal_actions: List = None, action_space_size: self.children_index = [] self.simulation_index = 0 self.batch_index = 0 - self.parent_value_prefix = 0 # only used in update_tree_q method def expand(self, to_play: int, simulation_index: int, batch_index: int, reward: float, policy_logits: List[float]) -> None: @@ -38,11 +48,11 @@ def expand(self, to_play: int, simulation_index: int, batch_index: int, reward: Overview: Expand the child nodes of the current node. Arguments: - - to_play (:obj:`Class int`): which player to play the game in the current node. - - simulation_index (:obj:`Class int`): the x/first index of hidden state vector of the current node, i.e. the search depth. - - batch_index (:obj:`Class int`): the y/second index of hidden state vector of the current node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. - - value_prefix: (:obj:`Class float`): the value prefix of the current node. - - policy_logits: (:obj:`Class List`): the policy logit of the child nodes. + - to_play (:obj:`int`): The player to play the game in the current node. + - simulation_index (:obj:`int`): The x/first index of the hidden state vector of the current node, i.e., the search depth. + - batch_index (:obj:`int`): The y/second index of the hidden state vector of the current node, i.e., the index of the batch root node, its maximum is `batch_size`/`env_num`. + - reward (:obj:`float`): The reward associated with the current node. + - policy_logits (:obj:`List[float]`): The policy logits providing the priors for the child nodes. """ self.to_play = to_play if self.legal_actions is None: @@ -60,9 +70,10 @@ def expand(self, to_play: int, simulation_index: int, batch_index: int, reward: def add_exploration_noise(self, exploration_fraction: float, noises: List[float]) -> None: """ Overview: - add exploration noise to priors + Add exploration noise to the priors. Arguments: - - noises (:obj: list): length is len(self.legal_actions) + - exploration_fraction (:obj:`float`): The fraction of exploration noise to be added. + - noises (:obj:`List[float]`): The list of noises to be added to the priors. """ for i, a in enumerate(self.legal_actions): """ @@ -73,19 +84,22 @@ def add_exploration_noise(self, exploration_fraction: float, noises: List[float] prior = child.prior child.prior = prior * (1 - exploration_fraction) + noise * exploration_fraction - def compute_mean_q(self, is_root: int, parent_q: float, discount_factor: float) -> float: + def compute_mean_q(self, is_root: bool, parent_q: float, discount_factor: float) -> float: """ Overview: - Compute the mean q value of the current node. + Compute the mean of the action values of all legal actions. Arguments: - - is_root (:obj:`int`): whether the current node is a root node. - - parent_q (:obj:`float`): the q value of the parent node. - - discount_factor (:obj:`float`): the discount_factor of reward. + - is_root (:obj:`bool`): Whether the current node is a root node. + - parent_q (:obj:`float`): The q value of the parent node. + - discount_factor (:obj:`float`): The discount factor of the reward. + Returns: + - mean_q (:obj:`float`): The mean of the action values. """ total_unsigned_q = 0.0 total_visits = 0 for a in self.legal_actions: child = self.get_child(a) + # Only count the child nodes which have been visited. if child.visit_count > 0: true_reward = child.reward # TODO(pu): why only one step bootstrap? @@ -102,9 +116,9 @@ def compute_mean_q(self, is_root: int, parent_q: float, discount_factor: float) def get_trajectory(self) -> List[Union[int, float]]: """ Overview: - Find the current best trajectory starts from the current node. - Outputs: - - traj: a vector of node index, which is the current best trajectory from this node. + Find the current best trajectory starting from the current node. + Returns: + - traj (:obj:`List[Union[int, float]]`): A vector of node indices representing the current best trajectory. """ traj = [] node = self @@ -116,6 +130,13 @@ def get_trajectory(self) -> List[Union[int, float]]: return traj def get_children_distribution(self) -> List[Union[int, float]]: + """ + Overview: + Get the distribution of visit counts among the child nodes. + Returns: + - distribution (:obj:`List[Union[int, float]]` or :obj:`None`): The distribution of visit counts among the children nodes. \ + If the legal_actions list is empty, returns None. + """ if self.legal_actions == []: return None distribution = {a: 0 for a in self.legal_actions} @@ -130,7 +151,11 @@ def get_children_distribution(self) -> List[Union[int, float]]: def get_child(self, action: Union[int, float]) -> "Node": """ Overview: - get children node according to the input action. + Get the child node according to the input action. + Arguments: + - action (:obj:`Union[int, float]`): The action for which the child node is to be retrieved. + Returns: + - child (:obj:`Node`): The child node corresponding to the input action. """ if not isinstance(action, np.int64): action = int(action) @@ -138,13 +163,21 @@ def get_child(self, action: Union[int, float]) -> "Node": @property def expanded(self) -> bool: + """ + Overview: + Check if the node has been expanded. + Returns: + - expanded (:obj:`bool`): True if the node has been expanded, False otherwise. + """ return len(self.children) > 0 @property def value(self) -> float: """ Overview: - Return the estimated value of the current root node. + Return the estimated value of the current node. + Returns: + - value (:obj:`float`): The estimated value of the current node. """ if self.visit_count == 0: return 0 @@ -153,8 +186,22 @@ def value(self) -> float: class Roots: + """ + Overview: + The class to process a batch of roots(Node instances) at the same time. + Interfaces: + ``__init__``, ``prepare``, ``prepare_no_noise``, ``clear``, ``get_trajectories``, \ + ``get_distributions``, ``get_values`` + """ def __init__(self, root_num: int, legal_actions_list: List) -> None: + """ + Overview: + Initializes an instance of the Roots class with the specified number of roots and legal actions. + Arguments: + - root_num (:obj:`int`): The number of roots. + - legal_actions_list(:obj:`List`): A list of the legal actions for each root. + """ self.num = root_num self.root_num = root_num self.legal_actions_list = legal_actions_list # list of list @@ -176,14 +223,13 @@ def prepare( ) -> None: """ Overview: - Expand the roots and add noises. + Expand the roots and add noises for exploration. Arguments: - - root_noise_weight: the exploration fraction of roots - - noises: the vector of noise add to the roots. - - rewards: the vector of rewards of each root. - - policies: the vector of policy logits of each root. - - to_play_batch: the vector of the player side of each root. - + - root_noise_weight (:obj:`float`): the exploration fraction of roots + - noises (:obj:`List[float]`): the vector of noise add to the roots. + - rewards (:obj:`List[float]`): the vector of rewards of each root. + - policies (:obj:`List[List[float]]`): the vector of policy logits of each root. + - to_play(:obj:`List`): The vector of the player side of each root. """ for i in range(self.root_num): if to_play is None: @@ -199,9 +245,9 @@ def prepare_no_noise(self, rewards: List[float], policies: List[List[float]], to Overview: Expand the roots without noise. Arguments: - - rewards: the vector of rewards of each root. - - policies: the vector of policy logits of each root. - - to_play_batch: the vector of the player side of each root. + - rewards (:obj:`List[float]`): the vector of rewards of each root. + - policies (:obj:`List[List[float]]`): the vector of policy logits of each root. + - to_play(:obj:`List`): The vector of the player side of each root. """ for i in range(self.root_num): if to_play is None: @@ -212,14 +258,18 @@ def prepare_no_noise(self, rewards: List[float], policies: List[List[float]], to self.roots[i].visit_count += 1 def clear(self) -> None: + """ + Overview: + Clear all the roots in the list. + """ self.roots.clear() def get_trajectories(self) -> List[List[Union[int, float]]]: """ Overview: Find the current best trajectory starts from each root. - Outputs: - - traj: a vector of node index, which is the current best trajectory from each root. + Returns: + - traj (:obj:`List[List[Union[int, float]]]`): a vector of node index, which is the current best trajectory from each root. """ trajs = [] for i in range(self.root_num): @@ -229,9 +279,9 @@ def get_trajectories(self) -> List[List[Union[int, float]]]: def get_distributions(self) -> List[List[Union[int, float]]]: """ Overview: - Get the children distribution of each root. - Outputs: - - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). + Get the visit count distribution of child nodes for each root. + Returns: + - distribution (:obj:`List[List[Union[int, float]]]`): a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). """ distributions = [] for i in range(self.root_num): @@ -239,10 +289,12 @@ def get_distributions(self) -> List[List[Union[int, float]]]: return distributions - def get_values(self) -> float: + def get_values(self) -> List[float]: """ Overview: - Return the estimated value of each root. + Get the estimated value of each root. + Returns: + - values (:obj:`List[float]`): The estimated value of each root. """ values = [] for i in range(self.root_num): @@ -251,8 +303,20 @@ def get_values(self) -> float: class SearchResults: + """ + Overview: + The class to record the results of the simulations for the batch of roots. + Interfaces: + ``__init__``. + """ def __init__(self, num: int) -> None: + """ + Overview: + Initiaizes the attributes to be recorded. + Arguments: + -num (:obj:`int`): The number of search results(equal to ``batch_size``). + """ self.num = num self.nodes = [] self.search_paths = [] @@ -262,38 +326,6 @@ def __init__(self, num: int) -> None: self.search_lens = [] -def update_tree_q(root: Node, min_max_stats: MinMaxStats, discount_factor: float, players: int = 1) -> None: - """ - Overview: - Update the value sum and visit count of nodes along the search path. - Arguments: - - search_path: a vector of nodes on the search path. - - min_max_stats: a tool used to min-max normalize the q value. - - to_play: which player to play the game in the current node. - - value: the value to propagate along the search path. - - discount_factor: the discount factor of reward. - """ - node_stack = [] - node_stack.append(root) - while len(node_stack) > 0: - node = node_stack[-1] - node_stack.pop() - - if node != root: - true_reward = node.reward - if players == 1: - q_of_s_a = true_reward + discount_factor * node.value - elif players == 2: - q_of_s_a = true_reward + discount_factor * (-node.value) - - min_max_stats.update(q_of_s_a) - - for a in node.legal_actions: - child = node.get_child(a) - if child.expanded: - node_stack.append(child) - - def select_child( root: Node, min_max_stats: MinMaxStats, pb_c_base: float, pb_c_int: float, discount_factor: float, mean_q: float, players: int @@ -302,13 +334,13 @@ def select_child( Overview: Select the child node of the roots according to ucb scores. Arguments: - - root: the roots to select the child node. - - min_max_stats (:obj:`Class MinMaxStats`): a tool used to min-max normalize the score. - - pb_c_base (:obj:`Class Float`): constant c1 used in pUCT rule, typically 1.25. - - pb_c_int (:obj:`Class Float`): constant c2 used in pUCT rule, typically 19652. - - discount_factor (:obj:`Class Float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. - - mean_q (:obj:`Class Float`): the mean q value of the parent node. - - players (:obj:`Class Int`): the number of players. one/in self-play-mode board games. + - root(:obj:`Node`): The root to select the child node. + - min_max_stats (:obj:`MinMaxStats`): A tool used to min-max normalize the score. + - pb_c_base (:obj:`float`): Constant c1 used in pUCT rule, typically 1.25. + - pb_c_int (:obj:`float`): Constant c2 used in pUCT rule, typically 19652. + - discount_factor (:obj:`float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. + - mean_q (:obj:`float`): The mean q value of the parent node. + - players (:obj:`int`): The number of players. In two-player games such as board games, the value need to be negatived when backpropating. Returns: - action (:obj:`Union[int, float]`): Choose the action with the highest ucb score. """ @@ -347,25 +379,24 @@ def compute_ucb_score( """ Overview: Compute the ucb score of the child. - Arguments: - - child: the child node to compute ucb score. - - min_max_stats: a tool used to min-max normalize the score. - - parent_mean_q: the mean q value of the parent node. - - is_reset: whether the value prefix needs to be reset. - - total_children_visit_counts: the total visit counts of the child nodes of the parent node. - - parent_value_prefix: the value prefix of parent node. - - pb_c_base: constants c2 in muzero. - - pb_c_init: constants c1 in muzero. - - disount_factor: the discount factor of reward. - - players: the number of players. - - continuous_action_space: whether the action space is continuous in current env. - Outputs: - - ucb_value: the ucb score of the child. + Arguments: + - child (:obj:`Node`): the child node to compute ucb score. + - min_max_stats (:obj:`MinMaxStats`): a tool used to min-max normalize the score. + - parent_mean_q (:obj:`float`): the mean q value of the parent node. + - total_children_visit_counts (:obj:`float`): the total visit counts of the child nodes of the parent node. + - pb_c_base (:obj:`float`): constants c2 in muzero. + - pb_c_init (:obj:`float`): constants c1 in muzero. + - disount_factor (:obj:`float`): the discount factor of reward. + - players (:obj:`int`): the number of players. + Returns: + - ucb_value (:obj:`float`): the ucb score of the child. """ + # Compute the prior score. pb_c = math.log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init pb_c *= (math.sqrt(total_children_visit_counts) / (child.visit_count + 1)) - prior_score = pb_c * child.prior + + # Compute the value score. if child.visit_count == 0: value_score = parent_mean_q else: @@ -375,11 +406,13 @@ def compute_ucb_score( elif players == 2: value_score = true_reward + discount_factor * (-child.value) + # Normalize the value score. value_score = min_max_stats.normalize(value_score) if value_score < 0: value_score = 0 if value_score > 1: value_score = 1 + ucb_score = prior_score + value_score return ucb_score @@ -393,23 +426,23 @@ def batch_traverse( min_max_stats_lst: List[MinMaxStats], results: SearchResults, virtual_to_play: List, -) -> Tuple[List[None], List[None], List[None], list]: +) -> Tuple[List[None], List[None], List[None], Union[list, int]]: """ Overview: - traverse, also called expansion. process a batch roots parallelly. + traverse, also called expansion. Process a batch roots at once. Arguments: - - roots (:obj:`Any`): a batch of root nodes to be expanded. - - pb_c_base (:obj:`float`): constant c1 used in pUCT rule, typically 1.25. - - pb_c_init (:obj:`float`): constant c2 used in pUCT rule, typically 19652. + - roots (:obj:`Any`): A batch of root nodes to be expanded. + - pb_c_base (:obj:`float`): Constant c1 used in pUCT rule, typically 1.25. + - pb_c_init (:obj:`float`): Constant c2 used in pUCT rule, typically 19652. - discount_factor (:obj:`float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. - - virtual_to_play (:obj:`list`): the to_play list used in self_play collecting and training in board games, + - results (:obj:`SearchResults`): An instance to record the simulation results for all the roots in the batch. + - virtual_to_play (:obj:`list`): The to_play list used in self_play collecting and training in board games, `virtual` is to emphasize that actions are performed on an imaginary hidden state. - - continuous_action_space: whether the action space is continuous in current env. Returns: - - latent_state_index_in_search_path (:obj:`list`): the list of x/first index of hidden state vector of the searched node, i.e. the search depth. - - latent_state_index_in_batch (:obj:`list`): the list of y/second index of hidden state vector of the searched node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. - - last_actions (:obj:`list`): the action performed by the previous node. - - virtual_to_play (:obj:`list`): the to_play list used in self_play collecting and trainin gin board games, + - latent_state_index_in_search_path (:obj:`list`): The list of x/first index of hidden state vector of the searched node, i.e. the search depth. + - latent_state_index_in_batch (:obj:`list`): The list of y/second index of hidden state vector of the searched node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. + - last_actions (:obj:`list`): The action performed by the previous node. + - virtual_to_play (:obj:`list`): The to_play list used in self_play collecting and trainin gin board games, `virtual` is to emphasize that actions are performed on an imaginary hidden state. """ parent_q = 0.0 @@ -481,11 +514,11 @@ def backpropagate( Overview: Update the value sum and visit count of nodes along the search path. Arguments: - - search_path: a vector of nodes on the search path. - - min_max_stats: a tool used to min-max normalize the q value. - - to_play: which player to play the game in the current node. - - value: the value to propagate along the search path. - - discount_factor: the discount factor of reward. + - search_path (:obj:`List[Node]`): a vector of nodes on the search path. + - min_max_stats (:obj:`MinMaxStats`): a tool used to min-max normalize the q value. + - to_play (:obj:`int`): which player to play the game in the current node. + - value (:obj:`float`): the value to propagate along the search path. + - discount_factor (:obj:`float`): the discount factor of reward. """ assert to_play is None or to_play in [-1, 1, 2], to_play if to_play is None or to_play == -1: @@ -540,16 +573,16 @@ def batch_backpropagate( ) -> None: """ Overview: - Backpropagation along the search path to update the attributes. + Update the value sum and visit count of nodes along the search paths for each root in the batch. Arguments: - - simulation_index (:obj:`Class Int`): The index of latent state of the leaf node in the search path. - - discount_factor (:obj:`Class Float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. - - value_prefixs (:obj:`Class List`): the value prefixs of nodes along the search path. - - values (:obj:`Class List`): the values to propagate along the search path. - - policies (:obj:`Class List`): the policy logits of nodes along the search path. - - min_max_stats_lst (:obj:`Class List[MinMaxStats]`): a tool used to min-max normalize the q value. - - results (:obj:`Class List`): the search results. - - to_play (:obj:`Class List`): the batch of which player is playing on this node. + - simulation_index (:obj:`int`): The index of latent state of the leaf node in the search path. + - discount_factor (:obj:`float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. + - value_prefixs (:obj:`List`): the value prefixs of nodes along the search path. + - values (:obj:`List`): the values to propagate along the search path. + - policies (:obj:`List`): the policy logits of nodes along the search path. + - min_max_stats_lst (:obj:`List[MinMaxStats]`): a tool used to min-max normalize the q value. + - results (:obj:`List`): the search results. + - to_play (:obj:`List`): the batch of which player is playing on this node. """ for i in range(results.num): # ****** expand the leaf node ****** diff --git a/lzero/mcts/ptree/ptree_sez.py b/lzero/mcts/ptree/ptree_sez.py index 4e9f0890b..3b9f04437 100644 --- a/lzero/mcts/ptree/ptree_sez.py +++ b/lzero/mcts/ptree/ptree_sez.py @@ -14,9 +14,12 @@ class Node: """ - Overview: - the node base class for Sampled EfficientZero. - """ + Overview: + The Node class for Sampled EfficientZero. The basic functions of the node are implemented. + Interfaces: + ``__init__``, ``expand``, ``add_exploration_noise``, ``compute_mean_q``, ``get_trajectory``, \ + ``get_children_distribution``, ``get_child``, ``expanded``, ``value``. + """ def __init__( self, @@ -26,6 +29,16 @@ def __init__( num_of_sampled_actions: int = 20, continuous_action_space: bool = False, ) -> None: + """ + Overview: + Initializes a Node instance. + Arguments: + - prior (:obj:`float`): The prior probability of the node. + - legal_actions (:obj:`List`, optional): The list of legal actions of the node. Defaults to None. + - action_space_size (:obj:`int`, optional): The size of the action space. Defaults to 9. + - num_of_sampled_actions (:obj:`int`): The number of sampled actions, i.e. K in the Sampled MuZero paper. + - continuous_action_space (:obj:'bool'): Whether the action space is continous in current env. + """ self.prior = prior self.mu = None self.sigma = None @@ -52,11 +65,11 @@ def expand( Overview: Expand the child nodes of the current node. Arguments: - - to_play (:obj:`Class int`): which player to play the game in the current node. - - simulation_index (:obj:`Class int`): the x/first index of hidden state vector of the current node, i.e. the search depth. - - batch_index (:obj:`Class int`): the y/second index of hidden state vector of the current node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. - - value_prefix: (:obj:`Class float`): the value prefix of the current node. - - policy_logits: (:obj:`Class List`): the policy logit of the child nodes. + - to_play (:obj:`int`): The player to play the game in the current node. + - simulation_index (:obj:`int`): The x/first index of the hidden state vector of the current node, i.e., the search depth. + - batch_index (:obj:`int`): The y/second index of the hidden state vector of the current node, i.e., the index of the batch root node, its maximum is `batch_size`/`env_num`. + - value_prefix: (:obj:`float`): the value prefix of the current node. + - policy_logits (:obj:`List[float]`): The policy logits providing the priors for the child nodes. """ """ to varify ctree_efficientzero: @@ -154,10 +167,10 @@ def add_exploration_noise_to_sample_distribution( def add_exploration_noise(self, exploration_fraction: float, noises: List[float]) -> None: """ Overview: - Add a noise to the prior of the child nodes. + Add exploration noise to the priors of the child nodes.. Arguments: - - exploration_fraction: the fraction to add noise. - - noises (:obj: list): the vector of noises added to each child node. length is len(self.legal_actions) + - exploration_fraction (:obj:`float`): The fraction of exploration noise to be added. + - noises (:obj:`List[float]`): The list of noises to be added to the priors. Length is ``len(self.legal_actions)``. """ # ============================================================== # sampled related core code @@ -176,11 +189,13 @@ def add_exploration_noise(self, exploration_fraction: float, noises: List[float] def compute_mean_q(self, is_root: int, parent_q: float, discount_factor: float) -> float: """ Overview: - Compute the mean q value of the current node. + Compute the mean of the action values of all legal actions. Arguments: - - is_root (:obj:`int`): whether the current node is a root node. - - parent_q (:obj:`float`): the q value of the parent node. - - discount_factor (:obj:`float`): the discount_factor of reward. + - is_root (:obj:`bool`): Whether the current node is a root node. + - parent_q (:obj:`float`): The q value of the parent node. + - discount_factor (:obj:`float`): The discount factor of the reward. + Returns: + - mean_q (:obj:`float`): The mean of the action values. """ total_unsigned_q = 0.0 total_visits = 0 @@ -209,9 +224,9 @@ def print_out(self) -> None: def get_trajectory(self) -> List[Union[int, float]]: """ Overview: - Find the current best trajectory starts from the current node. - Outputs: - - traj: a vector of node index, which is the current best trajectory from this node. + Find the current best trajectory starting from the current node. + Returns: + - traj (:obj:`List[Union[int, float]]`): A vector of node indices representing the current best trajectory. """ traj = [] node = self @@ -223,6 +238,13 @@ def get_trajectory(self) -> List[Union[int, float]]: return traj def get_children_distribution(self) -> List[Union[int, float]]: + """ + Overview: + Get the distribution of visit counts among the child nodes. + Returns: + - distribution (:obj:`List[Union[int, float]]` or :obj:`None`): The distribution of visit counts among the children nodes. \ + If the legal_actions list is empty, returns None. + """ if self.legal_actions == []: return None # distribution = {a: 0 for a in self.legal_actions} @@ -238,7 +260,11 @@ def get_children_distribution(self) -> List[Union[int, float]]: def get_child(self, action: Union[int, float]) -> "Node": """ Overview: - get children node according to the input action. + Get the child node according to the input action. + Arguments: + - action (:obj:`Union[int, float]`): The action for which the child node is to be retrieved. + Returns: + - child (:obj:`Node`): The child node corresponding to the input action. """ if isinstance(action, Action): return self.children[action] @@ -248,13 +274,21 @@ def get_child(self, action: Union[int, float]) -> "Node": @property def expanded(self) -> bool: + """ + Overview: + Check if the node has been expanded. + Returns: + - expanded (:obj:`bool`): True if the node has been expanded, False otherwise. + """ return len(self.children) > 0 @property def value(self) -> float: """ Overview: - Return the estimated value of the current root node. + Return the estimated value of the current node. + Returns: + - value (:obj:`float`): The estimated value of the current node. """ if self.visit_count == 0: return 0 @@ -263,7 +297,13 @@ def value(self) -> float: class Roots: - + """ + Overview: + The class to process a batch of roots(Node instances) at the same time. + Interfaces: + ``__init__``, ``prepare``, ``prepare_no_noise``, ``clear``, ``get_trajectories``, \ + ``get_distributions``, ``get_values`` + """ def __init__( self, root_num: int, @@ -272,6 +312,16 @@ def __init__( num_of_sampled_actions: int = 20, continuous_action_space: bool = False, ) -> None: + """ + Overview: + Initializes an instance of the Roots class with the specified number of roots and legal actions. + Arguments: + - root_num (:obj:`int`): The number of roots. + - legal_actions_list(:obj:`List`): A list of the legal actions for each root. + - action_space_size (:obj:'int'): the size of action space of the current env. + - num_of_sampled_actions (:obj:'int'): The number of sampled actions, i.e. K in the Sampled MuZero paper. + - continuous_action_space (:obj:'bool'): whether the action space is continous in current env. + """ self.num = root_num self.root_num = root_num self.legal_actions_list = legal_actions_list # list of list @@ -328,13 +378,13 @@ def prepare( ) -> None: """ Overview: - Expand the roots and add noises. + Expand the roots and add noises for exploration. Arguments: - - root_noise_weight: the exploration fraction of roots - - noises: the vector of noise add to the roots. - - value_prefixs: the vector of value prefixs of each root. - - policies: the vector of policy logits of each root. - - to_play_batch: the vector of the player side of each root. + - root_noise_weight (:obj:`float`): the exploration fraction of roots + - noises (:obj:`List[float]`): the vector of noise add to the roots. + - value_prefixs (:obj:`List[float]`): the vector of rewards of each root. + - policies (:obj:`List[List[float]]`): the vector of policy logits of each root. + - to_play(:obj:`List`): The vector of the player side of each root. """ for i in range(self.root_num): @@ -351,9 +401,9 @@ def prepare_no_noise(self, value_prefixs: List[float], policies: List[List[float Overview: Expand the roots without noise. Arguments: - - value_prefixs: the vector of value prefixs of each root. - - policies: the vector of policy logits of each root. - - to_play_batch: the vector of the player side of each root. + - value_prefixs (:obj:`List[float]`): the vector of value prefixs of each root. + - policies (:obj:`List[List[float]]`): the vector of policy logits of each root. + - to_play(:obj:`List`): The vector of the player side of each root. """ for i in range(self.root_num): if to_play is None: @@ -364,14 +414,18 @@ def prepare_no_noise(self, value_prefixs: List[float], policies: List[List[float self.roots[i].visit_count += 1 def clear(self) -> None: + """ + Overview: + Clear all the roots in the list. + """ self.roots.clear() def get_trajectories(self) -> List[List[Union[int, float]]]: """ Overview: Find the current best trajectory starts from each root. - Outputs: - - traj: a vector of node index, which is the current best trajectory from each root. + Returns: + - traj (:obj:`List[List[Union[int, float]]]`): a vector of node index, which is the current best trajectory from each root. """ trajs = [] for i in range(self.root_num): @@ -381,9 +435,9 @@ def get_trajectories(self) -> List[List[Union[int, float]]]: def get_distributions(self) -> List[List[Union[int, float]]]: """ Overview: - Get the children distribution of each root. - Outputs: - - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). + Get the visit count distribution of child nodes for each root. + Returns: + - distribution (:obj:`List[List[Union[int, float]]]`): a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). """ distributions = [] for i in range(self.root_num): @@ -398,9 +452,9 @@ def get_sampled_actions(self) -> List[List[Union[int, float]]]: """ Overview: Get the sampled_actions of each root. - Outputs: - - python_sampled_actions: a vector of sampled_actions for each root, e.g. the size of original action space is 6, the K=3, - python_sampled_actions = [[1,3,0], [2,4,0], [5,4,1]]. + Returns: + - sampled_actions (:obj:`List[List[Union[int, float]]]`): a vector of sampled_actions for each root, \ + e.g. the size of original action space is 6, K=3, sampled_actions = [[1,3,0], [2,4,0], [5,4,1]]. """ # TODO(pu): root_sampled_actions bug in discere action space? sampled_actions = [] @@ -412,7 +466,9 @@ def get_sampled_actions(self) -> List[List[Union[int, float]]]: def get_values(self) -> float: """ Overview: - Return the estimated value of each root. + Get the estimated value of each root. + Returns: + - values (:obj:`List[float]`): The estimated value of each root. """ values = [] for i in range(self.root_num): @@ -421,8 +477,20 @@ def get_values(self) -> float: class SearchResults: + """ + Overview: + The class to record the results of the simulations for the batch of roots. + Interfaces: + ``__init__``. + """ - def __init__(self, num: int): + def __init__(self, num: int) -> None: + """ + Overview: + Initiaizes the attributes to be recorded. + Arguments: + -num (:obj:`int`): The number of search results(equal to ``batch_size``). + """ self.num = num self.nodes = [] self.search_paths = [] @@ -446,14 +514,14 @@ def select_child( Overview: Select the child node of the roots according to ucb scores. Arguments: - - root: the roots to select the child node. - - min_max_stats (:obj:`Class MinMaxStats`): a tool used to min-max normalize the score. - - pb_c_base (:obj:`Class Float`): constant c1 used in pUCT rule, typically 1.25. - - pb_c_int (:obj:`Class Float`): constant c2 used in pUCT rule, typically 19652. - - discount_factor (:obj:`Class Float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. - - mean_q (:obj:`Class Float`): the mean q value of the parent node. - - players (:obj:`Class Float`): the number of players. one/in self-play-mode board games. - - continuous_action_space: whether the action space is continous in current env. + - root(:obj:`Node`): The root to select the child node. + - min_max_stats (:obj:`MinMaxStats`): A tool used to min-max normalize the score. + - pb_c_base (:obj:`float`): Constant c1 used in pUCT rule, typically 1.25. + - pb_c_int (:obj:`float`): Constant c2 used in pUCT rule, typically 19652. + - discount_factor (:obj:`float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. + - mean_q (:obj:`float`): The mean q value of the parent node. + - players (:obj:`int`): The number of players. In two-player games such as board games, the value need to be negatived when backpropating. + - continuous_action_space (:obj: `bool`): Whether the action space is continous in current env. Returns: - action (:obj:`Union[int, float]`): Choose the action with the highest ucb score. """ @@ -504,20 +572,20 @@ def compute_ucb_score( """ Overview: Compute the ucb score of the child. - Arguments: - - child: the child node to compute ucb score. - - min_max_stats: a tool used to min-max normalize the score. - - parent_mean_q: the mean q value of the parent node. - - is_reset: whether the value prefix needs to be reset. - - total_children_visit_counts: the total visit counts of the child nodes of the parent node. - - parent_value_prefix: the value prefix of parent node. - - pb_c_base: constants c2 in muzero. - - pb_c_init: constants c1 in muzero. - - disount_factor: the discount factor of reward. - - players: the number of players. - - continuous_action_space: whether the action space is continous in current env. - Outputs: - - ucb_value: the ucb score of the child. + Arguments: + - child (:obj:`Node`): the child node to compute ucb score. + - min_max_stats (:obj:`MinMaxStats`): a tool used to min-max normalize the score. + - parent_mean_q (:obj:`float`): the mean q value of the parent node. + - is_reset (:obj:`float`): whether the value prefix needs to be reset. + - total_children_visit_counts (:obj:`float`): the total visit counts of the child nodes of the parent node. + - parent_value_prefix(:obj:`float`): The value prefix of parent node. + - pb_c_base (:obj:`float`): constants c2 in muzero. + - pb_c_init (:obj:`float`): constants c1 in muzero. + - disount_factor (:obj:`float`): the discount factor of reward. + - players (:obj:`int`): the number of players. + - continuous_action_space (:obj: `bool`): Whether the action space is continous in current env. + Returns: + - ucb_value (:obj:`float`): the ucb score of the child. """ assert total_children_visit_counts == parent.visit_count pb_c = math.log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init @@ -557,11 +625,13 @@ def compute_ucb_score( elif players == 2: value_score = true_reward + discount_factor * (-child.value) + # Normalize the value score. value_score = min_max_stats.normalize(value_score) if value_score < 0: value_score = 0 if value_score > 1: value_score = 1 + ucb_score = prior_score + value_score return ucb_score @@ -572,27 +642,28 @@ def batch_traverse( pb_c_base: float, pb_c_init: float, discount_factor: float, - min_max_stats_lst, + min_max_stats_lst: List[MinMaxStats], results: SearchResults, virtual_to_play: List, continuous_action_space: bool = False, ) -> Tuple[List[int], List[int], List[Union[int, float]], List]: """ Overview: - traverse, also called expansion. process a batch roots parallely. + traverse, also called expansion. Process a batch roots at once. Arguments: - - roots (:obj:`Any`): a batch of root nodes to be expanded. - - pb_c_base (:obj:`float`): constant c1 used in pUCT rule, typically 1.25. - - pb_c_init (:obj:`float`): constant c2 used in pUCT rule, typically 19652. + - roots (:obj:`Any`): A batch of root nodes to be expanded. + - pb_c_base (:obj:`float`): Constant c1 used in pUCT rule, typically 1.25. + - pb_c_init (:obj:`float`): Constant c2 used in pUCT rule, typically 19652. - discount_factor (:obj:`float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. - - virtual_to_play (:obj:`list`): the to_play list used in self_play collecting and training in board games, + - results (:obj:`SearchResults`): An instance to record the simulation results for all the roots in the batch. + - virtual_to_play (:obj:`list`): The to_play list used in self_play collecting and training in board games, `virtual` is to emphasize that actions are performed on an imaginary hidden state. - continuous_action_space: whether the action space is continous in current env. Returns: - - latent_state_index_in_search_path (:obj:`list`): the list of x/first index of hidden state vector of the searched node, i.e. the search depth. - - latent_state_index_in_batch (:obj:`list`): the list of y/second index of hidden state vector of the searched node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. - - last_actions (:obj:`list`): the action performed by the previous node. - - virtual_to_play (:obj:`list`): the to_play list used in self_play collecting and trainin gin board games, + - latent_state_index_in_search_path (:obj:`list`): The list of x/first index of hidden state vector of the searched node, i.e. the search depth. + - latent_state_index_in_batch (:obj:`list`): The list of y/second index of hidden state vector of the searched node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. + - last_actions (:obj:`list`): The action performed by the previous node. + - virtual_to_play (:obj:`list`): The to_play list used in self_play collecting and trainin gin board games, `virtual` is to emphasize that actions are performed on an imaginary hidden state. """ parent_q = 0.0 @@ -666,13 +737,13 @@ def backpropagate( Overview: Update the value sum and visit count of nodes along the search path. Arguments: - - search_path: a vector of nodes on the search path. - - min_max_stats: a tool used to min-max normalize the q value. - - to_play: which player to play the game in the current node. - - value: the value to propagate along the search path. - - discount_factor: the discount factor of reward. + - search_path (:obj:`List[Node]`): a vector of nodes on the search path. + - min_max_stats (:obj:`MinMaxStats`): a tool used to min-max normalize the q value. + - to_play (:obj:`int`): which player to play the game in the current node. + - value (:obj:`float`): the value to propagate along the search path. + - discount_factor (:obj:`float`): the discount factor of reward. """ - assert to_play is None or to_play in [-1, 1, 2], to_play + assert to_play is None or to_play in [-1, 1, 2], f'to_play is {to_play}!' if to_play is None or to_play == -1: # for play-with-bot-mode bootstrap_value = value @@ -738,21 +809,21 @@ def batch_backpropagate( min_max_stats_lst: List[MinMaxStats], results: SearchResults, is_reset_list: List, - to_play: list = None + to_play: list = None, ) -> None: """ Overview: - Backpropagation along the search path to update the attributes. + Update the value sum and visit count of nodes along the search paths for each root in the batch. Arguments: - - simulation_index (:obj:`Class Int`): The index of latent state of the leaf node in the search path. - - discount_factor (:obj:`Class Float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. - - value_prefixs (:obj:`Class List`): the value prefixs of nodes along the search path. - - values (:obj:`Class List`): the values to propagate along the search path. - - policies (:obj:`Class List`): the policy logits of nodes along the search path. - - min_max_stats_lst (:obj:`Class List[MinMaxStats]`): a tool used to min-max normalize the q value. - - results (:obj:`Class List`): the search results. - - is_reset_list (:obj:`Class List`): the vector of is_reset nodes along the search path, where is_reset represents for whether the parent value prefix needs to be reset. - - to_play (:obj:`Class List`): the batch of which player is playing on this node. + - simulation_index (:obj:`int`): The index of latent state of the leaf node in the search path. + - discount_factor (:obj:`float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. + - value_prefixs (:obj:`List`): the value prefixs of nodes along the search path. + - values (:obj:`List`): the values to propagate along the search path. + - policies (:obj:`List`): the policy logits of nodes along the search path. + - min_max_stats_lst (:obj:`List[MinMaxStats]`): a tool used to min-max normalize the q value. + - results (:obj:`List`): the search results. + - is_reset_list (:obj:`List`): the vector of is_reset nodes along the search path, where is_reset represents for whether the parent value prefix needs to be reset. + - to_play (:obj:`List`): the batch of which player is playing on this node. """ for i in range(results.num): # ****** expand the leaf node ****** @@ -779,30 +850,26 @@ def batch_backpropagate( class Action: """ - Class that represents an action of a game. - - Attributes: - value (Union[int, np.ndarray]): The value of the action. Can be either an integer or a numpy array. + Overview: + Class that represents an action of the game. """ - def __init__(self, value: Union[int, np.ndarray]) -> None: """ - Initializes the Action with the given value. - - Args: - value (Union[int, np.ndarray]): The value of the action. + Overview: + Initializes the Action with the given value. + Arguments: + - value (:obj:`Union[int, np.ndarray]`): The value of the action. Can be either an integer or a numpy array. """ self.value = value def __hash__(self) -> int: """ - Returns a hash of the Action's value. - - If the value is a numpy array, it is flattened to a tuple and then hashed. - If the value is a single integer, it is hashed directly. - + Overview: + Returns a hash of the Action's value. \ + If the value is a numpy array, it is flattened to a tuple and then hashed. \ + If the value is a single integer, it is hashed directly. Returns: - int: The hash of the Action's value. + hash (:obj:`int`): The hash of the Action's value. """ if isinstance(self.value, np.ndarray): if self.value.ndim == 0: @@ -814,16 +881,14 @@ def __hash__(self) -> int: def __eq__(self, other: "Action") -> bool: """ - Determines if this Action is equal to another Action. - - If both values are numpy arrays, they are compared element-wise. - Otherwise, they are compared directly. - - Args: - other (Action): The Action to compare with. - + Overview: + Determines if this Action is equal to another Action. \ + If both values are numpy arrays, they are compared element-wise. \ + Otherwise, they are compared directly. + Arguments: + - other (:obj:`Action`): The Action to compare with. Returns: - bool: True if the two Actions are equal, False otherwise. + - bool (:obj:`bool`): True if the two Actions are equal, False otherwise. """ if isinstance(self.value, np.ndarray) and isinstance(other.value, np.ndarray): return np.array_equal(self.value, other.value) @@ -832,21 +897,20 @@ def __eq__(self, other: "Action") -> bool: def __gt__(self, other: "Action") -> bool: """ - Determines if this Action's value is greater than another Action's value. - - Args: - other (Action): The Action to compare with. - + Overview: + Determines if this Action's value is greater than another Action's value. + Arguments: + - other (:obj:`Action`): The Action to compare with. Returns: - bool: True if this Action's value is greater, False otherwise. + - bool (:obj:`bool`): True if the two Actions are equal, False otherwise. """ return self.value > other.value def __repr__(self) -> str: """ - Returns a string representation of this Action. - + Overview: + Returns a string representation of this Action. Returns: - str: A string representation of the Action's value. + - str (:obj:`str`): A string representation of the Action's value. """ return str(self.value) diff --git a/lzero/mcts/ptree/ptree_stochastic_mz.py b/lzero/mcts/ptree/ptree_stochastic_mz.py index 4384f2ab6..7dee5d383 100644 --- a/lzero/mcts/ptree/ptree_stochastic_mz.py +++ b/lzero/mcts/ptree/ptree_stochastic_mz.py @@ -13,12 +13,23 @@ class Node: """ - Overview: - the node base class for Stochastic MuZero. - Arguments: - """ + Overview: + The Node class for Stochastic MuZero. The basic functions of the node are implemented. + Interfaces: + ``__init__``, ``expand``, ``add_exploration_noise``, ``compute_mean_q``, ``get_trajectory``, \ + ``get_children_distribution``, ``get_child``, ``expanded``, ``value``. + """ def __init__(self, prior: float, legal_actions: List = None, action_space_size: int = 9, is_chance: bool = False, chance_space_size: int = 2) -> None: + """ + Overview: + Initializes a Node instance. + Arguments: + - prior (:obj:`float`): The prior probability of the node. + - legal_actions (:obj:`List`, optional): The list of legal actions of the node. Defaults to None. + - action_space_size (:obj:`int`, optional): The size of the action space. Defaults to 9. + - is_chance (:obj:`bool`) Whether the node is a chance node. + """ self.prior = prior self.legal_actions = legal_actions self.action_space_size = action_space_size @@ -46,11 +57,11 @@ def expand( Overview: Expand the child nodes of the current node. Arguments: - - to_play (:obj:`Class int`): which player to play the game in the current node. - - latent_state_index_in_search_path (:obj:`Class int`): the x/first index of latent state vector of the current node, i.e. the search depth. - - latent_state_index_in_batch (:obj:`Class int`): the y/second index of latent state vector of the current node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. - - value_prefix: (:obj:`Class float`): the value prefix of the current node. - - policy_logits: (:obj:`Class List`): the policy logit of the child nodes. + - to_play (:obj:`int`): which player to play the game in the current node. + - latent_state_index_in_search_path (:obj:`int`): the x/first index of latent state vector of the current node, i.e. the search depth. + - latent_state_index_in_batch (:obj:`int`): the y/second index of latent state vector of the current node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. + - reward: (:obj:`float`): the value prefix of the current node. + - policy_logits: (:obj:`List`): the policy logit of the child nodes. """ self.to_play = to_play self.reward = reward @@ -80,9 +91,10 @@ def expand( def add_exploration_noise(self, exploration_fraction: float, noises: List[float]) -> None: """ Overview: - add exploration noise to priors + Add exploration noise to the priors. Arguments: - - noises (:obj: list): length is len(self.legal_actions) + - exploration_fraction (:obj:`float`): The fraction of exploration noise to be added. + - noises (:obj:`List[float]`): The list of noises to be added to the priors. """ for i, a in enumerate(self.legal_actions): """ @@ -96,14 +108,16 @@ def add_exploration_noise(self, exploration_fraction: float, noises: List[float] prior = child.prior child.prior = prior * (1 - exploration_fraction) + noise * exploration_fraction - def compute_mean_q(self, is_root: int, parent_q: float, discount_factor: float) -> float: + def compute_mean_q(self, is_root: bool, parent_q: float, discount_factor: float) -> float: """ Overview: - Compute the mean q value of the current node. + Compute the mean of the action values of all legal actions. Arguments: - - is_root (:obj:`int`): whether the current node is a root node. - - parent_q (:obj:`float`): the q value of the parent node. - - discount_factor (:obj:`float`): the discount_factor of reward. + - is_root (:obj:`bool`): Whether the current node is a root node. + - parent_q (:obj:`float`): The q value of the parent node. + - discount_factor (:obj:`float`): The discount factor of the reward. + Returns: + - mean_q (:obj:`float`): The mean of the action values. """ total_unsigned_q = 0.0 total_visits = 0 @@ -126,9 +140,9 @@ def compute_mean_q(self, is_root: int, parent_q: float, discount_factor: float) def get_trajectory(self) -> List[Union[int, float]]: """ Overview: - Find the current best trajectory starts from the current node. - Outputs: - - traj: a vector of node index, which is the current best trajectory from this node. + Find the current best trajectory starting from the current node. + Returns: + - traj (:obj:`List[Union[int, float]]`): A vector of node indices representing the current best trajectory. """ # TODO(pu): best action traj = [] @@ -142,6 +156,13 @@ def get_trajectory(self) -> List[Union[int, float]]: return traj def get_children_distribution(self) -> List[Union[int, float]]: + """ + Overview: + Get the distribution of visit counts among the child nodes. + Returns: + - distribution (:obj:`List[Union[int, float]]` or :obj:`None`): The distribution of visit counts among the children nodes. \ + If the legal_actions list is empty, returns None. + """ if self.legal_actions == []: return None distribution = {a: 0 for a in self.legal_actions} @@ -156,7 +177,11 @@ def get_children_distribution(self) -> List[Union[int, float]]: def get_child(self, action: Union[int, float]) -> "Node": """ Overview: - get children node according to the input action. + Get the child node according to the input action. + Arguments: + - action (:obj:`Union[int, float]`): The action for which the child node is to be retrieved. + Returns: + - child (:obj:`Node`): The child node corresponding to the input action. """ if not isinstance(action, np.int64): action = int(action) @@ -164,13 +189,21 @@ def get_child(self, action: Union[int, float]) -> "Node": @property def expanded(self) -> bool: + """ + Overview: + Check if the node has been expanded. + Returns: + - expanded (:obj:`bool`): True if the node has been expanded, False otherwise. + """ return len(self.children) > 0 @property def value(self) -> float: """ Overview: - Return the estimated value of the current root node. + Return the estimated value of the current node. + Returns: + - value (:obj:`float`): The estimated value of the current node. """ if self.visit_count == 0: return 0 @@ -179,8 +212,22 @@ def value(self) -> float: class Roots: + """ + Overview: + The class to process a batch of roots(Node instances) at the same time. + Interfaces: + ``__init__``, ``prepare``, ``prepare_no_noise``, ``clear``, ``get_trajectories``, \ + ``get_distributions``, ``get_values`` + """ def __init__(self, root_num: int, legal_actions_list: List) -> None: + """ + Overview: + Initializes an instance of the Roots class with the specified number of roots and legal actions. + Arguments: + - root_num (:obj:`int`): The number of roots. + - legal_actions_list(:obj:`List`): A list of the legal actions for each root. + """ self.num = root_num self.root_num = root_num self.legal_actions_list = legal_actions_list # list of list @@ -203,13 +250,13 @@ def prepare( ) -> None: """ Overview: - Expand the roots and add noises. + Expand the roots and add noises for exploration. Arguments: - - root_noise_weight: the exploration fraction of roots - - noises: the vector of noise add to the roots. - - rewards: the vector of rewards of each root. - - policies: the vector of policy logits of each root. - - to_play_batch: the vector of the player side of each root. + - root_noise_weight (:obj:`float`): the exploration fraction of roots + - noises (:obj:`List[float]`): the vector of noise add to the roots. + - rewards (:obj:`List[float]`): the vector of rewards of each root. + - policies (:obj:`List[List[float]]`): the vector of policy logits of each root. + - to_play(:obj:`List`): The vector of the player side of each root. """ for i in range(self.root_num): # to_play: int, latent_state_index_in_search_path: int, latent_state_index_in_batch: int, @@ -227,9 +274,9 @@ def prepare_no_noise(self, rewards: List[float], policies: List[List[float]], to Overview: Expand the roots without noise. Arguments: - - rewards: the vector of rewards of each root. - - policies: the vector of policy logits of each root. - - to_play_batch: the vector of the player side of each root. + - rewards (:obj:`List[float]`): the vector of rewards of each root. + - policies (:obj:`List[List[float]]`): the vector of policy logits of each root. + - to_play(:obj:`List`): The vector of the player side of each root. """ for i in range(self.root_num): if to_play is None: @@ -240,14 +287,18 @@ def prepare_no_noise(self, rewards: List[float], policies: List[List[float]], to self.roots[i].visit_count += 1 def clear(self) -> None: + """ + Overview: + Clear all the roots in the list. + """ self.roots.clear() def get_trajectories(self) -> List[List[Union[int, float]]]: """ Overview: Find the current best trajectory starts from each root. - Outputs: - - traj: a vector of node index, which is the current best trajectory from each root. + Returns: + - traj (:obj:`List[List[Union[int, float]]]`): a vector of node index, which is the current best trajectory from each root. """ trajs = [] for i in range(self.root_num): @@ -257,9 +308,9 @@ def get_trajectories(self) -> List[List[Union[int, float]]]: def get_distributions(self) -> List[List[Union[int, float]]]: """ Overview: - Get the children distribution of each root. - Outputs: - - distribution: a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). + Get the visit count distribution of child nodes for each root. + Returns: + - distribution (:obj:`List[List[Union[int, float]]]`): a vector of distribution of child nodes in the format of visit count (i.e. [1,3,0,2,5]). """ distributions = [] for i in range(self.root_num): @@ -267,10 +318,12 @@ def get_distributions(self) -> List[List[Union[int, float]]]: return distributions - def get_values(self) -> float: + def get_values(self) -> List[float]: """ Overview: - Return the estimated value of each root. + Get the estimated value of each root. + Returns: + - values (:obj:`List[float]`): The estimated value of each root. """ values = [] for i in range(self.root_num): @@ -279,8 +332,20 @@ def get_values(self) -> float: class SearchResults: + """ + Overview: + The class to record the results of the simulations for the batch of roots. + Interfaces: + ``__init__``. + """ def __init__(self, num: int) -> None: + """ + Overview: + Initiaizes the attributes to be recorded. + Arguments: + -num (:obj:`int`): The number of search results(equal to ``batch_size``). + """ self.num = num self.nodes = [] self.search_paths = [] @@ -289,39 +354,6 @@ def __init__(self, num: int) -> None: self.last_actions = [] self.search_lens = [] - -def update_tree_q(root: Node, min_max_stats: MinMaxStats, discount_factor: float, players: int = 1) -> None: - """ - Overview: - Update the value sum and visit count of nodes along the search path. - Arguments: - - search_path: a vector of nodes on the search path. - - min_max_stats: a tool used to min-max normalize the q value. - - to_play: which player to play the game in the current node. - - value: the value to propagate along the search path. - - discount_factor: the discount factor of reward. - """ - node_stack = [] - node_stack.append(root) - while len(node_stack) > 0: - node = node_stack[-1] - node_stack.pop() - - if node != root: - true_reward = node.reward - if players == 1: - q_of_s_a = true_reward + discount_factor * node.value - elif players == 2: - q_of_s_a = true_reward + discount_factor * (-node.value) - - min_max_stats.update(q_of_s_a) - - for a in node.legal_actions: - child = node.get_child(a) - if child.expanded: - node_stack.append(child) - - def select_child( node: Node, min_max_stats: MinMaxStats, pb_c_base: float, pb_c_int: float, discount_factor: float, mean_q: float, players: int @@ -330,13 +362,13 @@ def select_child( Overview: Select the child node of the roots according to ucb scores. Arguments: - - node: the node to select the child node. - - min_max_stats (:obj:`Class MinMaxStats`): a tool used to min-max normalize the score. - - pb_c_base (:obj:`Class Float`): constant c1 used in pUCT rule, typically 1.25. - - pb_c_int (:obj:`Class Float`): constant c2 used in pUCT rule, typically 19652. - - discount_factor (:obj:`Class Float`): discount_factor factor used i calculating bootstrapped value, if env is board_games, we set discount_factor=1. - - mean_q (:obj:`Class Float`): the mean q value of the parent node. - - players (:obj:`Class Int`): the number of players. one/two_player mode board games. + - node(:obj:`Node`): The root to select the child node. + - min_max_stats (:obj:`MinMaxStats`): A tool used to min-max normalize the score. + - pb_c_base (:obj:`float`): Constant c1 used in pUCT rule, typically 1.25. + - pb_c_int (:obj:`float`): Constant c2 used in pUCT rule, typically 19652. + - discount_factor (:obj:`float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. + - mean_q (:obj:`float`): The mean q value of the parent node. + - players (:obj:`int`): The number of players. In two-player games such as board games, the value need to be negatived when backpropating. Returns: - action (:obj:`Union[int, float]`): Choose the action with the highest ucb score. """ @@ -387,25 +419,24 @@ def compute_ucb_score( """ Overview: Compute the ucb score of the child. - Arguments: - - child: the child node to compute ucb score. - - min_max_stats: a tool used to min-max normalize the score. - - parent_mean_q: the mean q value of the parent node. - - is_reset: whether the value prefix needs to be reset. - - total_children_visit_counts: the total visit counts of the child nodes of the parent node. - - parent_value_prefix: the value prefix of parent node. - - pb_c_base: constants c2 in muzero. - - pb_c_init: constants c1 in muzero. - - disount_factor: the discount factor of reward. - - players: the number of players. - - continuous_action_space: whether the action space is continous in current env. - Outputs: - - ucb_value: the ucb score of the child. + Arguments: + - child (:obj:`Node`): the child node to compute ucb score. + - min_max_stats (:obj:`MinMaxStats`): a tool used to min-max normalize the score. + - parent_mean_q (:obj:`float`): the mean q value of the parent node. + - total_children_visit_counts (:obj:`float`): the total visit counts of the child nodes of the parent node. + - pb_c_base (:obj:`float`): constants c2 in muzero. + - pb_c_init (:obj:`float`): constants c1 in muzero. + - disount_factor (:obj:`float`): the discount factor of reward. + - players (:obj:`int`): the number of players. + Returns: + - ucb_value (:obj:`float`): the ucb score of the child. """ + # Compute the prior score. pb_c = math.log((total_children_visit_counts + pb_c_base + 1) / pb_c_base) + pb_c_init pb_c *= (math.sqrt(total_children_visit_counts) / (child.visit_count + 1)) - prior_score = pb_c * child.prior + + # Compute the value score. if child.visit_count == 0: value_score = parent_mean_q else: @@ -415,11 +446,13 @@ def compute_ucb_score( elif players == 2: value_score = true_reward + discount_factor * (-child.value) + # Normalize the value score. value_score = min_max_stats.normalize(value_score) if value_score < 0: value_score = 0 if value_score > 1: value_score = 1 + ucb_score = prior_score + value_score return ucb_score @@ -437,20 +470,20 @@ def batch_traverse( """ Overview: - traverse, also called selection. process a batch roots parallely. + traverse, also called expansion. Process a batch roots at once. Arguments: - - roots (:obj:`Any`): a batch of root nodes to be expanded. - - pb_c_base (:obj:`float`): constant c1 used in pUCT rule, typically 1.25. - - pb_c_init (:obj:`float`): constant c2 used in pUCT rule, typically 19652. - - discount_factor (:obj:`float`): discount_factor factor used i calculating bootstrapped value, if env is board_games, we set discount_factor=1. - - virtual_to_play (:obj:`list`): the to_play list used in self_play collecting and training in board games, + - roots (:obj:`Any`): A batch of root nodes to be expanded. + - pb_c_base (:obj:`float`): Constant c1 used in pUCT rule, typically 1.25. + - pb_c_init (:obj:`float`): Constant c2 used in pUCT rule, typically 19652. + - discount_factor (:obj:`float`): The discount factor used in calculating bootstrapped value, if env is board_games, we set discount_factor=1. + - results (:obj:`SearchResults`): An instance to record the simulation results for all the roots in the batch. + - virtual_to_play (:obj:`list`): The to_play list used in self_play collecting and training in board games, `virtual` is to emphasize that actions are performed on an imaginary hidden state. - - continuous_action_space: whether the action space is continous in current env. Returns: - - latent_state_index_in_search_path (:obj:`list`): the list of x/first index of latent state vector of the searched node, i.e. the search depth. - - latent_state_index_in_batch (:obj:`list`): the list of y/second index of latent state vector of the searched node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. - - last_actions (:obj:`list`): the action performed by the previous node. - - virtual_to_play (:obj:`list`): the to_play list used in self_play collecting and trainin gin board games, + - latent_state_index_in_search_path (:obj:`list`): The list of x/first index of hidden state vector of the searched node, i.e. the search depth. + - latent_state_index_in_batch (:obj:`list`): The list of y/second index of hidden state vector of the searched node, i.e. the index of batch root node, its maximum is ``batch_size``/``env_num``. + - last_actions (:obj:`list`): The action performed by the previous node. + - virtual_to_play (:obj:`list`): The to_play list used in self_play collecting and trainin gin board games, `virtual` is to emphasize that actions are performed on an imaginary hidden state. """ parent_q = 0.0 @@ -522,13 +555,13 @@ def backpropagate( Overview: Update the value sum and visit count of nodes along the search path. Arguments: - - search_path: a vector of nodes on the search path. - - min_max_stats: a tool used to min-max normalize the q value. - - to_play: which player to play the game in the current node. - - value: the value to propagate along the search path. - - discount_factor: the discount factor of reward. + - search_path (:obj:`List[Node]`): a vector of nodes on the search path. + - min_max_stats (:obj:`MinMaxStats`): a tool used to min-max normalize the q value. + - to_play (:obj:`int`): which player to play the game in the current node. + - value (:obj:`float`): the value to propagate along the search path. + - discount_factor (:obj:`float`): the discount factor of reward. """ - assert to_play is None or to_play in [-1, 1, 2] + assert to_play is None or to_play in [-1, 1, 2], to_play if to_play is None or to_play == -1: # for play-with-bot mode bootstrap_value = value @@ -578,17 +611,17 @@ def batch_backpropagate( ) -> None: """ Overview: - Backpropagation along the search path to update the attributes. + Update the value sum and visit count of nodes along the search paths for each root in the batch. Arguments: - - latent_state_index_in_search_path (:obj:`Class Int`): the index of latent state vector. - - discount_factor (:obj:`Class Float`): discount_factor factor used i calculating bootstrapped value, + - latent_state_index_in_search_path (:obj:`int`): the index of latent state vector. + - discount_factor (:obj:`float`): discount_factor factor used i calculating bootstrapped value, if env is board_games, we set discount_factor=1. - - value_prefixs (:obj:`Class List`): the value prefixs of nodes along the search path. - - values (:obj:`Class List`): the values to propagate along the search path. - - policies (:obj:`Class List`): the policy logits of nodes along the search path. - - min_max_stats_lst (:obj:`Class List[MinMaxStats]`): a tool used to min-max normalize the q value. - - results (:obj:`Class List`): the search results. - - to_play (:obj:`Class List`): the batch of which player is playing on this node. + - value_prefixs (:obj:`List`): the value prefixs of nodes along the search path. + - values (:obj:`List`): the values to propagate along the search path. + - policies (:obj:`List`): the policy logits of nodes along the search path. + - min_max_stats_lst (:obj:`List[MinMaxStats]`): a tool used to min-max normalize the q value. + - results (:obj:`List`): the search results. + - to_play (:obj:`List`): the batch of which player is playing on this node. """ if leaf_idx_list is None: leaf_idx_list = list(range(results.num)) diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index 20336277c..9d2c63f21 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -23,10 +23,15 @@ class EfficientZeroMCTSCtree(object): """ Overview: - MCTSCtree for EfficientZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in C++. + The C++ implementation of MCTS (batch format) for EfficientZero. \ + It completes the ``roots``and ``search`` methods by calling functions in module ``ctree_efficientzero``, \ + which are implemented in C++. Interfaces: - __init__, roots, search - + ``__init__``, ``roots``, ``search`` + + ..note:: + The benefit of searching for a batch of nodes at the same time is that \ + it can be parallelized during model inference, thus saving time. """ config = dict( @@ -44,7 +49,15 @@ class EfficientZeroMCTSCtree(object): @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + A class method that returns a default configuration in the form of an EasyDict object. + Returns: + - cfg (:obj:`EasyDict`): The dict of the default configuration. + """ + # Create a deep copy of the `config` attribute of the class. cfg = EasyDict(copy.deepcopy(cls.config)) + # Add a new attribute `cfg_type` to the `cfg` object. cfg.cfg_type = cls.__name__ + 'Dict' return cfg @@ -54,8 +67,12 @@ def __init__(self, cfg: EasyDict = None) -> None: Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key in the default configuration, the user-provided value will override the default configuration. Otherwise, the default configuration will be used. + Arguments: + - cfg (:obj:`EasyDict`): The configuration passed in by the user. """ + # Get the default configuration. default_config = self.default_config() + # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config self.inverse_scalar_transform_handle = InverseScalarTransform( @@ -66,13 +83,15 @@ def __init__(self, cfg: EasyDict = None) -> None: def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "ez_ctree.Roots": """ Overview: - The initialization of CRoots with root num and legal action lists. + Initializes a batch of roots to search parallelly later. Arguments: - - root_num (:obj:'int'): the number of the current root. - - legal_action_list (:obj:'List'): the vector of the legal action of this root. + - root_num (:obj:`int`): the number of the roots in a batch. + - legal_action_list (:obj:`List[Any]`): the vector of the legal actions for the roots. + + ..note:: + The initialization is achieved by the ``Roots`` class from the ``ctree_efficientzero`` module. """ - from lzero.mcts.ctree.ctree_efficientzero import ez_tree as ctree - return ctree.Roots(active_collect_env_num, legal_actions) + return tree_efficientzero.Roots(active_collect_env_num, legal_actions) def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], @@ -80,13 +99,17 @@ def search( ) -> None: """ Overview: - Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference. - Use the cpp ctree. + Do MCTS for a batch of roots. Parallel in model inference. \ + Use C++ to implement the tree search. Arguments: - - roots (:obj:`Any`): a batch of expanded root nodes - - latent_state_roots (:obj:`list`): the hidden states of the roots - - reward_hidden_state_roots (:obj:`list`): the value prefix hidden states in LSTM of the roots - - to_play_batch (:obj:`list`): the to_play_batch list used in self-play-mode board games + - roots (:obj:`Any`): a batch of expanded root nodes. + - latent_state_roots (:obj:`list`): the hidden states of the roots. + - reward_hidden_state_roots (:obj:`list`): the value prefix hidden states in LSTM of the roots. + - model (:obj:`torch.nn.Module`): The model used for inference. + - to_play (:obj:`list`): the to_play list used in in self-play-mode board games. + + .. note:: + The core functions ``batch_traverse`` and ``batch_backpropagate`` are implemented in C++. """ with torch.no_grad(): model.eval() @@ -202,10 +225,15 @@ def search( class MuZeroMCTSCtree(object): """ Overview: - MCTSCtree for MuZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in C++. - + The C++ implementation of MCTS (batch format) for MuZero. \ + It completes the ``roots``and ``search`` methods by calling functions in module ``ctree_muzero``, \ + which are implemented in C++. Interfaces: - __init__, roots, search + ``__init__``, ``roots``, ``search`` + + ..note:: + The benefit of searching for a batch of nodes at the same time is that \ + it can be parallelized during model inference, thus saving time. """ config = dict( @@ -223,18 +251,30 @@ class MuZeroMCTSCtree(object): @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + A class method that returns a default configuration in the form of an EasyDict object. + Returns: + - cfg (:obj:`EasyDict`): The dict of the default configuration. + """ + # Create a deep copy of the `config` attribute of the class. cfg = EasyDict(copy.deepcopy(cls.config)) + # Add a new attribute `cfg_type` to the `cfg` object. cfg.cfg_type = cls.__name__ + 'Dict' return cfg def __init__(self, cfg: EasyDict = None) -> None: """ Overview: - Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key - in the default configuration, the user-provided value will override the default configuration. Otherwise, + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key \ + in the default configuration, the user-provided value will override the default configuration. Otherwise, \ the default configuration will be used. + Arguments: + - cfg (:obj:`EasyDict`): The configuration passed in by the user. """ + # Get the default configuration. default_config = self.default_config() + # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config self.inverse_scalar_transform_handle = InverseScalarTransform( @@ -245,13 +285,15 @@ def __init__(self, cfg: EasyDict = None) -> None: def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "mz_ctree": """ Overview: - The initialization of CRoots with root num and legal action lists. + Initializes a batch of roots to search parallelly later. Arguments: - - root_num (:obj:`int`): the number of the current root. - - legal_action_list (:obj:`list`): the vector of the legal action of this root. + - root_num (:obj:`int`): the number of the roots in a batch. + - legal_action_list (:obj:`List[Any]`): the vector of the legal actions for the roots. + + ..note:: + The initialization is achieved by the ``Roots`` class from the ``ctree_muzero`` module. """ - from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree - return ctree.Roots(active_collect_env_num, legal_actions) + return tree_muzero.Roots(active_collect_env_num, legal_actions) def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, @@ -259,12 +301,16 @@ def search( ) -> None: """ Overview: - Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference. - Use the cpp ctree. + Do MCTS for a batch of roots. Parallel in model inference. \ + Use C++ to implement the tree search. Arguments: - - roots (:obj:`Any`): a batch of expanded root nodes - - latent_state_roots (:obj:`list`): the hidden states of the roots - - to_play_batch (:obj:`list`): the to_play_batch list used in in self-play-mode board games + - roots (:obj:`Any`): a batch of expanded root nodes. + - latent_state_roots (:obj:`list`): the hidden states of the roots. + - model (:obj:`torch.nn.Module`): The model used for inference. + - to_play (:obj:`list`): the to_play list used in in self-play-mode board games. + + .. note:: + The core functions ``batch_traverse`` and ``batch_backpropagate`` are implemented in C++. """ with torch.no_grad(): model.eval() @@ -341,10 +387,15 @@ def search( class GumbelMuZeroMCTSCtree(object): """ Overview: - MCTSCtree for Gumbel MuZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in C++. + The C++ implementation of MCTS (batch format) for Gumbel MuZero. \ + It completes the ``roots``and ``search`` methods by calling functions in module ``ctree_gumbel_muzero``, \ + which are implemented in C++. Interfaces: - __init__, roots, search - + ``__init__``, ``roots``, ``search`` + + ..note:: + The benefit of searching for a batch of nodes at the same time is that \ + it can be parallelized during model inference, thus saving time. """ config = dict( # (int) The max limitation of simluation times during the simulation. @@ -359,18 +410,30 @@ class GumbelMuZeroMCTSCtree(object): @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + A class method that returns a default configuration in the form of an EasyDict object. + Returns: + - cfg (:obj:`EasyDict`): The dict of the default configuration. + """ + # Create a deep copy of the `config` attribute of the class. cfg = EasyDict(copy.deepcopy(cls.config)) + # Add a new attribute `cfg_type` to the `cfg` object. cfg.cfg_type = cls.__name__ + 'Dict' return cfg def __init__(self, cfg: EasyDict = None) -> None: """ Overview: - Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key - in the default configuration, the user-provided value will override the default configuration. Otherwise, + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key \ + in the default configuration, the user-provided value will override the default configuration. Otherwise, \ the default configuration will be used. + Arguments: + - cfg (:obj:`EasyDict`): The configuration passed in by the user. """ + # Get the default configuration. default_config = self.default_config() + # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config self.inverse_scalar_transform_handle = InverseScalarTransform( @@ -381,25 +444,31 @@ def __init__(self, cfg: EasyDict = None) -> None: def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "gmz_ctree": """ Overview: - The initialization of CRoots with root num and legal action lists. + Initializes a batch of roots to search parallelly later. Arguments: - - root_num (:obj:`int`): the number of the current root. - - legal_action_list (:obj:`list`): the vector of the legal action of this root. + - root_num (:obj:`int`): the number of the roots in a batch. + - legal_action_list (:obj:`List[Any]`): the vector of the legal actions for the roots. + + ..note:: + The initialization is achieved by the ``Roots`` class from the ``ctree_gumbel_muzero`` module. """ - from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as ctree - return ctree.Roots(active_collect_env_num, legal_actions) + return tree_gumbel_muzero.Roots(active_collect_env_num, legal_actions) def search(self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, List[Any]] ) -> None: """ Overview: - Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference. - Use the cpp tree. + Do MCTS for a batch of roots. Parallel in model inference. \ + Use C++ to implement the tree search. Arguments: - - roots (:obj:`Any`): a batch of expanded root nodes - - latent_state_roots (:obj:`list`): the hidden states of the roots - - to_play_batch (:obj:`list`): the to_play_batch list used in two_player mode board games + - roots (:obj:`Any`): a batch of expanded root nodes. + - latent_state_roots (:obj:`list`): the hidden states of the roots. + - model (:obj:`torch.nn.Module`): The model used for inference. + - to_play (:obj:`list`): the to_play list used in in self-play-mode board games. + + .. note:: + The core functions ``batch_traverse`` and ``batch_backpropagate`` are implemented in C++. """ with torch.no_grad(): model.eval() diff --git a/lzero/mcts/tree_search/mcts_ctree_sampled.py b/lzero/mcts/tree_search/mcts_ctree_sampled.py index 5d30e626b..0ebbf43ed 100644 --- a/lzero/mcts/tree_search/mcts_ctree_sampled.py +++ b/lzero/mcts/tree_search/mcts_ctree_sampled.py @@ -19,9 +19,15 @@ class SampledEfficientZeroMCTSCtree(object): """ Overview: - MCTSCtree for Sampled EfficientZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in C++. + The C++ implementation of MCTS (batch format) for Sampled EfficientZero. \ + It completes the ``roots``and ``search`` methods by calling functions in module ``ctree_sampled_efficientzero``, \ + which are implemented in C++. Interfaces: - __init__, roots, search + ``__init__``, ``roots``, ``search`` + + ..note:: + The benefit of searching for a batch of nodes at the same time is that \ + it can be parallelized during model inference, thus saving time. """ # the default_config for SampledEfficientZeroMCTSCtree. @@ -40,7 +46,15 @@ class SampledEfficientZeroMCTSCtree(object): @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + A class method that returns a default configuration in the form of an EasyDict object. + Returns: + - cfg (:obj:`EasyDict`): The dict of the default configuration. + """ + # Create a deep copy of the `config` attribute of the class. cfg = EasyDict(copy.deepcopy(cls.config)) + # Add a new attribute `cfg_type` to the `cfg` object. cfg.cfg_type = cls.__name__ + 'Dict' return cfg @@ -50,8 +64,12 @@ def __init__(self, cfg: EasyDict = None) -> None: Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key in the default configuration, the user-provided value will override the default configuration. Otherwise, the default configuration will be used. + Arguments: + - cfg (:obj:`EasyDict`): The configuration passed in by the user. """ + # Get the default configuration. default_config = self.default_config() + # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config self.inverse_scalar_transform_handle = InverseScalarTransform( @@ -65,16 +83,18 @@ def roots( ) -> "ezs_ctree.Roots": """ Overview: - Initialization of CNode with root_num, legal_actions_list, action_space_size, num_of_sampled_actions, continuous_action_space. + Initializes a batch of roots to search parallelly later. Arguments: - - root_num (:obj:'int'): the number of the current root. - - legal_action_list (:obj:'List'): the vector of the legal action of this root. + - root_num (:obj:`int`): the number of the roots in a batch. + - legal_action_list (:obj:`List[Any]`): the vector of the legal actions for the roots. - action_space_size (:obj:'int'): the size of action space of the current env. - - num_of_sampled_actions (:obj:'int'): the number of sampled actions, i.e. K in the Sampled MuZero papers. + - num_of_sampled_actions (:obj:'int'): the number of sampled actions, i.e. K in the Sampled MuZero paper. - continuous_action_space (:obj:'bool'): whether the action space is continous in current env. + + ..note:: + The initialization is achieved by the ``Roots`` class from the ``ctree_sampled_efficientzero`` module. """ - from lzero.mcts.ctree.ctree_sampled_efficientzero import ezs_tree as ctree - return ctree.Roots( + return tree_efficientzero.Roots( root_num, legal_action_lis, action_space_size, num_of_sampled_actions, continuous_action_space ) @@ -84,14 +104,17 @@ def search( ) -> None: """ Overview: - Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference. - Use the cpp ctree. + Do MCTS for a batch of roots. Parallel in model inference. \ + Use C++ to implement the tree search. Arguments: - - roots (:obj:`Any`): a batch of expanded root nodes - - model (:obj:`torch.nn.Module`): Instance of torch.nn.Module. - - latent_state_roots (:obj:`list`): the hidden states of the roots - - reward_hidden_state_roots (:obj:`list`): the value prefix hidden states in LSTM of the roots - - to_play_batch (:obj:`list`): the to_play_batch list used in in self-play-mode board games + - roots (:obj:`Any`): a batch of expanded root nodes. + - model (:obj:`torch.nn.Module`): The model used for inference. + - latent_state_roots (:obj:`list`): the hidden states of the roots. + - reward_hidden_state_roots (:obj:`list`): the value prefix hidden states in LSTM of the roots. + - to_play (:obj:`list`): the to_play list used in in self-play-mode board games. + + .. note:: + The core functions ``batch_traverse`` and ``batch_backpropagate`` are implemented in C++. """ with torch.no_grad(): model.eval() diff --git a/lzero/mcts/tree_search/mcts_ctree_stochastic.py b/lzero/mcts/tree_search/mcts_ctree_stochastic.py index a92249fe4..6b450e5ff 100644 --- a/lzero/mcts/tree_search/mcts_ctree_stochastic.py +++ b/lzero/mcts/tree_search/mcts_ctree_stochastic.py @@ -17,10 +17,15 @@ class StochasticMuZeroMCTSCtree(object): """ Overview: - MCTSCtree for Stochastic MuZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in C++. - + The C++ implementation of MCTS (batch format) for Stochastic MuZero. \ + It completes the ``roots``and ``search`` methods by calling functions in module ``ctree_stochastic_muzero``, \ + which are implemented in C++. Interfaces: - __init__, roots, search + ``__init__``, ``roots``, ``search`` + + ..note:: + The benefit of searching for a batch of nodes at the same time is that \ + it can be parallelized during model inference, thus saving time. """ config = dict( @@ -38,7 +43,15 @@ class StochasticMuZeroMCTSCtree(object): @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + A class method that returns a default configuration in the form of an EasyDict object. + Returns: + - cfg (:obj:`EasyDict`): The dict of the default configuration. + """ + # Create a deep copy of the `config` attribute of the class. cfg = EasyDict(copy.deepcopy(cls.config)) + # Add a new attribute `cfg_type` to the `cfg` object. cfg.cfg_type = cls.__name__ + 'Dict' return cfg @@ -48,8 +61,12 @@ def __init__(self, cfg: EasyDict = None) -> None: Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key in the default configuration, the user-provided value will override the default configuration. Otherwise, the default configuration will be used. + Arguments: + - cfg (:obj:`EasyDict`): The configuration passed in by the user. """ + # Get the default configuration. default_config = self.default_config() + # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config self.inverse_scalar_transform_handle = InverseScalarTransform( @@ -61,13 +78,15 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any], chance_space_size: int = 2) -> "stochastic_mz_tree.Roots": """ Overview: - The initialization of CRoots with root num and legal action lists. + Initializes a batch of roots to search parallelly later. Arguments: - - root_num (:obj:`int`): the number of the current root. - - legal_action_list (:obj:`list`): the vector of the legal action of this root. + - root_num (:obj:`int`): the number of the roots in a batch. + - legal_action_list (:obj:`List[Any]`): the vector of the legal actions for the roots. + + ..note:: + The initialization is achieved by the ``Roots`` class from the ``ctree_stochastic_muzero`` module. """ - from lzero.mcts.ctree.ctree_stochastic_muzero import stochastic_mz_tree as ctree - return ctree.Roots(active_collect_env_num, legal_actions, chance_space_size) + return stochastic_mz_tree.Roots(active_collect_env_num, legal_actions, chance_space_size) def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, @@ -75,12 +94,16 @@ def search( ) -> None: """ Overview: - Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference. - Use the cpp ctree. + Do MCTS for a batch of roots. Parallel in model inference. \ + Use C++ to implement the tree search. Arguments: - - roots (:obj:`Any`): a batch of expanded root nodes - - latent_state_roots (:obj:`list`): the hidden states of the roots - - to_play_batch (:obj:`list`): the to_play_batch list used in in self-play-mode board games + - roots (:obj:`Any`): a batch of expanded root nodes. + - latent_state_roots (:obj:`list`): the hidden states of the roots. + - model (:obj:`torch.nn.Module`): The model used for inference. + - to_play (:obj:`list`): the to_play list used in in self-play-mode board games. + + .. note:: + The core functions ``batch_traverse`` and ``batch_backpropagate`` are implemented in C++. """ with torch.no_grad(): model.eval() diff --git a/lzero/mcts/tree_search/mcts_ptree.py b/lzero/mcts/tree_search/mcts_ptree.py index 4f074b615..9dbb51335 100644 --- a/lzero/mcts/tree_search/mcts_ptree.py +++ b/lzero/mcts/tree_search/mcts_ptree.py @@ -8,6 +8,7 @@ from lzero.mcts.ptree import MinMaxStatsList from lzero.policy import InverseScalarTransform, to_detach_cpu_numpy import lzero.mcts.ptree.ptree_mz as tree_muzero +import lzero.mcts.ptree.ptree_ez as tree_efficientzero if TYPE_CHECKING: import lzero.mcts.ptree.ptree_ez as ez_ptree @@ -16,15 +17,20 @@ # ============================================================== # EfficientZero # ============================================================== -import lzero.mcts.ptree.ptree_ez as tree class EfficientZeroMCTSPtree(object): """ Overview: - MCTSPtree for EfficientZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in python. + The Python implementation of MCTS (batch format) for EfficientZero. \ + It completes the ``roots``and ``search`` methods by calling functions in module ``ptree_ez``, \ + which are implemented in Python. Interfaces: - __init__, roots, search + ``__init__``, ``roots``, ``search`` + + ..note:: + The benefit of searching for a batch of nodes at the same time is that \ + it can be parallelized during model inference, thus saving time. """ # the default_config for EfficientZeroMCTSPtree. @@ -43,7 +49,15 @@ class EfficientZeroMCTSPtree(object): @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + A class method that returns a default configuration in the form of an EasyDict object. + Returns: + - cfg (:obj:`EasyDict`): The dict of the default configuration. + """ + # Create a deep copy of the `config` attribute of the class. cfg = EasyDict(copy.deepcopy(cls.config)) + # Add a new attribute `cfg_type` to the `cfg` object. cfg.cfg_type = cls.__name__ + 'Dict' return cfg @@ -53,8 +67,12 @@ def __init__(self, cfg: EasyDict = None) -> None: Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key in the default configuration, the user-provided value will override the default configuration. Otherwise, the default configuration will be used. + Arguments: + - cfg (:obj:`EasyDict`): The configuration passed in by the user. """ + # Get the default configuration. default_config = self.default_config() + # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config self.inverse_scalar_transform_handle = InverseScalarTransform( @@ -65,13 +83,15 @@ def __init__(self, cfg: EasyDict = None) -> None: def roots(cls: int, root_num: int, legal_actions: List[Any]) -> "ez_ptree.Roots": """ Overview: - The initialization of CRoots with root num and legal action lists. + Initializes a batch of roots to search parallelly later. Arguments: - - root_num: the number of the current root. - - legal_action_list: the vector of the legal action of this root. + - root_num (:obj:`int`): the number of the roots in a batch. + - legal_action_list (:obj:`List[Any]`): the vector of the legal actions for the roots. + + ..note:: + The initialization is achieved by the ``Roots`` class from the ``ptree_ez`` module. """ - import lzero.mcts.ptree.ptree_ez as ptree - return ptree.Roots(root_num, legal_actions) + return tree_efficientzero.Roots(root_num, legal_actions) def search( self, @@ -83,13 +103,17 @@ def search( ) -> None: """ Overview: - Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference. - Use the python ctree. + Do MCTS for a batch of roots. Parallel in model inference. \ + Use Python to implement the tree search. Arguments: - - roots (:obj:`Any`): a batch of expanded root nodes - - latent_state_roots (:obj:`list`): the hidden states of the roots - - reward_hidden_state_roots (:obj:`list`): the value prefix hidden states in LSTM of the roots - - to_play (:obj:`list`): the to_play list used in in self-play-mode board games + - roots (:obj:`Any`): a batch of expanded root nodes. + - latent_state_roots (:obj:`list`): the hidden states of the roots. + - reward_hidden_state_roots (:obj:`list`): the value prefix hidden states in LSTM of the roots. + - model (:obj:`torch.nn.Module`): The model used for inference. + - to_play (:obj:`list`): the to_play list used in in self-play-mode board games. + + .. note:: + The core functions ``batch_traverse`` and ``batch_backpropagate`` are implemented in Python. """ with torch.no_grad(): model.eval() @@ -116,7 +140,7 @@ def search( hidden_states_h_reward = [] # prepare a result wrapper to transport results between python and c++ parts - results = tree.SearchResults(num=batch_size) + results = tree_efficientzero.SearchResults(num=batch_size) # latent_state_index_in_search_path: the first index of leaf node states in latent_state_batch_in_search_path, i.e. is current_latent_state_index in one the search. # latent_state_index_in_batch: the second index of leaf node states in latent_state_batch_in_search_path, i.e. the index in the batch, whose maximum is ``batch_size``. @@ -126,7 +150,7 @@ def search( MCTS stage 1: Selection Each simulation starts from the internal root state s0, and finishes when the simulation reaches a leaf node s_l. """ - latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play = tree.batch_traverse( + latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, virtual_to_play = tree_efficientzero.batch_traverse( roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst, results, copy.deepcopy(to_play) ) # obtain the search horizon for leaf nodes @@ -195,7 +219,7 @@ def search( # NOTE: simulation_index + 1 is very important, which is the depth of the current leaf node. current_latent_state_index = simulation_index + 1 - tree.batch_backpropagate( + tree_efficientzero.batch_backpropagate( current_latent_state_index, discount_factor, value_prefix_batch, value_batch, policy_logits_batch, min_max_stats_lst, results, is_reset_list, virtual_to_play ) @@ -209,9 +233,15 @@ def search( class MuZeroMCTSPtree(object): """ Overview: - MCTSPtree for MuZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in python. + The Python implementation of MCTS (batch format) for MuZero. \ + It completes the ``roots``and ``search`` methods by calling functions in module ``ptree_mz``, \ + which are implemented in Python. Interfaces: - __init__, roots, search + ``__init__``, ``roots``, ``search`` + + ..note:: + The benefit of searching for a batch of nodes at the same time is that \ + it can be parallelized during model inference, thus saving time. """ # the default_config for MuZeroMCTSPtree. @@ -230,18 +260,30 @@ class MuZeroMCTSPtree(object): @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + A class method that returns a default configuration in the form of an EasyDict object. + Returns: + - cfg (:obj:`EasyDict`): The dict of the default configuration. + """ + # Create a deep copy of the `config` attribute of the class. cfg = EasyDict(copy.deepcopy(cls.config)) + # Add a new attribute `cfg_type` to the `cfg` object. cfg.cfg_type = cls.__name__ + 'Dict' return cfg def __init__(self, cfg: EasyDict = None) -> None: """ Overview: - Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key - in the default configuration, the user-provided value will override the default configuration. Otherwise, + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key \ + in the default configuration, the user-provided value will override the default configuration. Otherwise, \ the default configuration will be used. + Arguments: + - cfg (:obj:`EasyDict`): The configuration passed in by the user. """ + # Get the default configuration. default_config = self.default_config() + # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config self.inverse_scalar_transform_handle = InverseScalarTransform( @@ -252,13 +294,15 @@ def __init__(self, cfg: EasyDict = None) -> None: def roots(cls: int, root_num: int, legal_actions: List[Any]) -> "mz_ptree.Roots": """ Overview: - The initialization of CRoots with root num and legal action lists. + Initializes a batch of roots to search parallelly later. Arguments: - - root_num: the number of the current root. - - legal_action_list: the vector of the legal action of this root. + - root_num (:obj:`int`): the number of the roots in a batch. + - legal_action_list (:obj:`List[Any]`): the vector of the legal actions for the roots. + + ..note:: + The initialization is achieved by the ``Roots`` class from the ``ptree_mz`` module. """ - import lzero.mcts.ptree.ptree_mz as ptree - return ptree.Roots(root_num, legal_actions) + return tree_muzero.Roots(root_num, legal_actions) def search( self, @@ -269,12 +313,16 @@ def search( ) -> None: """ Overview: - Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference. - Use the python ctree. + Do MCTS for a batch of roots. Parallel in model inference. \ + Use Python to implement the tree search. Arguments: - - roots (:obj:`Any`): a batch of expanded root nodes - - latent_state_roots (:obj:`list`): the hidden states of the roots - - to_play (:obj:`list`): the to_play list used in in self-play-mode board games + - roots (:obj:`Any`): a batch of expanded root nodes. + - latent_state_roots (:obj:`list`): the hidden states of the roots. + - model (:obj:`torch.nn.Module`): The model used for inference. + - to_play (:obj:`list`): the to_play list used in in self-play-mode board games. + + .. note:: + The core functions ``batch_traverse`` and ``batch_backpropagate`` are implemented in Python. """ with torch.no_grad(): model.eval() diff --git a/lzero/mcts/tree_search/mcts_ptree_sampled.py b/lzero/mcts/tree_search/mcts_ptree_sampled.py index 7567c1f50..812de99c7 100644 --- a/lzero/mcts/tree_search/mcts_ptree_sampled.py +++ b/lzero/mcts/tree_search/mcts_ptree_sampled.py @@ -20,9 +20,15 @@ class SampledEfficientZeroMCTSPtree(object): """ Overview: - MCTSPtree for Sampled EfficientZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in python. + The Python implementation of MCTS (batch format) for Sampled EfficientZero. \ + It completes the ``roots``and ``search`` methods by calling functions in module ``ptree_sez``, \ + which are implemented in Python. Interfaces: - __init__, roots, search + ``__init__``, ``roots``, ``search`` + + ..note:: + The benefit of searching for a batch of nodes at the same time is that \ + it can be parallelized during model inference, thus saving time. """ # the default_config for SampledEfficientZeroMCTSPtree. @@ -41,7 +47,15 @@ class SampledEfficientZeroMCTSPtree(object): @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + A class method that returns a default configuration in the form of an EasyDict object. + Returns: + - cfg (:obj:`EasyDict`): The dict of the default configuration. + """ + # Create a deep copy of the `config` attribute of the class. cfg = EasyDict(copy.deepcopy(cls.config)) + # Add a new attribute `cfg_type` to the `cfg` object. cfg.cfg_type = cls.__name__ + 'Dict' return cfg @@ -51,8 +65,12 @@ def __init__(self, cfg: EasyDict = None) -> None: Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key in the default configuration, the user-provided value will override the default configuration. Otherwise, the default configuration will be used. + Arguments: + - cfg (:obj:`EasyDict`): The configuration passed in by the user. """ + # Get the default configuration. default_config = self.default_config() + # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config self.inverse_scalar_transform_handle = InverseScalarTransform( @@ -66,16 +84,18 @@ def roots( ) -> "ptree.Roots": """ Overview: - Initialization of CNode with root_num, legal_actions_list, action_space_size, num_of_sampled_actions, continuous_action_space. + Initializes a batch of roots to search parallelly later. Arguments: - - root_num (:obj:'int'): the number of the current root. - - legal_action_lis (:obj:'List'): the vector of the legal action of this root. + - root_num (:obj:`int`): the number of the roots in a batch. + - legal_action_list (:obj:`List[Any]`): the vector of the legal actions for the roots. - action_space_size (:obj:'int'): the size of action space of the current env. - - num_of_sampled_actions (:obj:'int'): the number of sampled actions, i.e. K in the Sampled MuZero papers. + - num_of_sampled_actions (:obj:'int'): The number of sampled actions, i.e. K in the Sampled MuZero paper. - continuous_action_space (:obj:'bool'): whether the action space is continous in current env. + + ..note:: + The initialization is achieved by the ``Roots`` class from the ``ptree_sez`` module. """ - import lzero.mcts.ptree.ptree_sez as ptree - return ptree.Roots( + return tree_sez.Roots( root_num, legal_action_lis, action_space_size, num_of_sampled_actions, continuous_action_space ) @@ -89,13 +109,17 @@ def search( ) -> None: """ Overview: - Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference. - Use the python ctree. + Do MCTS for a batch of roots. Parallel in model inference. \ + Use Python to implement the tree search. Arguments: - - roots (:obj:`Any`): a batch of expanded root nodes - - latent_state_roots (:obj:`list`): the hidden states of the roots - - reward_hidden_state_roots (:obj:`list`): the value prefix hidden states in LSTM of the roots - - to_play (:obj:`list`): the to_play list used in in self-play-mode board games + - roots (:obj:`Any`): a batch of expanded root nodes. + - model (:obj:`torch.nn.Module`): The model used for inference. + - latent_state_roots (:obj:`list`): the hidden states of the roots. + - reward_hidden_state_roots (:obj:`list`): the value prefix hidden states in LSTM of the roots. + - to_play (:obj:`list`): the to_play list used in in self-play-mode board games. + + .. note:: + The core functions ``batch_traverse`` and ``batch_backpropagate`` are implemented in Python. """ with torch.no_grad(): model.eval() diff --git a/lzero/mcts/tree_search/mcts_ptree_stochastic.py b/lzero/mcts/tree_search/mcts_ptree_stochastic.py index eb6f5f4b3..b3be91156 100644 --- a/lzero/mcts/tree_search/mcts_ptree_stochastic.py +++ b/lzero/mcts/tree_search/mcts_ptree_stochastic.py @@ -21,9 +21,15 @@ class StochasticMuZeroMCTSPtree(object): """ Overview: - MCTSPtree for MuZero. The core ``batch_traverse`` and ``batch_backpropagate`` function is implemented in python. + The Python implementation of MCTS (batch format) for Stochastic MuZero. \ + It completes the ``roots``and ``search`` methods by calling functions in module ``ptree_stochastic_mz``, \ + which are implemented in Python. Interfaces: - __init__, search + ``__init__``, ``roots``, ``search`` + + ..note:: + The benefit of searching for a batch of nodes at the same time is that \ + it can be parallelized during model inference, thus saving time. """ # the default_config for MuZeroMCTSPtree. @@ -42,7 +48,15 @@ class StochasticMuZeroMCTSPtree(object): @classmethod def default_config(cls: type) -> EasyDict: + """ + Overview: + A class method that returns a default configuration in the form of an EasyDict object. + Returns: + - cfg (:obj:`EasyDict`): The dict of the default configuration. + """ + # Create a deep copy of the `config` attribute of the class. cfg = EasyDict(copy.deepcopy(cls.config)) + # Add a new attribute `cfg_type` to the `cfg` object. cfg.cfg_type = cls.__name__ + 'Dict' return cfg @@ -52,8 +66,12 @@ def __init__(self, cfg: EasyDict = None) -> None: Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key in the default configuration, the user-provided value will override the default configuration. Otherwise, the default configuration will be used. + Arguments: + - cfg (:obj:`EasyDict`): The configuration passed in by the user. """ + # Get the default configuration. default_config = self.default_config() + # Update the default configuration with the values provided by the user in ``cfg``. default_config.update(cfg) self._cfg = default_config self.inverse_scalar_transform_handle = InverseScalarTransform( @@ -64,13 +82,15 @@ def __init__(self, cfg: EasyDict = None) -> None: def roots(cls: int, root_num: int, legal_actions: List[Any]) -> "stochastic_mz_ptree.Roots": """ Overview: - The initialization of CRoots with root num and legal action lists. + Initializes a batch of roots to search parallelly later. Arguments: - - root_num: the number of the current root. - - legal_action_list: the vector of the legal action of this root. + - root_num (:obj:`int`): the number of the roots in a batch. + - legal_action_list (:obj:`List[Any]`): the vector of the legal actions for the roots. + + ..note:: + The initialization is achieved by the ``Roots`` class from the ``ptree_stochastic_mz`` module. """ - import lzero.mcts.ptree.ptree_stochastic_mz as ptree - return ptree.Roots(root_num, legal_actions) + return tree_stochastic_muzero.Roots(root_num, legal_actions) def search( self, @@ -81,12 +101,16 @@ def search( ) -> None: """ Overview: - Do MCTS for the roots (a batch of root nodes in parallel). Parallel in model inference. - Use the python ctree. + Do MCTS for a batch of roots. Parallel in model inference. \ + Use Python to implement the tree search. Arguments: - - roots (:obj:`Any`): a batch of expanded root nodes - - latent_state_roots (:obj:`list`): the hidden states of the roots - - to_play (:obj:`list`): the to_play list used in two_player mode board games + - roots (:obj:`Any`): a batch of expanded root nodes. + - latent_state_roots (:obj:`list`): the hidden states of the roots. + - model (:obj:`torch.nn.Module`): The model used for inference. + - to_play (:obj:`list`): the to_play list used in in self-play-mode board games. + + .. note:: + The core functions ``batch_traverse`` and ``batch_backpropagate`` are implemented in Python. """ with torch.no_grad(): model.eval()