Skip to content

Commit

Permalink
polish(pu): polish comments in worker files
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Feb 8, 2024
1 parent 6ffcc4d commit 541f5bf
Show file tree
Hide file tree
Showing 4 changed files with 184 additions and 186 deletions.
92 changes: 44 additions & 48 deletions lzero/worker/alphazero_collector.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from collections import namedtuple
from typing import Optional, Any, List, Dict
from typing import Optional, Any, List

import numpy as np
from ding.envs import BaseEnvManager
from ding.torch_utils import to_ndarray
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \
broadcast_object_list, allreduce_data
from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, get_rank, get_world_size, \
allreduce_data
from ding.worker.collector.base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, \
to_tensor_transitions

Expand All @@ -14,9 +14,10 @@
class AlphaZeroCollector(ISerialCollector):
"""
Overview:
AlphaZero collector (n_episode).
AlphaZero collector for collecting episodes of experience during self-play or playing against an opponent.
This collector is specifically designed for the AlphaZero algorithm.
Interfaces:
__init__, reset, reset_env, reset_policy, collect, close
``__init__``, ``reset``, ``reset_env``, ``reset_policy``, ``collect``, ``close``
Property:
envstep
"""
Expand All @@ -35,18 +36,17 @@ def __init__(
env_config=None,
) -> None:
"""
Overview:
Init the AlphaZero collector according to input arguments.
Arguments:
- collect_print_freq (:obj:`int`): collect_print_frequency in terms of training_steps.
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
its derivatives are supported.
- policy (:obj:`Policy`): The policy to be collected.
- tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary.
- instance_name (:obj:`Optional[str]`): Name of this instance.
- exp_name (:obj:`str`): Experiment name, which is used to indicate output directory.
- env_config: Config of environment
"""
Overview:
Initialize the AlphaZero collector with the provided environment, policy, and configurations.
Arguments:
- collect_print_freq (:obj:`int`): Frequency of printing collection statistics (in training steps).
- env (:obj:`Optional[BaseEnvManager]`): Environment manager for managing multiple environments.
- policy (:obj:`Optional[namedtuple]`): Policy used for making decisions during collection.
- tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger for logging statistics.
- exp_name (:obj:`str`): Name of the experiment for logging purposes.
- instance_name (:obj:`str`): Unique identifier for this collector instance.
- env_config (:obj:`Optional[dict]`): Configuration for the environment.
"""
self._exp_name = exp_name
self._instance_name = instance_name
self._collect_print_freq = collect_print_freq
Expand Down Expand Up @@ -79,13 +79,12 @@ def __init__(
def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
"""
Overview:
Reset the environment.
Reset or replace the environment in the collector.
If _env is None, reset the old environment.
If _env is not None, replace the old environment in the collector with the new passed \
in environment and launch.
Arguments:
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
env_manager(BaseEnvManager)
- _env (:obj:`Optional[BaseEnvManager]`): New environment to replace the existing one, if provided.
"""
if _env is not None:
self._env = _env
Expand All @@ -97,11 +96,11 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
"""
Overview:
Reset the policy.
Reset or replace the policy in the collector.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the collector with the new passed in policy.
Arguments:
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy
- _policy (:obj:`Optional[namedtuple]`): New policy to replace the existing one, if provided.
"""
assert hasattr(self, '_env'), "please set env first"
if _policy is not None:
Expand All @@ -119,16 +118,15 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None:
"""
Overview:
Reset the environment and policy.
Reset the environment and policy within the collector.
If _env is None, reset the old environment.
If _env is not None, replace the old environment in the collector with the new passed \
in environment and launch.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the collector with the new passed in policy.
Arguments:
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
env_manager(BaseEnvManager)
- _policy (:obj:`Optional[namedtuple]`): New policy to replace the existing one, if provided.
- _env (:obj:`Optional[BaseEnvManager]`): New environment to replace the existing one, if provided.
"""
if _env is not None:
self.reset_env(_env)
Expand All @@ -151,9 +149,10 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana
def _reset_stat(self, env_id: int) -> None:
"""
Overview:
Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\
and env_info. Reset these states according to env_id. You can refer to base_serial_collector\
to get more messages.
Reset the statistics for a specific environment.
Including reset the traj_buffer, obs_pool, policy_output_pool and env_info.
Reset these states according to env_id.
You can refer to base_serial_collector to get more messages.
Arguments:
- env_id (:obj:`int`): the id where we need to reset the collector's state
"""
Expand All @@ -165,8 +164,8 @@ def _reset_stat(self, env_id: int) -> None:
def close(self) -> None:
"""
Overview:
Close the collector. If end_flag is False, close the environment, flush the tb_logger\
and close the tb_logger.
Close the collector. If end_flag is False, close the environment, flush the tb_logger
and close the tb_logger.
"""
if self._end_flag:
return
Expand All @@ -182,13 +181,13 @@ def collect(self,
policy_kwargs: Optional[dict] = None) -> List[Any]:
"""
Overview:
Collect `n_episode` data with policy_kwargs, which is already trained `train_iter` iterations
Collect experience data for a specified number of episodes using the current policy.
Arguments:
- n_episode (:obj:`int`): the number of collecting data episode
- train_iter (:obj:`int`): the number of training iteration
- policy_kwargs (:obj:`dict`): the keyword args for policy forward
- n_episode (:obj:`Optional[int]`): Number of episodes to collect. Defaults to a pre-set value if None.
- train_iter (:obj:`int`): Current training iteration.
- policy_kwargs (:obj:`Optional[dict]`): Additional keyword arguments for the policy.
Returns:
- return_data (:obj:`List`): A list containing collected episodes.
- return_data (:obj:`List[Any]`): A list of collected experience episodes.
"""
if n_episode is None:
if self._default_n_episode is None:
Expand Down Expand Up @@ -295,17 +294,16 @@ def collect(self,
def envstep(self) -> int:
"""
Overview:
Print the total envstep count.
Return:
- envstep (:obj:`int`): the total envstep count
Get the total number of environment steps taken by the collector.
Returns:
- envstep (:obj:`int`): Total count of environment steps.
"""
return self._total_envstep_count

def close(self) -> None:
"""
Overview:
Close the collector. If end_flag is False, close the environment, flush the tb_logger\
and close the tb_logger.
Close the collector and clean up resources such as environment and logger.
"""
if self._end_flag:
return
Expand All @@ -318,18 +316,16 @@ def close(self) -> None:
def __del__(self) -> None:
"""
Overview:
Execute the close command and close the collector. __del__ is automatically called to \
destroy the collector instance when the collector finishes its work
Destructor method that is called when the collector object is being destroyed.
"""
self.close()

def _output_log(self, train_iter: int) -> None:
"""
Overview:
Print the output log information. You can refer to Docs/Best Practice/How to understand\
training generated folders/Serial mode/log/collector for more details.
Output logging information for the current collection phase.
Arguments:
- train_iter (:obj:`int`): the number of training iteration.
- train_iter (:obj:`int`): Current training iteration for logging purposes.
"""
if self._rank != 0:
return
Expand Down Expand Up @@ -368,9 +364,9 @@ def _output_log(self, train_iter: int) -> None:
def reward_shaping(self, transitions, eval_episode_return):
"""
Overview:
Shape the reward according to the player.
Shape the rewards in the collected transitions based on the outcome of the episode.
Return:
- transitions: data transitions.
- transitions (:obj:`List[dict]`): List of data transitions.
"""
reward = transitions[-1]['reward']
to_play = transitions[-1]['obs']['to_play']
Expand Down
77 changes: 37 additions & 40 deletions lzero/worker/alphazero_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from collections import namedtuple
from typing import Optional, Callable, Tuple
import torch

import numpy as np
import torch
from ding.envs import BaseEnv
from ding.envs import BaseEnvManager
from ding.torch_utils import to_tensor, to_item

from ding.utils import build_logger, EasyTimer, SERIAL_EVALUATOR_REGISTRY
from ding.utils import get_world_size, get_rank, broadcast_object_list
from ding.worker.collector.base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor
Expand All @@ -15,9 +15,9 @@
class AlphaZeroEvaluator(ISerialEvaluator):
"""
Overview:
AlphaZero Evaluator.
AlphaZero Evaluator class which handles the evaluation of the trained policy.
Interfaces:
__init__, reset, reset_policy, reset_env, close, should_eval, eval
``__init__``, ``reset``, ``reset_policy``, ``reset_env``, ``close``, ``should_eval``, ``eval``
Property:
env, policy
"""
Expand All @@ -36,17 +36,17 @@ def __init__(
) -> None:
"""
Overview:
Init the AlphaZero evaluator according to input arguments.
Initialize the AlphaZero evaluator with the given parameters.
Arguments:
- eval_freq (:obj:`int`): evaluation frequency in terms of training steps.
- n_evaluator_episode (:obj:`int`): the number of episodes to eval in total.
- env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
its derivatives are supported.
- policy (:obj:`Policy`): The policy to be collected.
- tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary.
- exp_name (:obj:`str`): Experiment name, which is used to indicate output directory.
- instance_name (:obj:`Optional[str]`): Name of this instance.
- env_config: Config of environment
- eval_freq (:obj:`int`): Evaluation frequency in terms of training steps.
- n_evaluator_episode (:obj:`int`): Number of episodes for each evaluation.
- stop_value (:obj:`float`): Reward threshold to stop training if surpassed.
- env (:obj:`Optional[BaseEnvManager]`): Environment manager for managing multiple environments.
- policy (:obj:`Optional[namedtuple]`): Policy to be evaluated.
- tb_logger (:obj:`Optional[SummaryWriter]`): TensorBoard logger for logging statistics.
- exp_name (:obj:`str`): Name of the experiment for logging purposes.
- instance_name (:obj:`str`): Unique identifier for this evaluator instance.
- env_config (:obj:`Optional[dict]`): Configuration for the environment.
"""
self._eval_freq = eval_freq
self._exp_name = exp_name
Expand Down Expand Up @@ -78,14 +78,12 @@ def __init__(
def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
"""
Overview:
Reset evaluator's environment. In some case, we need evaluator use the same policy in different \
environments. We can use reset_env to reset the environment.
Reset or replace the environment in the evaluator.
If _env is None, reset the old environment.
If _env is not None, replace the old environment in the evaluator with the \
new passed in environment and launch.
Arguments:
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
env_manager(BaseEnvManager)
- _env (:obj:`Optional[BaseEnvManager]`): New environment to replace the existing one, if provided.
"""
if _env is not None:
self._env = _env
Expand All @@ -97,8 +95,7 @@ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
"""
Overview:
Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\
different policy. We can use reset_policy to reset the policy.
Reset or replace the policy in the evaluator.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
Expand All @@ -112,16 +109,15 @@ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None:
"""
Overview:
Reset evaluator's policy and environment. Use new policy and environment to collect data.
Reset the environment and policy within the evaluator.
If _env is None, reset the old environment.
If _env is not None, replace the old environment in the evaluator with the new passed in \
environment and launch.
If _policy is None, reset the old policy.
If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
Arguments:
- policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy
- env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
env_manager(BaseEnvManager)
- _policy (:obj:`Optional[namedtuple]`): New policy to replace the existing one, if provided.
- _env (:obj:`Optional[BaseEnvManager]`): New environment to replace the existing one, if provided.
"""
if _env is not None:
self.reset_env(_env)
Expand All @@ -134,8 +130,8 @@ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvMana
def close(self) -> None:
"""
Overview:
Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\
and close the tb_logger.
Close the evaluator and clean up resources such as environment and logger.
If end_flag is False, close the environment, flush the tb_logger and close the tb_logger.
"""
if self._end_flag:
return
Expand All @@ -148,18 +144,18 @@ def close(self) -> None:
def __del__(self) -> None:
"""
Overview:
Execute the close command and close the evaluator. __del__ is automatically called \
to destroy the evaluator instance when the evaluator finishes its work
Destructor method that is called when the evaluator object is being destroyed.
__del__ is automatically called to destroy the evaluator instance when the evaluator finishes its work.
"""
self.close()

