From 541f5bf5d5bcd643c472fac683b08780907eae21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Thu, 8 Feb 2024 12:53:51 +0800 Subject: [PATCH] polish(pu): polish comments in worker files --- lzero/worker/alphazero_collector.py | 92 +++++++++++----------- lzero/worker/alphazero_evaluator.py | 77 +++++++++---------- lzero/worker/muzero_collector.py | 114 +++++++++++++++------------- lzero/worker/muzero_evaluator.py | 87 ++++++++++----------- 4 files changed, 184 insertions(+), 186 deletions(-) diff --git a/lzero/worker/alphazero_collector.py b/lzero/worker/alphazero_collector.py index ced7727d0..8b43b1e00 100644 --- a/lzero/worker/alphazero_collector.py +++ b/lzero/worker/alphazero_collector.py @@ -1,11 +1,11 @@ from collections import namedtuple -from typing import Optional, Any, List, Dict +from typing import Optional, Any, List import numpy as np from ding.envs import BaseEnvManager from ding.torch_utils import to_ndarray -from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \ - broadcast_object_list, allreduce_data +from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, get_rank, get_world_size, \ + allreduce_data from ding.worker.collector.base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, \ to_tensor_transitions @@ -14,9 +14,10 @@ class AlphaZeroCollector(ISerialCollector): """ Overview: - AlphaZero collector (n_episode). + AlphaZero collector for collecting episodes of experience during self-play or playing against an opponent. + This collector is specifically designed for the AlphaZero algorithm. Interfaces: - __init__, reset, reset_env, reset_policy, collect, close + ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``collect``, ``close`` Property: envstep """ @@ -35,18 +36,17 @@ def __init__( env_config=None, ) -> None: """ - Overview: - Init the AlphaZero collector according to input arguments. - Arguments: - - collect_print_freq (:obj:`int`): collect_print_frequency in terms of training_steps. - - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ - its derivatives are supported. - - policy (:obj:`Policy`): The policy to be collected. - - tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary. - - instance_name (:obj:`Optional[str]`): Name of this instance. - - exp_name (:obj:`str`): Experiment name, which is used to indicate output directory. - - env_config: Config of environment - """ + Overview: + Initialize the AlphaZero collector with the provided environment, policy, and configurations. + Arguments: + - collect_print_freq (:obj:`int`): Frequency of printing collection statistics (in training steps). + - env (:obj:`Optional[BaseEnvManager]`): Environment manager for managing multiple environments. + - policy (:obj:`Optional[namedtuple]`): Policy used for making decisions during collection. + - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger for logging statistics. + - exp_name (:obj:`str`): Name of the experiment for logging purposes. + - instance_name (:obj:`str`): Unique identifier for this collector instance. + - env_config (:obj:`Optional[dict]`): Configuration for the environment. + """ self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq @@ -79,13 +79,12 @@ def __init__( def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset the environment. + Reset or replace the environment in the collector. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed \ in environment and launch. Arguments: - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) + - _env (:obj:`Optional[BaseEnvManager]`): New environment to replace the existing one, if provided. """ if _env is not None: self._env = _env @@ -97,11 +96,11 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Reset the policy. + Reset or replace the policy in the collector. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy + - _policy (:obj:`Optional[namedtuple]`): New policy to replace the existing one, if provided. """ assert hasattr(self, '_env'), "please set env first" if _policy is not None: @@ -119,16 +118,15 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset the environment and policy. + Reset the environment and policy within the collector. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed \ in environment and launch. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) + - _policy (:obj:`Optional[namedtuple]`): New policy to replace the existing one, if provided. + - _env (:obj:`Optional[BaseEnvManager]`): New environment to replace the existing one, if provided. """ if _env is not None: self.reset_env(_env) @@ -151,9 +149,10 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana def _reset_stat(self, env_id: int) -> None: """ Overview: - Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\ - and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ - to get more messages. + Reset the statistics for a specific environment. + Including reset the traj_buffer, obs_pool, policy_output_pool and env_info. + Reset these states according to env_id. + You can refer to base_serial_collector to get more messages. Arguments: - env_id (:obj:`int`): the id where we need to reset the collector's state """ @@ -165,8 +164,8 @@ def _reset_stat(self, env_id: int) -> None: def close(self) -> None: """ Overview: - Close the collector. If end_flag is False, close the environment, flush the tb_logger\ - and close the tb_logger. + Close the collector. If end_flag is False, close the environment, flush the tb_logger + and close the tb_logger. """ if self._end_flag: return @@ -182,13 +181,13 @@ def collect(self, policy_kwargs: Optional[dict] = None) -> List[Any]: """ Overview: - Collect `n_episode` data with policy_kwargs, which is already trained `train_iter` iterations + Collect experience data for a specified number of episodes using the current policy. Arguments: - - n_episode (:obj:`int`): the number of collecting data episode - - train_iter (:obj:`int`): the number of training iteration - - policy_kwargs (:obj:`dict`): the keyword args for policy forward + - n_episode (:obj:`Optional[int]`): Number of episodes to collect. Defaults to a pre-set value if None. + - train_iter (:obj:`int`): Current training iteration. + - policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy. Returns: - - return_data (:obj:`List`): A list containing collected episodes. + - return_data (:obj:`List[Any]`): A list of collected experience episodes. """ if n_episode is None: if self._default_n_episode is None: @@ -295,17 +294,16 @@ def collect(self, def envstep(self) -> int: """ Overview: - Print the total envstep count. - Return: - - envstep (:obj:`int`): the total envstep count + Get the total number of environment steps taken by the collector. + Returns: + - envstep (:obj:`int`): Total count of environment steps. """ return self._total_envstep_count def close(self) -> None: """ Overview: - Close the collector. If end_flag is False, close the environment, flush the tb_logger\ - and close the tb_logger. + Close the collector and clean up resources such as environment and logger. """ if self._end_flag: return @@ -318,18 +316,16 @@ def close(self) -> None: def __del__(self) -> None: """ Overview: - Execute the close command and close the collector. __del__ is automatically called to \ - destroy the collector instance when the collector finishes its work + Destructor method that is called when the collector object is being destroyed. """ self.close() def _output_log(self, train_iter: int) -> None: """ Overview: - Print the output log information. You can refer to Docs/Best Practice/How to understand\ - training generated folders/Serial mode/log/collector for more details. + Output logging information for the current collection phase. Arguments: - - train_iter (:obj:`int`): the number of training iteration. + - train_iter (:obj:`int`): Current training iteration for logging purposes. """ if self._rank != 0: return @@ -368,9 +364,9 @@ def _output_log(self, train_iter: int) -> None: def reward_shaping(self, transitions, eval_episode_return): """ Overview: - Shape the reward according to the player. + Shape the rewards in the collected transitions based on the outcome of the episode. Return: - - transitions: data transitions. + - transitions (:obj:`List[dict]`): List of data transitions. """ reward = transitions[-1]['reward'] to_play = transitions[-1]['obs']['to_play'] diff --git a/lzero/worker/alphazero_evaluator.py b/lzero/worker/alphazero_evaluator.py index c9eb8f650..66e61170b 100644 --- a/lzero/worker/alphazero_evaluator.py +++ b/lzero/worker/alphazero_evaluator.py @@ -1,11 +1,11 @@ from collections import namedtuple from typing import Optional, Callable, Tuple -import torch + import numpy as np +import torch from ding.envs import BaseEnv from ding.envs import BaseEnvManager from ding.torch_utils import to_tensor, to_item - from ding.utils import build_logger, EasyTimer, SERIAL_EVALUATOR_REGISTRY from ding.utils import get_world_size, get_rank, broadcast_object_list from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor @@ -15,9 +15,9 @@ class AlphaZeroEvaluator(ISerialEvaluator): """ Overview: - AlphaZero Evaluator. + AlphaZero Evaluator class which handles the evaluation of the trained policy. Interfaces: - __init__, reset, reset_policy, reset_env, close, should_eval, eval + ``__init__``, ``reset``, ``reset_policy``, ``reset_env``, ``close``, ``should_eval``, ``eval`` Property: env, policy """ @@ -36,17 +36,17 @@ def __init__( ) -> None: """ Overview: - Init the AlphaZero evaluator according to input arguments. + Initialize the AlphaZero evaluator with the given parameters. Arguments: - - eval_freq (:obj:`int`): evaluation frequency in terms of training steps. - - n_evaluator_episode (:obj:`int`): the number of episodes to eval in total. - - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \ - its derivatives are supported. - - policy (:obj:`Policy`): The policy to be collected. - - tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary. - - exp_name (:obj:`str`): Experiment name, which is used to indicate output directory. - - instance_name (:obj:`Optional[str]`): Name of this instance. - - env_config: Config of environment + - eval_freq (:obj:`int`): Evaluation frequency in terms of training steps. + - n_evaluator_episode (:obj:`int`): Number of episodes for each evaluation. + - stop_value (:obj:`float`): Reward threshold to stop training if surpassed. + - env (:obj:`Optional[BaseEnvManager]`): Environment manager for managing multiple environments. + - policy (:obj:`Optional[namedtuple]`): Policy to be evaluated. + - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger for logging statistics. + - exp_name (:obj:`str`): Name of the experiment for logging purposes. + - instance_name (:obj:`str`): Unique identifier for this evaluator instance. + - env_config (:obj:`Optional[dict]`): Configuration for the environment. """ self._eval_freq = eval_freq self._exp_name = exp_name @@ -78,14 +78,12 @@ def __init__( def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset evaluator's environment. In some case, we need evaluator use the same policy in different \ - environments. We can use reset_env to reset the environment. + Reset or replace the environment in the evaluator. If _env is None, reset the old environment. If _env is not None, replace the old environment in the evaluator with the \ new passed in environment and launch. Arguments: - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) + - _env (:obj:`Optional[BaseEnvManager]`): New environment to replace the existing one, if provided. """ if _env is not None: self._env = _env @@ -97,8 +95,7 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\ - different policy. We can use reset_policy to reset the policy. + Reset or replace the policy in the evaluator. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the evaluator with the new passed in policy. Arguments: @@ -112,16 +109,15 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset evaluator's policy and environment. Use new policy and environment to collect data. + Reset the environment and policy within the evaluator. If _env is None, reset the old environment. If _env is not None, replace the old environment in the evaluator with the new passed in \ environment and launch. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the evaluator with the new passed in policy. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) + - _policy (:obj:`Optional[namedtuple]`): New policy to replace the existing one, if provided. + - _env (:obj:`Optional[BaseEnvManager]`): New environment to replace the existing one, if provided. """ if _env is not None: self.reset_env(_env) @@ -134,8 +130,8 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana def close(self) -> None: """ Overview: - Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\ - and close the tb_logger. + Close the evaluator and clean up resources such as environment and logger. + If end_flag is False, close the environment, flush the tb_logger and close the tb_logger. """ if self._end_flag: return @@ -148,18 +144,18 @@ def close(self) -> None: def __del__(self) -> None: """ Overview: - Execute the close command and close the evaluator. __del__ is automatically called \ - to destroy the evaluator instance when the evaluator finishes its work + Destructor method that is called when the evaluator object is being destroyed. + __del__ is automatically called to destroy the evaluator instance when the evaluator finishes its work. """ self.close() def should_eval(self, train_iter: int) -> bool: """ Overview: - Determine whether you need to start the evaluation mode, if the number of training has reached\ - the maximum number of times to start the evaluator, return True - Arguments: - - train_iter (:obj:`int`): Current training iteration. + Check if it is time to evaluate the policy based on the training iteration count. + If the amount of training has reached the maximum number of times to start the evaluator, return True. + Returns: + - (:obj:`bool`): Flag indicating whether evaluation should be performed. """ if train_iter == self._last_eval_iter: return False @@ -178,17 +174,18 @@ def eval( ) -> Tuple[bool, dict]: """ Overview: - Evaluate policy and store the best policy based on whether it reaches the highest historical reward. + Execute the evaluation of the policy and determine if the stopping condition has been met. Arguments: - - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward. - - train_iter (:obj:`int`): Current training iteration. - - envstep (:obj:`int`): Current env interaction step. - - n_episode (:obj:`int`): Number of evaluation episodes. + - save_ckpt_fn (:obj:`Optional[Callable]`): Callback function to save a checkpoint. + - train_iter (:obj:`int`): Current number of training iterations completed. + - envstep (:obj:`int`): Current number of environment steps completed. + - n_episode (:obj:`Optional[int]`): Number of episodes to evaluate. Defaults to preset if None. + - force_render (:obj:`bool`): Force rendering of the environment, if applicable. Returns: - - stop_flag (:obj:`bool`): Whether this training program can be ended. - - return_info (:obj:`dict`): Current evaluation return information. + - stop_flag (:obj:`bool`): Whether the training process should stop based on evaluation results. + - return_info (:obj:`dict`): Information about the evaluation results. """ - # evaluator only work on rank0 + # the evaluator only works on rank0 stop_flag, return_info = False, [] if get_rank() == 0: if n_episode is None: diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index aca581c47..4cf32cd6b 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -6,8 +6,8 @@ import torch from ding.envs import BaseEnvManager from ding.torch_utils import to_ndarray -from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \ - broadcast_object_list, allreduce_data +from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, get_rank, get_world_size, \ + allreduce_data from ding.worker.collector.base_serial_collector import ISerialCollector from torch.nn import L1Loss @@ -20,10 +20,12 @@ class MuZeroCollector(ISerialCollector): """ Overview: The Collector for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel MuZero. + It manages the data collection process for training these algorithms using a serial mechanism. Interfaces: - __init__, reset, reset_env, reset_policy, _reset_stat, envstep, __del__, _compute_priorities, pad_and_save_last_trajectory, collect, _output_log, close - Property: - envstep + ``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``_reset_stat``, ``envstep``, ``__del__``, ``_compute_priorities``, + ``pad_and_save_last_trajectory``, ``collect``, ``_output_log``, ``close`` + Properties: + ``envstep`` """ # TO be compatible with ISerialCollector @@ -41,15 +43,15 @@ def __init__( ) -> None: """ Overview: - Init the collector according to input arguments. + Initialize the MuZeroCollector with the given parameters. Arguments: - - collect_print_freq (:obj:`int`): collect_print_frequency in terms of training_steps. - - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager) - - policy (:obj:`namedtuple`): the api namedtuple of collect_mode policy - - tb_logger (:obj:`SummaryWriter`): tensorboard handle - - instance_name (:obj:`Optional[str]`): Name of this instance. - - exp_name (:obj:`str`): Experiment name, which is used to indicate output directory. - - policy_config: Config of game. + - collect_print_freq (:obj:`int`): Frequency (in training steps) at which to print collection information. + - env (:obj:`Optional[BaseEnvManager]`): Instance of the subclass of vectorized environment manager. + - policy (:obj:`Optional[namedtuple]`): namedtuple of the collection mode policy API. + - tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger instance. + - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. + - instance_name (:obj:`str`): Unique identifier for this collector instance. + - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. """ self._exp_name = exp_name self._instance_name = instance_name @@ -84,13 +86,12 @@ def __init__( def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset the environment. + Reset or replace the environment managed by this collector. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed \ in environment and launch. Arguments: - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) + - env (:obj:`Optional[BaseEnvManager]`): New environment to manage, if provided. """ if _env is not None: self._env = _env @@ -102,7 +103,7 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Reset the policy. + Reset or replace the policy used by this collector. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the collector with the new passed in policy. Arguments: @@ -120,7 +121,7 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset the environment and policy. + Reset the collector with the given policy and/or environment. If _env is None, reset the old environment. If _env is not None, replace the old environment in the collector with the new passed \ in environment and launch. @@ -152,9 +153,9 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana def _reset_stat(self, env_id: int) -> None: """ Overview: - Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\ - and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ - to get more messages. + Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool \ + and env_info. Reset these states according to env_id. You can refer to base_serial_collector\ + to get more messages. Arguments: - env_id (:obj:`int`): the id where we need to reset the collector's state """ @@ -164,17 +165,17 @@ def _reset_stat(self, env_id: int) -> None: def envstep(self) -> int: """ Overview: - Print the total envstep count. - Return: - - envstep (:obj:`int`): the total envstep count + Get the total number of environment steps collected. + Returns: + - envstep (:obj:`int`): Total number of environment steps collected. """ return self._total_envstep_count def close(self) -> None: """ Overview: - Close the collector. If end_flag is False, close the environment, flush the tb_logger\ - and close the tb_logger. + Close the collector. If end_flag is False, close the environment, flush the tb_logger \ + and close the tb_logger. """ if self._end_flag: return @@ -188,21 +189,23 @@ def __del__(self) -> None: """ Overview: Execute the close command and close the collector. __del__ is automatically called to \ - destroy the collector instance when the collector finishes its work + destroy the collector instance when the collector finishes its work """ self.close() # ============================================================== # MCTS+RL related core code # ============================================================== - def _compute_priorities(self, i, pred_values_lst, search_values_lst): + def _compute_priorities(self, i: int, pred_values_lst: List[float], search_values_lst: List[float]) -> np.ndarray: """ Overview: - obtain the priorities at index i. + Compute the priorities for transitions based on prediction and search value discrepancies. Arguments: - - i: index. - - pred_values_lst: The list of value being predicted. - - search_values_lst: The list of value obtained through search. + - i (:obj:`int`): Index of the values in the list to compute the priority for. + - pred_values_lst (:obj:`List[float]`): List of predicted values. + - search_values_lst (:obj:`List[float]`): List of search values obtained from MCTS. + Returns: + - priorities (:obj:`np.ndarray`): Array of computed priorities. """ if self.policy_config.use_priority: # Calculate priorities. The priorities are the L1 losses between the predicted @@ -222,14 +225,18 @@ def _compute_priorities(self, i, pred_values_lst, search_values_lst): return priorities - def pad_and_save_last_trajectory(self, i, last_game_segments, last_game_priorities, game_segments, done) -> None: + def pad_and_save_last_trajectory(self, i: int, last_game_segments: List[GameSegment], + last_game_priorities: List[np.ndarray], + game_segments: List[GameSegment], done: np.ndarray) -> None: """ Overview: - put the last game segment into the pool if the current game is finished + Save the game segment to the pool if the current game is finished, padding it if necessary. Arguments: - - last_game_segments (:obj:`list`): list of the last game segments - - last_game_priorities (:obj:`list`): list of the last game priorities - - game_segments (:obj:`list`): list of the current game segments + - i (:obj:`int`): Index of the current game segment. + - last_game_segments (:obj:`List[GameSegment]`): List of the last game segments to be padded and saved. + - last_game_priorities (:obj:`List[np.ndarray]`): List of priorities of the last game segments. + - game_segments (:obj:`List[GameSegment]`): List of the current game segments. + - done (:obj:`np.ndarray`): Array indicating whether each game is done. Note: (last_game_segments[i].obs_segment[-4:][j] == game_segments[i].obs_segment[:4][j]).all() is True """ @@ -237,7 +244,8 @@ def pad_and_save_last_trajectory(self, i, last_game_segments, last_game_prioriti beg_index = self.policy_config.model.frame_stack_num end_index = beg_index + self.policy_config.num_unroll_steps - # the start obs is init zero obs, so we take the [ : +] obs as the pad obs + # the start obs is init zero obs, so we take the + # [ : +] obs as the pad obs # e.g. the start 4 obs is init zero obs, the num_unroll_steps is 5, so we take the [4:9] obs as the pad obs pad_obs_lst = game_segments[i].obs_segment[beg_index:end_index] pad_child_visits_lst = game_segments[i].child_visit_segment[:self.policy_config.num_unroll_steps] @@ -262,10 +270,12 @@ def pad_and_save_last_trajectory(self, i, last_game_segments, last_game_prioriti # pad over and save if self.policy_config.gumbel_algo: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, next_segment_improved_policy = pad_improved_policy_prob) + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, + next_segment_improved_policy=pad_improved_policy_prob) else: if self.policy_config.use_ture_chance_label_in_chance_encoder: - last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, next_chances = chance_lst) + last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst, + next_chances=chance_lst) else: last_game_segments[i].pad_over(pad_obs_lst, pad_reward_lst, pad_root_values_lst, pad_child_visits_lst) """ @@ -297,13 +307,13 @@ def collect(self, policy_kwargs: Optional[dict] = None) -> List[Any]: """ Overview: - Collect `n_episode` data with policy_kwargs, which is already trained `train_iter` iterations. + Collect `n_episode` episodes of data with policy_kwargs, trained for `train_iter` iterations. Arguments: - - n_episode (:obj:`int`): the number of collecting data episode. - - train_iter (:obj:`int`): the number of training iteration. - - policy_kwargs (:obj:`dict`): the keyword args for policy forward. + - n_episode (:obj:`Optional[int]`): Number of episodes to collect. + - train_iter (:obj:`int`): Number of training iterations completed so far. + - policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy. Returns: - - return_data (:obj:`List`): A list containing collected game_segments + - return_data (:obj:`List[Any]`): Collected data in the form of a list. """ if n_episode is None: if self._default_n_episode is None: @@ -325,7 +335,7 @@ def collect(self, retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: - # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to # len(self._env.ready_obs), especially in tictactoe env. self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states)) @@ -464,7 +474,7 @@ def collect(self, with self._timer: if timestep.info.get('abnormal', False): # If there is an abnormal timestep, reset all the related variables(including this env). - # suppose there is no reset param, just reset this env + # suppose there is no reset param, reset this env self._env.reset({env_id: None}) self._policy.reset([env_id]) self._reset_stat(env_id) @@ -477,11 +487,12 @@ def collect(self, distributions_dict[env_id], value_dict[env_id], root_sampled_actions_dict[env_id] ) elif self.policy_config.gumbel_algo: - game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], improved_policy = improved_policy_dict[env_id]) + game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id], + improved_policy=improved_policy_dict[env_id]) else: game_segments[env_id].store_search_stats(distributions_dict[env_id], value_dict[env_id]) # append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t} - # in ``game_segments[env_id].init``, we have append o_{t} in ``self.obs_segment`` + # in ``game_segments[env_id].init``, we have appended o_{t} in ``self.obs_segment`` if self.policy_config.use_ture_chance_label_in_chance_encoder: game_segments[env_id].append( actions[env_id], to_ndarray(obs['observation']), reward, action_mask_dict[env_id], @@ -689,10 +700,9 @@ def collect(self, def _output_log(self, train_iter: int) -> None: """ Overview: - Print the output log information. You can refer to Docs/Best Practice/How to understand\ - training generated folders/Serial mode/log/collector for more details. + Log the collector's data and output the log information. Arguments: - - train_iter (:obj:`int`): the number of training iteration. + - train_iter (:obj:`int`): Current training iteration number for logging context. """ if self._rank != 0: return diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index 313a07e07..3c11ab665 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -1,7 +1,7 @@ import copy import time from collections import namedtuple -from typing import Optional, Callable, Tuple +from typing import Optional, Callable, Tuple, Dict, Any import numpy as np import torch @@ -19,10 +19,10 @@ class MuZeroEvaluator(ISerialEvaluator): """ Overview: - The Evaluator for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. + The Evaluator class for MCTS+RL algorithms, such as MuZero, EfficientZero, and Sampled EfficientZero. Interfaces: __init__, reset, reset_policy, reset_env, close, should_eval, eval - Property: + Properties: env, policy """ @@ -30,10 +30,10 @@ class MuZeroEvaluator(ISerialEvaluator): def default_config(cls: type) -> EasyDict: """ Overview: - Get evaluator's default config. We merge evaluator's default config with other default configs\ - and user's config to get the final config. - Return: - cfg (:obj:`EasyDict`): evaluator's default config + Retrieve the default configuration for the evaluator by merging evaluator-specific defaults with other + defaults and any user-provided configuration. + Returns: + - cfg (:obj:`EasyDict`): The default configuration for the evaluator. """ cfg = EasyDict(copy.deepcopy(cls.config)) cfg.cfg_type = cls.__name__ + 'Dict' @@ -58,17 +58,17 @@ def __init__( ) -> None: """ Overview: - Init method. Load config and use ``self._cfg`` setting to build common serial evaluator components, - e.g. logger helper, timer. + Initialize the evaluator with configuration settings for various components such as logger helper and timer. Arguments: - - eval_freq (:obj:`int`): evaluation frequency in terms of training steps. - - n_evaluator_episode (:obj:`int`): the number of episodes to eval in total. - - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager) - - policy (:obj:`namedtuple`): the api namedtuple of collect_mode policy - - tb_logger (:obj:`SummaryWriter`): tensorboard handle - - exp_name (:obj:`str`): Experiment name, which is used to indicate output directory. - - instance_name (:obj:`Optional[str]`): Name of this instance. - - policy_config: Config of game. + - eval_freq (:obj:`int`): Evaluation frequency in terms of training steps. + - n_evaluator_episode (:obj:`int`): Number of episodes to evaluate in total. + - stop_value (:obj:`float`): A reward threshold above which the training is considered converged. + - env (:obj:`Optional[BaseEnvManager]`): An optional instance of a subclass of BaseEnvManager. + - policy (:obj:`Optional[namedtuple]`): An optional API namedtuple defining the policy for evaluation. + - tb_logger (:obj:`Optional[SummaryWriter]`): Optional TensorBoard logger instance. + - exp_name (:obj:`str`): Name of the experiment, used to determine output directory. + - instance_name (:obj:`str`): Name of this evaluator instance. + - policy_config (:obj:`Optional[dict]`): Optional configuration for the game policy. """ self._eval_freq = eval_freq self._exp_name = exp_name @@ -103,14 +103,11 @@ def __init__( def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset evaluator's environment. In some case, we need evaluator use the same policy in different \ - environments. We can use reset_env to reset the environment. - If _env is None, reset the old environment. - If _env is not None, replace the old environment in the evaluator with the \ - new passed in environment and launch. + Reset the environment for the evaluator, optionally replacing it with a new environment. + If _env is None, reset the old environment. If _env is not None, replace the old environment + in the evaluator with the new passed in environment and launch. Arguments: - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) + - _env (:obj:`Optional[BaseEnvManager]`): An optional new environment instance to replace the existing one. """ if _env is not None: self._env = _env @@ -122,12 +119,11 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None: def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: """ Overview: - Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\ - different policy. We can use reset_policy to reset the policy. + Reset the policy for the evaluator, optionally replacing it with a new policy. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the evaluator with the new passed in policy. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy + - _policy (:obj:`Optional[namedtuple]`): An optional new policy namedtuple to replace the existing one. """ assert hasattr(self, '_env'), "please set env first" if _policy is not None: @@ -137,16 +133,15 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None: def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None: """ Overview: - Reset evaluator's policy and environment. Use new policy and environment to collect data. + Reset both the policy and environment for the evaluator, optionally replacing them. If _env is None, reset the old environment. If _env is not None, replace the old environment in the evaluator with the new passed in \ environment and launch. If _policy is None, reset the old policy. If _policy is not None, replace the old policy in the evaluator with the new passed in policy. Arguments: - - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy - - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \ - env_manager(BaseEnvManager) + - _policy (:obj:`Optional[namedtuple]`): An optional new policy namedtuple to replace the existing one. + - _env (:obj:`Optional[BaseEnvManager]`): An optional new environment instance to replace the existing one. """ if _env is not None: self.reset_env(_env) @@ -159,8 +154,7 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana def close(self) -> None: """ Overview: - Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\ - and close the tb_logger. + Close the evaluator, the environment, flush and close the TensorBoard logger if applicable. """ if self._end_flag: return @@ -181,10 +175,11 @@ def __del__(self): def should_eval(self, train_iter: int) -> bool: """ Overview: - Determine whether you need to start the evaluation mode, if the number of training has reached\ - the maximum number of times to start the evaluator, return True + Determine whether to initiate evaluation based on the training iteration count and evaluation frequency. Arguments: - - train_iter (:obj:`int`): Current training iteration. + - train_iter (:obj:`int`): The current count of training iterations. + Returns: + - (:obj:`bool`): `True` if evaluation should be initiated, otherwise `False`. """ if train_iter == self._last_eval_iter: return False @@ -199,20 +194,20 @@ def eval( train_iter: int = -1, envstep: int = -1, n_episode: Optional[int] = None, - ) -> Tuple[bool, float]: + ) -> Tuple[bool, Dict[str, Any]]: """ Overview: - Evaluate policy and store the best policy based on whether it reaches the highest historical reward. + Evaluate the current policy, storing the best policy if it achieves the highest historical reward. Arguments: - - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward. - - train_iter (:obj:`int`): Current training iteration. - - envstep (:obj:`int`): Current env interaction step. - - n_episode (:obj:`int`): Number of evaluation episodes. + - save_ckpt_fn (:obj:`Optional[Callable]`): Optional function to save a checkpoint when a new best reward is achieved. + - train_iter (:obj:`int`): The current training iteration count. + - envstep (:obj:`int`): The current environment step count. + - n_episode (:obj:`Optional[int]`): Optional number of evaluation episodes; defaults to the evaluator's setting. Returns: - - stop_flag (:obj:`bool`): Whether this training program can be ended. - - episode_info (:obj:`Dict[str, List]`): Current evaluation episode information. + - stop_flag (:obj:`bool`): Indicates whether the training can be stopped based on the stop value. + - episode_info (:obj:`Dict[str, Any]`): A dictionary containing information about the evaluation episodes. """ - # evaluator only work on rank0 + # the evaluator only works on rank0 episode_info = None stop_flag = False if get_rank() == 0: @@ -231,7 +226,7 @@ def eval( retry_waiting_time = 0.001 while len(init_obs.keys()) != self._env_num: - # In order to be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to + # To be compatible with subprocess env_manager, in which sometimes self._env_num is not equal to # len(self._env.ready_obs), especially in tictactoe env. self._logger.info('The current init_obs.keys() is {}'.format(init_obs.keys())) self._logger.info('Before sleeping, the _env_states is {}'.format(self._env._env_states))