def should_eval(self, train_iter: int) -> bool:
"""
Overview:
Determine whether you need to start the evaluation mode, if the number of training has reached\
the maximum number of times to start the evaluator, return True
Arguments:
- train_iter (:obj:`int`): Current training iteration.
Check if it is time to evaluate the policy based on the training iteration count.
If the amount of training has reached the maximum number of times to start the evaluator, return True.
Returns:
- (:obj:`bool`): Flag indicating whether evaluation should be performed.
"""
if train_iter == self._last_eval_iter:
return False
Expand All @@ -178,17 +174,18 @@ def eval(
) -> Tuple[bool, dict]:
"""
Overview:
Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
Execute the evaluation of the policy and determine if the stopping condition has been met.
Arguments:
- save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward.
- train_iter (:obj:`int`): Current training iteration.
- envstep (:obj:`int`): Current env interaction step.
- n_episode (:obj:`int`): Number of evaluation episodes.
- save_ckpt_fn (:obj:`Optional[Callable]`): Callback function to save a checkpoint.
- train_iter (:obj:`int`): Current number of training iterations completed.
- envstep (:obj:`int`): Current number of environment steps completed.
- n_episode (:obj:`Optional[int]`): Number of episodes to evaluate. Defaults to preset if None.
- force_render (:obj:`bool`): Force rendering of the environment, if applicable.
Returns:
- stop_flag (:obj:`bool`): Whether this training program can be ended.
- return_info (:obj:`dict`): Current evaluation return information.
- stop_flag (:obj:`bool`): Whether the training process should stop based on evaluation results.
- return_info (:obj:`dict`): Information about the evaluation results.
"""
# evaluator only work on rank0
# the evaluator only works on rank0
stop_flag, return_info = False, []
if get_rank() == 0:
if n_episode is None:
Expand Down
Loading

0 comments on commit 541f5bf

Please sign in to comment.