diff --git a/docs/source/tutorials/envs/customize_envs.md b/docs/source/tutorials/envs/customize_envs.md index 5d914307a..bd697b391 100644 --- a/docs/source/tutorials/envs/customize_envs.md +++ b/docs/source/tutorials/envs/customize_envs.md @@ -203,12 +203,12 @@ class LightZeroEnvWrapper(gym.Wrapper): Specifically, use the following function to wrap a gym environment into the format required by LightZero using `LightZeroEnvWrapper`. The `get_wrappered_env` function returns an anonymous function that generates a `DingEnvWrapper` instance each time it is called. This instance takes `LightZeroEnvWrapper` as an anonymous function and internally wraps the original environment into the format required by LightZero. ```python -def get_wrappered_env(wrapper_cfg: EasyDict, env_name: str): +def get_wrappered_env(wrapper_cfg: EasyDict, env_id: str): # overview comments ... if wrapper_cfg.manually_discretization: return lambda: DingEnvWrapper( - gym.make(env_name), + gym.make(env_id), cfg={ 'env_wrapper': [ lambda env: ActionDiscretizationEnvWrapper(env, wrapper_cfg), lambda env: @@ -218,7 +218,7 @@ def get_wrappered_env(wrapper_cfg: EasyDict, env_name: str): ) else: return lambda: DingEnvWrapper( - gym.make(env_name), cfg={'env_wrapper': [lambda env: LightZeroEnvWrapper(env, wrapper_cfg)]} + gym.make(env_id), cfg={'env_wrapper': [lambda env: LightZeroEnvWrapper(env, wrapper_cfg)]} ) ``` diff --git a/docs/source/tutorials/envs/customize_envs_zh.md b/docs/source/tutorials/envs/customize_envs_zh.md index c995f9fe7..89945dc43 100644 --- a/docs/source/tutorials/envs/customize_envs_zh.md +++ b/docs/source/tutorials/envs/customize_envs_zh.md @@ -216,12 +216,12 @@ class LightZeroEnvWrapper(gym.Wrapper): 具体使用时,使用下面的函数,将一个 gym 环境,通过 `LightZeroEnvWrapper` 包装成 LightZero 所需要的环境格式。 `get_wrappered_env` 会返回一个匿名函数,该匿名函数每次调用都会产生一个 `DingEnvWrapper` 实例,该实例会将 `LightZeroEnvWrapper` 作为匿名函数传入,并在实例内部将原始环境封装成 LightZero 所需的格式。 ```Python -def get_wrappered_env(wrapper_cfg: EasyDict, env_name: str): +def get_wrappered_env(wrapper_cfg: EasyDict, env_id: str): # overview comments ... if wrapper_cfg.manually_discretization: return lambda: DingEnvWrapper( - gym.make(env_name), + gym.make(env_id), cfg={ 'env_wrapper': [ lambda env: ActionDiscretizationEnvWrapper(env, wrapper_cfg), lambda env: @@ -231,7 +231,7 @@ def get_wrappered_env(wrapper_cfg: EasyDict, env_name: str): ) else: return lambda: DingEnvWrapper( - gym.make(env_name), cfg={'env_wrapper': [lambda env: LightZeroEnvWrapper(env, wrapper_cfg)]} + gym.make(env_id), cfg={'env_wrapper': [lambda env: LightZeroEnvWrapper(env, wrapper_cfg)]} ) ``` diff --git a/lzero/agent/__init__.py b/lzero/agent/__init__.py index a6cbb38df..5e22f7294 100644 --- a/lzero/agent/__init__.py +++ b/lzero/agent/__init__.py @@ -1 +1,6 @@ +from .alphazero import AlphaZeroAgent +from .efficientzero import EfficientZeroAgent +from .gumbel_muzero import GumbelMuZeroAgent from .muzero import MuZeroAgent +from .sampled_alphazero import SampledAlphaZeroAgent +from .sampled_efficientzero import SampledEfficientZeroAgent diff --git a/lzero/agent/alphazero.py b/lzero/agent/alphazero.py new file mode 100644 index 000000000..d31e17bb9 --- /dev/null +++ b/lzero/agent/alphazero.py @@ -0,0 +1,381 @@ +import os +from functools import partial +from typing import Optional, Union, List +import copy +import numpy as np +import torch +from ding.bonus.common import TrainingReturn, EvalReturn +from ding.config import save_config_py, compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner, create_buffer +from ditk import logging +from easydict import EasyDict +from tensorboardX import SummaryWriter + +from lzero.agent.config.alphazero import supported_env_cfg +from lzero.policy import visit_count_temperature +from lzero.policy.alphazero import AlphaZeroPolicy +from lzero.worker import AlphaZeroCollector as Collector +from lzero.worker import AlphaZeroEvaluator as Evaluator + + +class AlphaZeroAgent: + """ + Overview: + Agent class for executing AlphaZero algorithms which include methods for training, deployment, and batch evaluation. + Interfaces: + ``__init__``, ``train``, ``deploy``, ``batch_evaluate`` + Properties: + ``best`` + + .. note:: + This agent class is tailored for use with the HuggingFace Model Zoo for LightZero + (e.g. https://huggingface.co/OpenDILabCommunity/CartPole-v0-AlphaZero), + and provides methods such as "train" and "deploy". + """ + + supported_env_list = list(supported_env_cfg.keys()) + + def __init__( + self, + env_id: str = None, + seed: int = 0, + exp_name: str = None, + model: Optional[torch.nn.Module] = None, + cfg: Optional[Union[EasyDict, dict]] = None, + policy_state_dict: str = None, + ) -> None: + """ + Overview: + Initialize the AlphaZeroAgent instance with environment parameters, model, and configuration. + Arguments: + - env_id (:obj:`str`): Identifier for the environment to be used, registered in gym. + - seed (:obj:`int`): Random seed for reproducibility. Defaults to 0. + - exp_name (:obj:`Optional[str]`): Name for the experiment. Defaults to None. + - model (:obj:`Optional[torch.nn.Module]`): PyTorch module to be used as the model. If None, a default model is created. Defaults to None. + - cfg (:obj:`Optional[Union[EasyDict, dict]]`): Configuration for the agent. If None, default configuration will be used. Defaults to None. + - policy_state_dict (:obj:`Optional[str]`): Path to a pre-trained model state dictionary. If provided, state dict will be loaded. Defaults to None. + + .. note:: + - If `env_id` is not specified, it must be included in `cfg`. + - The `supported_env_list` contains all the environment IDs that are supported by this agent. + """ + assert env_id is not None or cfg["main_config"]["env_id"] is not None, "Please specify env_id or cfg." + + if cfg is not None and not isinstance(cfg, EasyDict): + cfg = EasyDict(cfg) + + if env_id is not None: + assert env_id in AlphaZeroAgent.supported_env_list, "Please use supported envs: {}".format( + AlphaZeroAgent.supported_env_list + ) + if cfg is None: + cfg = supported_env_cfg[env_id] + else: + assert cfg.main_config.env.env_id == env_id, "env_id in cfg should be the same as env_id in args." + else: + assert hasattr(cfg.main_config.env, "env_id"), "Please specify env_id in cfg." + assert cfg.main_config.env.env_id in AlphaZeroAgent.supported_env_list, "Please use supported envs: {}".format( + AlphaZeroAgent.supported_env_list + ) + default_policy_config = EasyDict({"policy": AlphaZeroPolicy.default_config()}) + default_policy_config.policy.update(cfg.main_config.policy) + cfg.main_config.policy = default_policy_config.policy + + if exp_name is not None: + cfg.main_config.exp_name = exp_name + self.origin_cfg = cfg + self.cfg = compile_config( + cfg.main_config, seed=seed, env=None, auto=True, policy=AlphaZeroPolicy, create_cfg=cfg.create_config + ) + self.exp_name = self.cfg.exp_name + + logging.getLogger().setLevel(logging.INFO) + self.seed = seed + set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda) + if not os.path.exists(self.exp_name): + os.makedirs(self.exp_name) + save_config_py(cfg, os.path.join(self.exp_name, 'policy_config.py')) + if model is None: + from lzero.model.alphazero_model import AlphaZeroModel + model = AlphaZeroModel(**self.cfg.policy.model) + + if self.cfg.policy.cuda and torch.cuda.is_available(): + self.cfg.policy.device = 'cuda' + else: + self.cfg.policy.device = 'cpu' + self.policy = create_policy(self.cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + if policy_state_dict is not None: + self.policy.learn_mode.load_state_dict(policy_state_dict) + self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt") + + self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env) + + def train( + self, + step: int = int(1e7), + ) -> TrainingReturn: + """ + Overview: + Train the agent through interactions with the environment. + Arguments: + - step (:obj:`int`): Total number of environment steps to train for. Defaults to 10 million (1e7). + Returns: + - A `TrainingReturn` object containing training information, such as logs and potentially a URL to a training dashboard. + .. note:: + The method involves interacting with the environment, collecting experience, and optimizing the model. + """ + + collector_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.collector_env_cfg] + ) + evaluator_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg] + ) + + collector_env.seed(self.cfg.seed) + evaluator_env.seed(self.cfg.seed, dynamic_seed=False) + set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(self.cfg.exp_name), 'serial') + ) if get_rank() == 0 else None + learner = BaseLearner( + self.cfg.policy.learn.learner, self.policy.learn_mode, tb_logger, exp_name=self.cfg.exp_name + ) + replay_buffer = create_buffer(self.cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=self.cfg.exp_name) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = self.cfg.policy + batch_size = policy_config.batch_size + collector = Collector( + env=collector_env, + policy=self.policy.collect_mode, + tb_logger=tb_logger, + exp_name=self.cfg.exp_name, + ) + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=self.cfg.env.n_evaluator_episode, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + tb_logger=tb_logger, + exp_name=self.cfg.exp_name, + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + + if self.cfg.policy.update_per_collect is not None: + update_per_collect = self.cfg.policy.update_per_collect + + while True: + collect_kwargs = {} + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + # Evaluate policy performance. + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + new_data = sum(new_data, []) + + if self.cfg.policy.update_per_collect is None: + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + collected_transitions_num = len(new_data) + update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio) + replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) + + # Learn policy from collected data + for i in range(update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + train_data = replay_buffer.sample(batch_size, learner.train_iter) + if train_data is None: + logging.warning( + 'The data in replay_buffer is not sufficient to sample a mini-batch.' + 'continue to collect now ....' + ) + break + + learner.train(train_data, collector.envstep) + + if collector.envstep >= step: + break + + # Learner's after_run hook. + learner.call_hook('after_run') + + return TrainingReturn(wandb_url=None) + + def deploy( + self, + enable_save_replay: bool = False, + concatenate_all_replay: bool = False, + replay_save_path: str = None, + seed: Optional[Union[int, List]] = None, + debug: bool = False + ) -> EvalReturn: + """ + Overview: + Deploy the agent for evaluation in the environment, with optional replay saving. The performance of the + agent will be evaluated. Average return and standard deviation of the return will be returned. + If `enable_save_replay` is True, replay videos are saved in the specified `replay_save_path`. + Arguments: + - enable_save_replay (:obj:`bool`): Flag to enable saving of replay footage. Defaults to False. + - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one file. Defaults to False. + - replay_save_path (:obj:`Optional[str]`): Directory path to save replay videos. Defaults to None, which sets a default path. + - seed (:obj:`Optional[Union[int, List[int]]]`): Seed or list of seeds for environment reproducibility. Defaults to None. + - debug (:obj:`bool`): Whether to enable the debug mode. Default to False. + Returns: + - An `EvalReturn` object containing evaluation metrics such as mean and standard deviation of returns. + """ + + deply_configs = [copy.deepcopy(self.evaluator_env_cfg[0])] + + if type(seed) == int: + seed_list = [seed] + elif seed: + seed_list = seed + else: + seed_list = [0] + + reward_list = [] + + if enable_save_replay: + replay_save_path = replay_save_path if replay_save_path is not None else os.path.join( + self.exp_name, 'videos' + ) + deply_configs[0]['replay_path'] = replay_save_path + deply_configs[0]['save_replay'] = True + + for seed in seed_list: + + evaluator_env = create_env_manager(self.cfg.env.manager, [partial(self.env_fn, cfg=deply_configs[0])]) + + evaluator_env.seed(seed if seed is not None else self.cfg.seed, dynamic_seed=False) + set_pkg_seed(seed if seed is not None else self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=1, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + exp_name=self.cfg.exp_name, + ) + + # ============================================================== + # Main loop + # ============================================================== + + stop, reward = evaluator.eval() + reward_list.extend(reward['eval_episode_return']) + + if enable_save_replay: + if not os.path.exists(replay_save_path): + os.makedirs(replay_save_path) + files = os.listdir(replay_save_path) + files = [file for file in files if file.endswith('.mp4')] + files.sort() + if concatenate_all_replay: + # create a file named 'files.txt' to store the names of all mp4 files + with open(os.path.join(replay_save_path, 'files.txt'), 'w') as f: + for file in files: + f.write("file '{}'\n".format(file)) + + # combine all the mp4 files into one mp4 file, rename it as 'deploy.mp4' + os.system( + 'ffmpeg -f concat -safe 0 -i {} -c copy {}/deploy.mp4'.format( + os.path.join(replay_save_path, 'files.txt'), replay_save_path + ) + ) + + return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list)) + + def batch_evaluate( + self, + n_evaluator_episode: int = None, + ) -> EvalReturn: + """ + Overview: + Perform a batch evaluation of the agent over a specified number of episodes: ``n_evaluator_episode``. + Arguments: + - n_evaluator_episode (:obj:`Optional[int]`): Number of episodes to run the evaluation. + If None, uses default value from configuration. Defaults to None. + Returns: + - An `EvalReturn` object with evaluation results such as mean and standard deviation of returns. + + .. note:: + This method evaluates the agent's performance across multiple episodes to gauge its effectiveness. + """ + evaluator_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg] + ) + + evaluator_env.seed(self.cfg.seed, dynamic_seed=False) + set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=self.cfg.env.n_evaluator_episode + if n_evaluator_episode is None else n_evaluator_episode, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + exp_name=self.cfg.exp_name, + ) + + # ============================================================== + # Main loop + # ============================================================== + + stop, reward = evaluator.eval() + + return EvalReturn( + eval_value=np.mean(reward['eval_episode_return']), eval_value_std=np.std(reward['eval_episode_return']) + ) + + @property + def best(self): + """ + Overview: + Provides access to the best model according to evaluation metrics. + Returns: + - The agent with the best model loaded. + + .. note:: + The best model is saved in the path `./exp_name/ckpt/ckpt_best.pth.tar`. + When this property is accessed, the agent instance will load the best model state. + """ + + best_model_file_path = os.path.join(self.checkpoint_save_dir, "ckpt_best.pth.tar") + # Load best model if it exists + if os.path.exists(best_model_file_path): + policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu")) + self.policy.learn_mode.load_state_dict(policy_state_dict) + return self diff --git a/lzero/agent/config/alphazero/__init__.py b/lzero/agent/config/alphazero/__init__.py new file mode 100644 index 000000000..f18651a6f --- /dev/null +++ b/lzero/agent/config/alphazero/__init__.py @@ -0,0 +1,10 @@ +from easydict import EasyDict +from . import gomoku_play_with_bot +from . import tictactoe_play_with_bot + +supported_env_cfg = { + gomoku_play_with_bot.cfg.main_config.env.env_id: gomoku_play_with_bot.cfg, + tictactoe_play_with_bot.cfg.main_config.env.env_id: tictactoe_play_with_bot.cfg, +} + +supported_env_cfg = EasyDict(supported_env_cfg) diff --git a/lzero/agent/config/alphazero/gomoku_play_with_bot.py b/lzero/agent/config/alphazero/gomoku_play_with_bot.py new file mode 100644 index 000000000..909d4dc19 --- /dev/null +++ b/lzero/agent/config/alphazero/gomoku_play_with_bot.py @@ -0,0 +1,104 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +board_size = 6 # default_size is 15 +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 50 +batch_size = 256 +max_env_step = int(5e5) +prob_random_action_in_bot = 0.5 +mcts_ctree = False +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='Gomoku-play-with-bot-AlphaZero', + seed=0, + env=dict( + env_id='Gomoku-play-with-bot', + battle_mode='play_with_bot_mode', + replay_format='mp4', + board_size=board_size, + bot_action_type='v1', + prob_random_action_in_bot=prob_random_action_in_bot, + channel_last=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # ============================================================== + # for the creation of simulation env + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + scale=True, + screen_scaling=9, + render_mode='image_savefile_mode', + replay_path=None, + alphazero_mcts_ctree=mcts_ctree, + # ============================================================== + ), + policy=dict( + mcts_ctree=mcts_ctree, + # ============================================================== + # for the creation of simulation env + simulation_env_id='gomoku', + simulation_env_config_type='play_with_bot', + # ============================================================== + torch_compile=False, + tensor_float_32=False, + model=dict( + observation_shape=(3, board_size, board_size), + action_space_size=int(1 * board_size * board_size), + num_res_blocks=1, + num_channels=32, + ), + cuda=True, + board_size=board_size, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + value_weight=1.0, + entropy_weight=0.0, + n_episode=n_episode, + eval_freq=int(2e3), + mcts=dict(num_simulations=num_simulations), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='gomoku', + import_names=['zoo.board_games.gomoku.envs.gomoku_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='alphazero', + import_names=['lzero.policy.alphazero'], + ), + collector=dict( + type='episode_alphazero', + import_names=['lzero.worker.alphazero_collector'], + ), + evaluator=dict( + type='alphazero', + import_names=['lzero.worker.alphazero_evaluator'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/alphazero/tictactoe_play_with_bot.py b/lzero/agent/config/alphazero/tictactoe_play_with_bot.py new file mode 100644 index 000000000..845dc0ee7 --- /dev/null +++ b/lzero/agent/config/alphazero/tictactoe_play_with_bot.py @@ -0,0 +1,100 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 5 +num_simulations = 25 +update_per_collect = 50 +batch_size = 256 +max_env_step = int(2e5) +mcts_ctree = False +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='TicTacToe-play-with-bot-AlphaZero', + seed=0, + env=dict( + env_id='TicTacToe-play-with-bot', + board_size=3, + battle_mode='play_with_bot_mode', + bot_action_type='v0', # {'v0', 'alpha_beta_pruning'} + channel_last=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # ============================================================== + # for the creation of simulation env + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + scale=True, + alphazero_mcts_ctree=mcts_ctree, + save_replay_gif=False, + replay_path_gif='./replay_gif', + # ============================================================== + ), + policy=dict( + mcts_ctree=mcts_ctree, + # ============================================================== + # for the creation of simulation env + simulation_env_id='tictactoe', + simulation_env_config_type='play_with_bot', + # ============================================================== + model=dict( + observation_shape=(3, 3, 3), + action_space_size=int(1 * 3 * 3), + # We use the small size model for tictactoe. + num_res_blocks=1, + num_channels=16, + fc_value_layers=[8], + fc_policy_layers=[8], + ), + cuda=True, + board_size=3, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + value_weight=1.0, + entropy_weight=0.0, + n_episode=n_episode, + eval_freq=int(2e3), + mcts=dict(num_simulations=num_simulations), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='tictactoe', + import_names=['zoo.board_games.tictactoe.envs.tictactoe_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='alphazero', + import_names=['lzero.policy.alphazero'], + ), + collector=dict( + type='episode_alphazero', + import_names=['lzero.worker.alphazero_collector'], + ), + evaluator=dict( + type='alphazero', + import_names=['lzero.worker.alphazero_evaluator'], + ) + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/efficientzero/__init__.py b/lzero/agent/config/efficientzero/__init__.py new file mode 100644 index 000000000..bbb3176a4 --- /dev/null +++ b/lzero/agent/config/efficientzero/__init__.py @@ -0,0 +1,18 @@ +from easydict import EasyDict +from . import gym_breakoutnoframeskip_v4 +from . import gym_cartpole_v0 +from . import gym_lunarlander_v2 +from . import gym_mspacmannoframeskip_v4 +from . import gym_pendulum_v1 +from . import gym_pongnoframeskip_v4 + +supported_env_cfg = { + gym_breakoutnoframeskip_v4.cfg.main_config.env.env_id: gym_breakoutnoframeskip_v4.cfg, + gym_cartpole_v0.cfg.main_config.env.env_id: gym_cartpole_v0.cfg, + gym_lunarlander_v2.cfg.main_config.env.env_id: gym_lunarlander_v2.cfg, + gym_mspacmannoframeskip_v4.cfg.main_config.env.env_id: gym_mspacmannoframeskip_v4.cfg, + gym_pendulum_v1.cfg.main_config.env.env_id: gym_pendulum_v1.cfg, + gym_pongnoframeskip_v4.cfg.main_config.env.env_id: gym_pongnoframeskip_v4.cfg, +} + +supported_env_cfg = EasyDict(supported_env_cfg) diff --git a/lzero/agent/config/efficientzero/gym_breakoutnoframeskip_v4.py b/lzero/agent/config/efficientzero/gym_breakoutnoframeskip_v4.py new file mode 100644 index 000000000..b426c93ee --- /dev/null +++ b/lzero/agent/config/efficientzero/gym_breakoutnoframeskip_v4.py @@ -0,0 +1,85 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 +max_env_step = int(1e6) +reanalyze_ratio = 0. +eps_greedy_exploration_in_collect = True +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='BreakoutNoFrameskip-v4-EfficientZero', + seed=0, + env=dict( + env_id='BreakoutNoFrameskip-v4', + obs_shape=(4, 96, 96), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(2e4), + ), + policy=dict( + model=dict( + observation_shape=(4, 96, 96), + frame_stack_num=4, + action_space_size=4, + downsample=True, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=400, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + # need to dynamically adjust the number of decay steps according to the characteristics of the environment and the algorithm + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + use_augmentation=True, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/efficientzero/gym_cartpole_v0.py b/lzero/agent/config/efficientzero/gym_cartpole_v0.py new file mode 100644 index 000000000..663316603 --- /dev/null +++ b/lzero/agent/config/efficientzero/gym_cartpole_v0.py @@ -0,0 +1,73 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(1e5) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config = dict( + exp_name='CartPole-v0-EfficientZero', + env=dict( + env_id='CartPole-v0', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=4, + action_space_size=2, + model_type='mlp', + lstm_hidden_size=128, + latent_state_dim=128, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=50, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e2), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config=dict( + env=dict( + type='cartpole_lightzero', + import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), + ), +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/efficientzero/gym_lunarlander_v2.py b/lzero/agent/config/efficientzero/gym_lunarlander_v2.py new file mode 100644 index 000000000..f6fb30cc7 --- /dev/null +++ b/lzero/agent/config/efficientzero/gym_lunarlander_v2.py @@ -0,0 +1,76 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 200 +batch_size = 256 +max_env_step = int(5e6) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='LunarLander-v2-EfficientZero', + seed=0, + env=dict( + env_id='LunarLander-v2', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=8, + action_space_size=4, + model_type='mlp', + lstm_hidden_size=256, + latent_state_dim=256, + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(1e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config=dict( + env=dict( + type='lunarlander', + import_names=['zoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), + ), +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/efficientzero/gym_mspacmannoframeskip_v4.py b/lzero/agent/config/efficientzero/gym_mspacmannoframeskip_v4.py new file mode 100644 index 000000000..2372ff362 --- /dev/null +++ b/lzero/agent/config/efficientzero/gym_mspacmannoframeskip_v4.py @@ -0,0 +1,84 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 +max_env_step = int(1e6) +reanalyze_ratio = 0. + +eps_greedy_exploration_in_collect = False +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='MsPacmanNoFrameskip-v4-EfficientZero', + seed=0, + env=dict( + env_id='MsPacmanNoFrameskip-v4', + obs_shape=(4, 96, 96), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(4, 96, 96), + frame_stack_num=4, + action_space_size=9, + downsample=True, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=400, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + # need to dynamically adjust the number of decay steps according to the characteristics of the environment and the algorithm + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + use_augmentation=True, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/efficientzero/gym_pendulum_v1.py b/lzero/agent/config/efficientzero/gym_pendulum_v1.py new file mode 100644 index 000000000..5bf4c3ddf --- /dev/null +++ b/lzero/agent/config/efficientzero/gym_pendulum_v1.py @@ -0,0 +1,73 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 200 +batch_size = 256 +max_env_step = int(1e6) +reanalyze_ratio = 0 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='CartPole-v0-EfficientZero', + seed=0, + env=dict( + env_id='Pendulum-v1', + continuous=False, + manually_discretization=True, + each_dim_disc_size=11, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=3, + action_space_size=11, + model_type='mlp', + lstm_hidden_size=128, + latent_state_dim=128, + ), + cuda=True, + env_type='not_board_games', + game_segment_length=50, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config=dict( + env=dict( + type='pendulum_lightzero', + import_names=['zoo.classic_control.pendulum.envs.pendulum_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), + ), +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/efficientzero/gym_pongnoframeskip_v4.py b/lzero/agent/config/efficientzero/gym_pongnoframeskip_v4.py new file mode 100644 index 000000000..de40742a7 --- /dev/null +++ b/lzero/agent/config/efficientzero/gym_pongnoframeskip_v4.py @@ -0,0 +1,84 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 +max_env_step = int(1e6) +reanalyze_ratio = 0. + +eps_greedy_exploration_in_collect = False +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='PongNoFrameskip-v4-EfficientZero', + seed=0, + env=dict( + env_id='PongNoFrameskip-v4', + obs_shape=(4, 96, 96), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(4, 96, 96), + frame_stack_num=4, + action_space_size=6, + downsample=True, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=400, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + # need to dynamically adjust the number of decay steps according to the characteristics of the environment and the algorithm + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + use_augmentation=True, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='efficientzero', + import_names=['lzero.policy.efficientzero'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/gumbel_muzero/__init__.py b/lzero/agent/config/gumbel_muzero/__init__.py new file mode 100644 index 000000000..0e91bb684 --- /dev/null +++ b/lzero/agent/config/gumbel_muzero/__init__.py @@ -0,0 +1,13 @@ +from easydict import EasyDict +from . import gomoku_play_with_bot +from . import gym_cartpole_v0 +from . import tictactoe_play_with_bot + + +supported_env_cfg = { + gomoku_play_with_bot.cfg.main_config.env.env_id: gomoku_play_with_bot.cfg, + gym_cartpole_v0.cfg.main_config.env.env_id: gym_cartpole_v0.cfg, + tictactoe_play_with_bot.cfg.main_config.env.env_id: tictactoe_play_with_bot.cfg, +} + +supported_env_cfg = EasyDict(supported_env_cfg) diff --git a/lzero/agent/config/gumbel_muzero/gomoku_play_with_bot.py b/lzero/agent/config/gumbel_muzero/gomoku_play_with_bot.py new file mode 100644 index 000000000..33239162b --- /dev/null +++ b/lzero/agent/config/gumbel_muzero/gomoku_play_with_bot.py @@ -0,0 +1,91 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 32 +n_episode = 32 +evaluator_env_num = 5 +num_simulations = 50 +update_per_collect = 50 +reanalyze_ratio = 0. +batch_size = 256 +max_env_step = int(1e6) + +board_size = 6 # default_size is 15 +bot_action_type = 'v0' # options={'v0', 'v1'} +prob_random_action_in_bot = 0.5 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='Gomoku-play-with-bot-GumbelMuZero', + seed=0, + env=dict( + env_id='Gomoku-play-with-bot', + battle_mode='play_with_bot_mode', + render_mode='image_savefile_mode', + replay_format='mp4', + board_size=board_size, + bot_action_type=bot_action_type, + prob_random_action_in_bot=prob_random_action_in_bot, + channel_last=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(3, board_size, board_size), + action_space_size=int(board_size * board_size), + image_channel=3, + num_res_blocks=1, + num_channels=32, + support_scale=10, + reward_support_size=21, + value_support_size=21, + ), + cuda=True, + env_type='board_games', + action_type='varied_action_space', + game_segment_length=int(board_size * board_size / 2), # for battle_mode='play_with_bot_mode' + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + max_num_considered_actions=6, + # NOTE:In board_games, we set large td_steps to make sure the value target is the final outcome. + td_steps=int(board_size * board_size / 2), # for battle_mode='play_with_bot_mode' + # NOTE:In board_games, we set discount_factor=1. + discount_factor=1, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e5), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='gomoku', + import_names=['zoo.board_games.gomoku.envs.gomoku_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='gumbel_muzero', + import_names=['lzero.policy.gumbel_muzero'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/gumbel_muzero/gym_cartpole_v0.py b/lzero/agent/config/gumbel_muzero/gym_cartpole_v0.py new file mode 100644 index 000000000..a13303525 --- /dev/null +++ b/lzero/agent/config/gumbel_muzero/gym_cartpole_v0.py @@ -0,0 +1,77 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(1e5) +reanalyze_ratio = 0 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='CartPole-v0-GumbelMuZero', + seed=0, + env=dict( + env_id='CartPole-v0', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=4, + action_space_size=2, + model_type='mlp', + lstm_hidden_size=128, + latent_state_dim=128, + self_supervised_learning_loss=True, # NOTE: default is False. + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=50, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + max_num_considered_actions=2, + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, # NOTE: default is 0. + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e2), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config=dict( + env=dict( + type='cartpole_lightzero', + import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='gumbel_muzero', + import_names=['lzero.policy.gumbel_muzero'], + ), + ), +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/gumbel_muzero/tictactoe_play_with_bot.py b/lzero/agent/config/gumbel_muzero/tictactoe_play_with_bot.py new file mode 100644 index 000000000..48c7cc528 --- /dev/null +++ b/lzero/agent/config/gumbel_muzero/tictactoe_play_with_bot.py @@ -0,0 +1,86 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 5 +num_simulations = 30 +update_per_collect = 50 +batch_size = 256 +max_env_step = int(2e5) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='TicTacToe-play-with-bot-GumbelMuZero', + seed=0, + env=dict( + env_id='TicTacToe-play-with-bot', + battle_mode='play_with_bot_mode', + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(3, 3, 3), + action_space_size=9, + image_channel=3, + # We use the small size model for tictactoe. + num_res_blocks=1, + num_channels=16, + fc_reward_layers=[8], + fc_value_layers=[8], + fc_policy_layers=[8], + support_scale=10, + reward_support_size=21, + value_support_size=21, + ), + cuda=True, + env_type='board_games', + action_type='varied_action_space', + game_segment_length=5, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + max_num_considered_actions=3, + # NOTE:In board_games, we set large td_steps to make sure the value target is the final outcome. + td_steps=9, + num_unroll_steps=3, + # NOTE:In board_games, we set discount_factor=1. + discount_factor=1, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='tictactoe', + import_names=['zoo.board_games.tictactoe.envs.tictactoe_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='gumbel_muzero', + import_names=['lzero.policy.gumbel_muzero'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/muzero/__init__.py b/lzero/agent/config/muzero/__init__.py index e8b937645..17770763d 100644 --- a/lzero/agent/config/muzero/__init__.py +++ b/lzero/agent/config/muzero/__init__.py @@ -1,8 +1,22 @@ from easydict import EasyDict +from . import gomoku_play_with_bot +from . import gym_breakoutnoframeskip_v4 from . import gym_cartpole_v0 +from . import gym_lunarlander_v2 +from . import gym_mspacmannoframeskip_v4 +from . import gym_pendulum_v1 +from . import gym_pongnoframeskip_v4 +from . import tictactoe_play_with_bot supported_env_cfg = { + gomoku_play_with_bot.cfg.main_config.env.env_id: gomoku_play_with_bot.cfg, + gym_breakoutnoframeskip_v4.cfg.main_config.env.env_id: gym_breakoutnoframeskip_v4.cfg, gym_cartpole_v0.cfg.main_config.env.env_id: gym_cartpole_v0.cfg, + gym_lunarlander_v2.cfg.main_config.env.env_id: gym_lunarlander_v2.cfg, + gym_mspacmannoframeskip_v4.cfg.main_config.env.env_id: gym_mspacmannoframeskip_v4.cfg, + gym_pendulum_v1.cfg.main_config.env.env_id: gym_pendulum_v1.cfg, + gym_pongnoframeskip_v4.cfg.main_config.env.env_id: gym_pongnoframeskip_v4.cfg, + tictactoe_play_with_bot.cfg.main_config.env.env_id: tictactoe_play_with_bot.cfg, } supported_env_cfg = EasyDict(supported_env_cfg) diff --git a/lzero/agent/config/muzero/gomoku_play_with_bot.py b/lzero/agent/config/muzero/gomoku_play_with_bot.py new file mode 100644 index 000000000..541424c26 --- /dev/null +++ b/lzero/agent/config/muzero/gomoku_play_with_bot.py @@ -0,0 +1,90 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 32 +n_episode = 32 +evaluator_env_num = 5 +num_simulations = 50 +update_per_collect = 50 +reanalyze_ratio = 0. +batch_size = 256 +max_env_step = int(1e6) + +board_size = 6 # default_size is 15 +bot_action_type = 'v0' # options={'v0', 'v1'} +prob_random_action_in_bot = 0.5 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='Gomoku-play-with-bot-MuZero', + seed=0, + env=dict( + env_id='Gomoku-play-with-bot', + battle_mode='play_with_bot_mode', + render_mode='image_savefile_mode', + replay_format='mp4', + board_size=board_size, + bot_action_type=bot_action_type, + prob_random_action_in_bot=prob_random_action_in_bot, + channel_last=True, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(3, board_size, board_size), + action_space_size=int(board_size * board_size), + image_channel=3, + num_res_blocks=1, + num_channels=32, + support_scale=10, + reward_support_size=21, + value_support_size=21, + ), + cuda=True, + env_type='board_games', + action_type='varied_action_space', + game_segment_length=int(board_size * board_size / 2), # for battle_mode='play_with_bot_mode' + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + # NOTE:In board_games, we set large td_steps to make sure the value target is the final outcome. + td_steps=int(board_size * board_size / 2), # for battle_mode='play_with_bot_mode' + # NOTE:In board_games, we set discount_factor=1. + discount_factor=1, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e5), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='gomoku', + import_names=['zoo.board_games.gomoku.envs.gomoku_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/muzero/gym_breakoutnoframeskip_v4.py b/lzero/agent/config/muzero/gym_breakoutnoframeskip_v4.py new file mode 100644 index 000000000..e02dba881 --- /dev/null +++ b/lzero/agent/config/muzero/gym_breakoutnoframeskip_v4.py @@ -0,0 +1,89 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 +max_env_step = int(1e6) +reanalyze_ratio = 0. +eps_greedy_exploration_in_collect = True +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='BreakoutNoFrameskip-v4-MuZero', + seed=0, + env=dict( + stop_value=int(1e6), + env_id='BreakoutNoFrameskip-v4', + obs_shape=(4, 96, 96), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + collect_max_episode_steps=int(5e3), + eval_max_episode_steps=int(2e4), + ), + policy=dict( + model=dict( + observation_shape=(4, 96, 96), + frame_stack_num=4, + action_space_size=4, + downsample=True, + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=400, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + # need to dynamically adjust the number of decay steps + # according to the characteristics of the environment and the algorithm + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + use_augmentation=True, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + ssl_loss_weight=2, # default is 0 + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/muzero/gym_lunarlander_v2.py b/lzero/agent/config/muzero/gym_lunarlander_v2.py new file mode 100644 index 000000000..6817b9cb7 --- /dev/null +++ b/lzero/agent/config/muzero/gym_lunarlander_v2.py @@ -0,0 +1,83 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 200 +batch_size = 256 +max_env_step = int(5e6) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='LunarLander-v2-MuZero', + seed=0, + env=dict( + env_id='LunarLander-v2', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=8, + action_space_size=4, + model_type='mlp', + lstm_hidden_size=256, + latent_state_dim=256, + self_supervised_learning_loss=True, # NOTE: default is False. + discrete_action_encoding_type='one_hot', + res_connection_in_dynamics=True, + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, # NOTE: default is 0. + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(1e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config=dict( + env=dict( + type='lunarlander', + import_names=['zoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + collector=dict( + type='episode_muzero', + get_train_sample=True, + import_names=['lzero.worker.muzero_collector'], + ) + ), +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/muzero/gym_mspacmannoframeskip_v4.py b/lzero/agent/config/muzero/gym_mspacmannoframeskip_v4.py new file mode 100644 index 000000000..ade8c5ca8 --- /dev/null +++ b/lzero/agent/config/muzero/gym_mspacmannoframeskip_v4.py @@ -0,0 +1,87 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 +max_env_step = int(1e6) +reanalyze_ratio = 0. +eps_greedy_exploration_in_collect = False +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='MsPacmanNoFrameskip-v4-MuZero', + seed=0, + env=dict( + stop_value=int(1e6), + env_id='MsPacmanNoFrameskip-v4', + obs_shape=(4, 96, 96), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(4, 96, 96), + frame_stack_num=4, + action_space_size=9, + downsample=True, + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=400, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + # need to dynamically adjust the number of decay steps + # according to the characteristics of the environment and the algorithm + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + use_augmentation=True, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + ssl_loss_weight=2, # default is 0 + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/muzero/gym_pendulum_v1.py b/lzero/agent/config/muzero/gym_pendulum_v1.py new file mode 100644 index 000000000..3d5b6bf99 --- /dev/null +++ b/lzero/agent/config/muzero/gym_pendulum_v1.py @@ -0,0 +1,75 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 200 +batch_size = 256 +max_env_step = int(1e6) +reanalyze_ratio = 0 +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='CartPole-v0-MuZero', + seed=0, + env=dict( + env_id='Pendulum-v1', + continuous=False, + manually_discretization=True, + each_dim_disc_size=11, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=3, + action_space_size=11, + model_type='mlp', + lstm_hidden_size=128, + latent_state_dim=128, + self_supervised_learning_loss=True, # NOTE: default is False. + ), + cuda=True, + env_type='not_board_games', + game_segment_length=50, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + ssl_loss_weight=2, # NOTE: default is 0. + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config=dict( + env=dict( + type='pendulum_lightzero', + import_names=['zoo.classic_control.pendulum.envs.pendulum_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + ), +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/muzero/gym_pongnoframeskip_v4.py b/lzero/agent/config/muzero/gym_pongnoframeskip_v4.py new file mode 100644 index 000000000..26e631f45 --- /dev/null +++ b/lzero/agent/config/muzero/gym_pongnoframeskip_v4.py @@ -0,0 +1,87 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 +max_env_step = int(1e6) +reanalyze_ratio = 0. +eps_greedy_exploration_in_collect = False +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='PongNoFrameskip-v4-MuZero', + seed=0, + env=dict( + stop_value=int(1e6), + env_id='PongNoFrameskip-v4', + obs_shape=(4, 96, 96), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(4, 96, 96), + frame_stack_num=4, + action_space_size=6, + downsample=True, + self_supervised_learning_loss=True, # default is False + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=400, + random_collect_episode_num=0, + eps=dict( + eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect, + # need to dynamically adjust the number of decay steps + # according to the characteristics of the environment and the algorithm + type='linear', + start=1., + end=0.05, + decay=int(1e5), + ), + use_augmentation=True, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + ssl_loss_weight=2, # default is 0 + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/muzero/tictactoe_play_with_bot.py b/lzero/agent/config/muzero/tictactoe_play_with_bot.py new file mode 100644 index 000000000..74a0ff0b5 --- /dev/null +++ b/lzero/agent/config/muzero/tictactoe_play_with_bot.py @@ -0,0 +1,86 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 5 +num_simulations = 25 +update_per_collect = 50 +batch_size = 256 +max_env_step = int(2e5) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='TicTacToe-play-with-bot-MuZero', + seed=0, + env=dict( + env_id='TicTacToe-play-with-bot', + battle_mode='play_with_bot_mode', + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(3, 3, 3), + action_space_size=9, + image_channel=3, + # We use the small size model for tictactoe. + num_res_blocks=1, + num_channels=16, + fc_reward_layers=[8], + fc_value_layers=[8], + fc_policy_layers=[8], + support_scale=10, + reward_support_size=21, + value_support_size=21, + norm_type='BN', + ), + cuda=True, + env_type='board_games', + action_type='varied_action_space', + game_segment_length=5, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + # NOTE:In board_games, we set large td_steps to make sure the value target is the final outcome. + td_steps=9, + num_unroll_steps=3, + # NOTE:In board_games, we set discount_factor=1. + discount_factor=1, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e4), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='tictactoe', + import_names=['zoo.board_games.tictactoe.envs.tictactoe_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='muzero', + import_names=['lzero.policy.muzero'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/sampled_alphazero/__init__.py b/lzero/agent/config/sampled_alphazero/__init__.py new file mode 100644 index 000000000..f18651a6f --- /dev/null +++ b/lzero/agent/config/sampled_alphazero/__init__.py @@ -0,0 +1,10 @@ +from easydict import EasyDict +from . import gomoku_play_with_bot +from . import tictactoe_play_with_bot + +supported_env_cfg = { + gomoku_play_with_bot.cfg.main_config.env.env_id: gomoku_play_with_bot.cfg, + tictactoe_play_with_bot.cfg.main_config.env.env_id: tictactoe_play_with_bot.cfg, +} + +supported_env_cfg = EasyDict(supported_env_cfg) diff --git a/lzero/agent/config/sampled_alphazero/gomoku_play_with_bot.py b/lzero/agent/config/sampled_alphazero/gomoku_play_with_bot.py new file mode 100644 index 000000000..a209c954d --- /dev/null +++ b/lzero/agent/config/sampled_alphazero/gomoku_play_with_bot.py @@ -0,0 +1,114 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +board_size = 6 +num_simulations = 100 +update_per_collect = 50 +# board_size = 9 +# num_simulations = 200 +# update_per_collect = 100 +num_of_sampled_actions = 20 +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +batch_size = 256 +max_env_step = int(10e6) +prob_random_action_in_bot = 0.5 +mcts_ctree = False +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='Gomoku-play-with-bot-SampledAlphaZero', + seed=0, + env=dict( + env_id='Gomoku-play-with-bot', + battle_mode='play_with_bot_mode', + replay_format='mp4', + stop_value=2, + board_size=board_size, + bot_action_type='v0', + prob_random_action_in_bot=prob_random_action_in_bot, + channel_last=False, + use_katago_bot=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # ============================================================== + # for the creation of simulation env + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + scale=True, + check_action_to_connect4_in_bot_v0=False, + simulation_env_id="gomoku", + screen_scaling=9, + render_mode='image_savefile_mode', + replay_path=None, + alphazero_mcts_ctree=mcts_ctree, + # ============================================================== + ), + policy=dict( + # ============================================================== + # for the creation of simulation env + simulation_env_id='gomoku', + simulation_env_config_type='sampled_play_with_bot', + # ============================================================== + torch_compile=False, + tensor_float_32=False, + model=dict( + observation_shape=(3, board_size, board_size), + action_space_size=int(1 * board_size * board_size), + num_res_blocks=1, + num_channels=64, + ), + sampled_algo=True, + mcts_ctree=mcts_ctree, + policy_loss_type='KL', + cuda=True, + board_size=board_size, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + value_weight=1.0, + entropy_weight=0.0, + n_episode=n_episode, + eval_freq=int(2e3), + mcts=dict(num_simulations=num_simulations, num_of_sampled_actions=num_of_sampled_actions), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='gomoku', + import_names=['zoo.board_games.gomoku.envs.gomoku_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_alphazero', + import_names=['lzero.policy.sampled_alphazero'], + ), + collector=dict( + type='episode_alphazero', + get_train_sample=False, + import_names=['lzero.worker.alphazero_collector'], + ), + evaluator=dict( + type='alphazero', + import_names=['lzero.worker.alphazero_evaluator'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/sampled_alphazero/tictactoe_play_with_bot.py b/lzero/agent/config/sampled_alphazero/tictactoe_play_with_bot.py new file mode 100644 index 000000000..aeba8dab8 --- /dev/null +++ b/lzero/agent/config/sampled_alphazero/tictactoe_play_with_bot.py @@ -0,0 +1,102 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 5 +num_simulations = 25 +update_per_collect = 50 +batch_size = 256 +max_env_step = int(2e5) +num_of_sampled_actions = 5 +mcts_ctree = False +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='TicTacToe-play-with-bot-SampledAlphaZero', + seed=0, + env=dict( + env_id='TicTacToe-play-with-bot', + battle_mode='play_with_bot_mode', + bot_action_type='v0', # {'v0', 'alpha_beta_pruning'} + channel_last=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + # ============================================================== + # for the creation of simulation env + agent_vs_human=False, + prob_random_agent=0, + prob_expert_agent=0, + scale=True, + alphazero_mcts_ctree=mcts_ctree, + save_replay_gif=False, + replay_path_gif='./replay_gif', + # ============================================================== + ), + policy=dict( + # ============================================================== + # for the creation of simulation env + simulation_env_id='tictactoe', + simulation_env_config_type='play_with_bot', + # ============================================================== + model=dict( + observation_shape=(3, 3, 3), + action_space_size=int(1 * 3 * 3), + # We use the small size model for tictactoe. + num_res_blocks=1, + num_channels=16, + fc_value_layers=[8], + fc_policy_layers=[8], + ), + sampled_algo=True, + mcts_ctree=mcts_ctree, + policy_loss_type='KL', + cuda=True, + board_size=3, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + value_weight=1.0, + entropy_weight=0.0, + n_episode=n_episode, + eval_freq=int(2e3), + mcts=dict(num_simulations=num_simulations, num_of_sampled_actions=num_of_sampled_actions), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='tictactoe', + import_names=['zoo.board_games.tictactoe.envs.tictactoe_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_alphazero', + import_names=['lzero.policy.sampled_alphazero'], + ), + collector=dict( + type='episode_alphazero', + import_names=['lzero.worker.alphazero_collector'], + ), + evaluator=dict( + type='alphazero', + import_names=['lzero.worker.alphazero_evaluator'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/sampled_efficientzero/__init__.py b/lzero/agent/config/sampled_efficientzero/__init__.py new file mode 100644 index 000000000..38267b4a3 --- /dev/null +++ b/lzero/agent/config/sampled_efficientzero/__init__.py @@ -0,0 +1,18 @@ +from easydict import EasyDict +from . import gym_breakoutnoframeskip_v4 +from . import gym_cartpole_v0 +from . import gym_lunarlandercontinuous_v2 +from . import gym_mspacmannoframeskip_v4 +from . import gym_pendulum_v1 +from . import gym_pongnoframeskip_v4 + +supported_env_cfg = { + gym_breakoutnoframeskip_v4.cfg.main_config.env.env_id: gym_breakoutnoframeskip_v4.cfg, + gym_cartpole_v0.cfg.main_config.env.env_id: gym_cartpole_v0.cfg, + gym_lunarlandercontinuous_v2.cfg.main_config.env.env_id: gym_lunarlandercontinuous_v2.cfg, + gym_mspacmannoframeskip_v4.cfg.main_config.env.env_id: gym_mspacmannoframeskip_v4.cfg, + gym_pendulum_v1.cfg.main_config.env.env_id: gym_pendulum_v1.cfg, + gym_pongnoframeskip_v4.cfg.main_config.env.env_id: gym_pongnoframeskip_v4.cfg, +} + +supported_env_cfg = EasyDict(supported_env_cfg) diff --git a/lzero/agent/config/sampled_efficientzero/gym_breakoutnoframeskip_v4.py b/lzero/agent/config/sampled_efficientzero/gym_breakoutnoframeskip_v4.py new file mode 100644 index 000000000..d1fbeb89e --- /dev/null +++ b/lzero/agent/config/sampled_efficientzero/gym_breakoutnoframeskip_v4.py @@ -0,0 +1,78 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +continuous_action_space = False +K = 5 # num_of_sampled_actions +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 +max_env_step = int(1e6) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='BreakoutNoFrameskip-v4-SampledEfficientZero', + seed=0, + env=dict( + env_id='BreakoutNoFrameskip-v4', + obs_shape=(4, 96, 96), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(4, 96, 96), + frame_stack_num=4, + action_space_size=4, + downsample=True, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=400, + use_augmentation=True, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + policy_loss_type='cross_entropy', + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_efficientzero', + import_names=['lzero.policy.sampled_efficientzero'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/sampled_efficientzero/gym_cartpole_v0.py b/lzero/agent/config/sampled_efficientzero/gym_cartpole_v0.py new file mode 100644 index 000000000..6496729ab --- /dev/null +++ b/lzero/agent/config/sampled_efficientzero/gym_cartpole_v0.py @@ -0,0 +1,77 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +continuous_action_space = False +K = 2 # num_of_sampled_actions +num_simulations = 25 +update_per_collect = 100 +batch_size = 256 +max_env_step = int(1e5) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config = dict( + exp_name='CartPole-v0-SampledEfficientZero', + env=dict( + env_id='CartPole-v0', + continuous=False, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=4, + action_space_size=2, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + model_type='mlp', + lstm_hidden_size=128, + latent_state_dim=128, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=50, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e2), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config=dict( + env=dict( + type='cartpole_lightzero', + import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_efficientzero', + import_names=['lzero.policy.sampled_efficientzero'], + ), + ), +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/sampled_efficientzero/gym_lunarlandercontinuous_v2.py b/lzero/agent/config/sampled_efficientzero/gym_lunarlandercontinuous_v2.py new file mode 100644 index 000000000..2060c76e0 --- /dev/null +++ b/lzero/agent/config/sampled_efficientzero/gym_lunarlandercontinuous_v2.py @@ -0,0 +1,84 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +continuous_action_space = True +K = 20 # num_of_sampled_actions +num_simulations = 50 +update_per_collect = 200 +batch_size = 256 +max_env_step = int(5e6) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='LunarLanderContinuous-v2-SampledEfficientZero', + seed=0, + env=dict( + env_id='LunarLanderContinuous-v2', + continuous=True, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + mcts_ctree=True, + model=dict( + observation_shape=8, + action_space_size=2, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + sigma_type='conditioned', + model_type='mlp', + lstm_hidden_size=256, + latent_state_dim=256, + res_connection_in_dynamics=True, + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=200, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + grad_clip_value=0.5, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + random_collect_episode_num=0, + # NOTE: for continuous gaussian policy, we use the policy_entropy_loss as in the original Sampled MuZero paper. + policy_entropy_loss_weight=5e-3, + n_episode=n_episode, + eval_freq=int(1e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config=dict( + env=dict( + type='lunarlander', + import_names=['zoo.box2d.lunarlander.envs.lunarlander_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_efficientzero', + import_names=['lzero.policy.sampled_efficientzero'], + ), + ), +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/sampled_efficientzero/gym_mspacmannoframeskip_v4.py b/lzero/agent/config/sampled_efficientzero/gym_mspacmannoframeskip_v4.py new file mode 100644 index 000000000..38860a76e --- /dev/null +++ b/lzero/agent/config/sampled_efficientzero/gym_mspacmannoframeskip_v4.py @@ -0,0 +1,78 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +continuous_action_space = False +K = 5 # num_of_sampled_actions +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 +max_env_step = int(1e6) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='MsPacmanNoFrameskip-v4-SampledEfficientZero', + seed=0, + env=dict( + env_id='MsPacmanNoFrameskip-v4', + obs_shape=(4, 96, 96), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(4, 96, 96), + frame_stack_num=4, + action_space_size=9, + downsample=True, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=400, + use_augmentation=True, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + policy_loss_type='cross_entropy', + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_efficientzero', + import_names=['lzero.policy.sampled_efficientzero'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/sampled_efficientzero/gym_pendulum_v1.py b/lzero/agent/config/sampled_efficientzero/gym_pendulum_v1.py new file mode 100644 index 000000000..122796abe --- /dev/null +++ b/lzero/agent/config/sampled_efficientzero/gym_pendulum_v1.py @@ -0,0 +1,79 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +continuous_action_space = True +K = 20 # num_of_sampled_actions +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 200 +batch_size = 256 +max_env_step = int(1e6) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='CartPole-v0-SampledEfficientZero', + seed=0, + env=dict( + env_id='Pendulum-v1', + continuous=True, + manually_discretization=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=3, + action_space_size=11, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + sigma_type='conditioned', + model_type='mlp', + lstm_hidden_size=128, + latent_state_dim=128, + ), + cuda=True, + env_type='not_board_games', + game_segment_length=50, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='Adam', + lr_piecewise_constant_decay=False, + learning_rate=0.003, + # NOTE: for continuous gaussian policy, we use the policy_entropy_loss as in the original Sampled MuZero paper. + policy_entropy_loss_weight=5e-3, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config=dict( + env=dict( + type='pendulum_lightzero', + import_names=['zoo.classic_control.pendulum.envs.pendulum_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_efficientzero', + import_names=['lzero.policy.sampled_efficientzero'], + ), + ), +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/config/sampled_efficientzero/gym_pongnoframeskip_v4.py b/lzero/agent/config/sampled_efficientzero/gym_pongnoframeskip_v4.py new file mode 100644 index 000000000..a6bad1eb1 --- /dev/null +++ b/lzero/agent/config/sampled_efficientzero/gym_pongnoframeskip_v4.py @@ -0,0 +1,78 @@ +from easydict import EasyDict + +# ============================================================== +# begin of the most frequently changed config specified by the user +# ============================================================== +continuous_action_space = False +K = 5 # num_of_sampled_actions +collector_env_num = 8 +n_episode = 8 +evaluator_env_num = 3 +num_simulations = 50 +update_per_collect = 1000 +batch_size = 256 +max_env_step = int(1e6) +reanalyze_ratio = 0. +# ============================================================== +# end of the most frequently changed config specified by the user +# ============================================================== + +cfg = dict( + main_config=dict( + exp_name='PongNoFrameskip-v4-SampledEfficientZero', + seed=0, + env=dict( + env_id='PongNoFrameskip-v4', + obs_shape=(4, 96, 96), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + ), + policy=dict( + model=dict( + observation_shape=(4, 96, 96), + frame_stack_num=4, + action_space_size=6, + downsample=True, + continuous_action_space=continuous_action_space, + num_of_sampled_actions=K, + discrete_action_encoding_type='one_hot', + norm_type='BN', + ), + cuda=True, + env_type='not_board_games', + game_segment_length=400, + use_augmentation=True, + update_per_collect=update_per_collect, + batch_size=batch_size, + optim_type='SGD', + lr_piecewise_constant_decay=True, + learning_rate=0.2, + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + policy_loss_type='cross_entropy', + n_episode=n_episode, + eval_freq=int(2e3), + replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions. + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + wandb_logger=dict( + gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False + ), + ), + create_config = dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='sampled_efficientzero', + import_names=['lzero.policy.sampled_efficientzero'], + ), + ) +) + +cfg = EasyDict(cfg) diff --git a/lzero/agent/efficientzero.py b/lzero/agent/efficientzero.py new file mode 100644 index 000000000..421cea881 --- /dev/null +++ b/lzero/agent/efficientzero.py @@ -0,0 +1,425 @@ +import os +from functools import partial +from typing import Optional, Union, List + +import numpy as np +import torch +from ding.bonus.common import TrainingReturn, EvalReturn +from ding.config import save_config_py, compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from ditk import logging +from easydict import EasyDict +from tensorboardX import SummaryWriter + +from lzero.agent.config.efficientzero import supported_env_cfg +from lzero.entry.utils import log_buffer_memory_usage, random_collect +from lzero.mcts import EfficientZeroGameBuffer +from lzero.policy import visit_count_temperature +from lzero.policy.efficientzero import EfficientZeroPolicy +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator + + +class EfficientZeroAgent: + """ + Overview: + Agent class for executing EfficientZero algorithms which include methods for training, deployment, and batch evaluation. + Interfaces: + ``__init__``, ``train``, ``deploy``, ``batch_evaluate`` + Properties: + ``best`` + + .. note:: + This agent class is tailored for use with the HuggingFace Model Zoo for LightZero + (e.g. https://huggingface.co/OpenDILabCommunity/CartPole-v0-EfficientZero), + and provides methods such as "train" and "deploy". + """ + + supported_env_list = list(supported_env_cfg.keys()) + + def __init__( + self, + env_id: str = None, + seed: int = 0, + exp_name: str = None, + model: Optional[torch.nn.Module] = None, + cfg: Optional[Union[EasyDict, dict]] = None, + policy_state_dict: str = None, + ) -> None: + """ + Overview: + Initialize the EfficientZeroAgent instance with environment parameters, model, and configuration. + Arguments: + - env_id (:obj:`str`): Identifier for the environment to be used, registered in gym. + - seed (:obj:`int`): Random seed for reproducibility. Defaults to 0. + - exp_name (:obj:`Optional[str]`): Name for the experiment. Defaults to None. + - model (:obj:`Optional[torch.nn.Module]`): PyTorch module to be used as the model. If None, a default model is created. Defaults to None. + - cfg (:obj:`Optional[Union[EasyDict, dict]]`): Configuration for the agent. If None, default configuration will be used. Defaults to None. + - policy_state_dict (:obj:`Optional[str]`): Path to a pre-trained model state dictionary. If provided, state dict will be loaded. Defaults to None. + + .. note:: + - If `env_id` is not specified, it must be included in `cfg`. + - The `supported_env_list` contains all the environment IDs that are supported by this agent. + """ + assert env_id is not None or cfg["main_config"]["env_id"] is not None, "Please specify env_id or cfg." + + if cfg is not None and not isinstance(cfg, EasyDict): + cfg = EasyDict(cfg) + + if env_id is not None: + assert env_id in EfficientZeroAgent.supported_env_list, "Please use supported envs: {}".format( + EfficientZeroAgent.supported_env_list + ) + if cfg is None: + cfg = supported_env_cfg[env_id] + else: + assert cfg.main_config.env.env_id == env_id, "env_id in cfg should be the same as env_id in args." + else: + assert hasattr(cfg.main_config.env, "env_id"), "Please specify env_id in cfg." + assert cfg.main_config.env.env_id in EfficientZeroAgent.supported_env_list, "Please use supported envs: {}".format( + EfficientZeroAgent.supported_env_list + ) + default_policy_config = EasyDict({"policy": EfficientZeroPolicy.default_config()}) + default_policy_config.policy.update(cfg.main_config.policy) + cfg.main_config.policy = default_policy_config.policy + + if exp_name is not None: + cfg.main_config.exp_name = exp_name + self.origin_cfg = cfg + self.cfg = compile_config( + cfg.main_config, seed=seed, env=None, auto=True, policy=EfficientZeroPolicy, create_cfg=cfg.create_config + ) + self.exp_name = self.cfg.exp_name + + logging.getLogger().setLevel(logging.INFO) + self.seed = seed + set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda) + if not os.path.exists(self.exp_name): + os.makedirs(self.exp_name) + save_config_py(cfg, os.path.join(self.exp_name, 'policy_config.py')) + if model is None: + if self.cfg.policy.model.model_type == 'mlp': + from lzero.model.efficientzero_model_mlp import EfficientZeroModelMLP + model = EfficientZeroModelMLP(**self.cfg.policy.model) + elif self.cfg.policy.model.model_type == 'conv': + from lzero.model.efficientzero_model import EfficientZeroModel + model = EfficientZeroModel(**self.cfg.policy.model) + else: + raise NotImplementedError + if self.cfg.policy.cuda and torch.cuda.is_available(): + self.cfg.policy.device = 'cuda' + else: + self.cfg.policy.device = 'cpu' + self.policy = create_policy(self.cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + if policy_state_dict is not None: + self.policy.learn_mode.load_state_dict(policy_state_dict) + self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt") + + self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env) + + def train( + self, + step: int = int(1e7), + ) -> TrainingReturn: + """ + Overview: + Train the agent through interactions with the environment. + Arguments: + - step (:obj:`int`): Total number of environment steps to train for. Defaults to 10 million (1e7). + Returns: + - A `TrainingReturn` object containing training information, such as logs and potentially a URL to a training dashboard. + .. note:: + The method involves interacting with the environment, collecting experience, and optimizing the model. + """ + + collector_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.collector_env_cfg] + ) + evaluator_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg] + ) + + collector_env.seed(self.cfg.seed) + evaluator_env.seed(self.cfg.seed, dynamic_seed=False) + set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(self.cfg.exp_name), 'serial') + ) if get_rank() == 0 else None + learner = BaseLearner( + self.cfg.policy.learn.learner, self.policy.learn_mode, tb_logger, exp_name=self.cfg.exp_name + ) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = self.cfg.policy + batch_size = policy_config.batch_size + # specific game buffer for MCTS+RL algorithms + replay_buffer = EfficientZeroGameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=self.policy.collect_mode, + tb_logger=tb_logger, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=self.cfg.env.n_evaluator_episode, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + tb_logger=tb_logger, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + + if self.cfg.policy.update_per_collect is not None: + update_per_collect = self.cfg.policy.update_per_collect + + # The purpose of collecting random data before training: + # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely. + # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms. + if self.cfg.policy.random_collect_episode_num > 0: + random_collect(self.cfg.policy, self.policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + + while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + collect_kwargs = {} + # set temperature for visit count distributions according to the train_iter, + # please refer to Appendix D in MuZero paper for details. + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 + + # Evaluate policy performance. + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + if self.cfg.policy.update_per_collect is None: + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) + update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio) + # save returned new_data collected by the collector + replay_buffer.push_game_segments(new_data) + # remove the oldest data if the replay buffer is full. + replay_buffer.remove_oldest_data_to_fit() + + # Learn policy from collected data. + for i in range(update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, self.policy) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + + # The core train steps for MCTS+RL algorithms. + log_vars = learner.train(train_data, collector.envstep) + + if self.cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + if collector.envstep >= step: + break + + # Learner's after_run hook. + learner.call_hook('after_run') + + return TrainingReturn(wandb_url=None) + + def deploy( + self, + enable_save_replay: bool = False, + concatenate_all_replay: bool = False, + replay_save_path: str = None, + seed: Optional[Union[int, List]] = None, + debug: bool = False + ) -> EvalReturn: + """ + Overview: + Deploy the agent for evaluation in the environment, with optional replay saving. The performance of the + agent will be evaluated. Average return and standard deviation of the return will be returned. + If `enable_save_replay` is True, replay videos are saved in the specified `replay_save_path`. + Arguments: + - enable_save_replay (:obj:`bool`): Flag to enable saving of replay footage. Defaults to False. + - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one file. Defaults to False. + - replay_save_path (:obj:`Optional[str]`): Directory path to save replay videos. Defaults to None, which sets a default path. + - seed (:obj:`Optional[Union[int, List[int]]]`): Seed or list of seeds for environment reproducibility. Defaults to None. + - debug (:obj:`bool`): Whether to enable the debug mode. Default to False. + Returns: + - An `EvalReturn` object containing evaluation metrics such as mean and standard deviation of returns. + """ + + deply_configs = [self.evaluator_env_cfg[0]] + + if type(seed) == int: + seed_list = [seed] + elif seed: + seed_list = seed + else: + seed_list = [0] + + reward_list = [] + + if enable_save_replay: + replay_save_path = replay_save_path if replay_save_path is not None else os.path.join( + self.exp_name, 'videos' + ) + deply_configs[0]['replay_path'] = replay_save_path + deply_configs[0]['save_replay'] = True + + for seed in seed_list: + + evaluator_env = create_env_manager(self.cfg.env.manager, [partial(self.env_fn, cfg=deply_configs[0])]) + + evaluator_env.seed(seed if seed is not None else self.cfg.seed, dynamic_seed=False) + set_pkg_seed(seed if seed is not None else self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = self.cfg.policy + + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=1, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + + stop, reward = evaluator.eval() + reward_list.extend(reward['eval_episode_return']) + + if enable_save_replay: + if not os.path.exists(replay_save_path): + os.makedirs(replay_save_path) + files = os.listdir(replay_save_path) + files = [file for file in files if file.endswith('.mp4')] + files.sort() + if concatenate_all_replay: + # create a file named 'files.txt' to store the names of all mp4 files + with open(os.path.join(replay_save_path, 'files.txt'), 'w') as f: + for file in files: + f.write("file '{}'\n".format(file)) + + # combine all the mp4 files into one mp4 file, rename it as 'deploy.mp4' + os.system( + 'ffmpeg -f concat -safe 0 -i {} -c copy {}/deploy.mp4'.format( + os.path.join(replay_save_path, 'files.txt'), replay_save_path + ) + ) + + return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list)) + + def batch_evaluate( + self, + n_evaluator_episode: int = None, + ) -> EvalReturn: + """ + Overview: + Perform a batch evaluation of the agent over a specified number of episodes: ``n_evaluator_episode``. + Arguments: + - n_evaluator_episode (:obj:`Optional[int]`): Number of episodes to run the evaluation. + If None, uses default value from configuration. Defaults to None. + Returns: + - An `EvalReturn` object with evaluation results such as mean and standard deviation of returns. + + .. note:: + This method evaluates the agent's performance across multiple episodes to gauge its effectiveness. + """ + evaluator_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg] + ) + + evaluator_env.seed(self.cfg.seed, dynamic_seed=False) + set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = self.cfg.policy + + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=self.cfg.env.n_evaluator_episode + if n_evaluator_episode is None else n_evaluator_episode, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + + stop, reward = evaluator.eval() + + return EvalReturn( + eval_value=np.mean(reward['eval_episode_return']), eval_value_std=np.std(reward['eval_episode_return']) + ) + + @property + def best(self): + """ + Overview: + Provides access to the best model according to evaluation metrics. + Returns: + - The agent with the best model loaded. + + .. note:: + The best model is saved in the path `./exp_name/ckpt/ckpt_best.pth.tar`. + When this property is accessed, the agent instance will load the best model state. + """ + + best_model_file_path = os.path.join(self.checkpoint_save_dir, "ckpt_best.pth.tar") + # Load best model if it exists + if os.path.exists(best_model_file_path): + policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu")) + self.policy.learn_mode.load_state_dict(policy_state_dict) + return self diff --git a/lzero/agent/gumbel_muzero.py b/lzero/agent/gumbel_muzero.py new file mode 100644 index 000000000..0df583ab1 --- /dev/null +++ b/lzero/agent/gumbel_muzero.py @@ -0,0 +1,425 @@ +import os +from functools import partial +from typing import Optional, Union, List + +import numpy as np +import torch +from ding.bonus.common import TrainingReturn, EvalReturn +from ding.config import save_config_py, compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from ditk import logging +from easydict import EasyDict +from tensorboardX import SummaryWriter + +from lzero.agent.config.gumbel_muzero import supported_env_cfg +from lzero.entry.utils import log_buffer_memory_usage, random_collect +from lzero.mcts import GumbelMuZeroGameBuffer +from lzero.policy import visit_count_temperature +from lzero.policy.gumbel_muzero import GumbelMuZeroPolicy +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator + + +class GumbelMuZeroAgent: + """ + Overview: + Agent class for executing Gumbel MuZero algorithms which include methods for training, deployment, and batch evaluation. + Interfaces: + ``__init__``, ``train``, ``deploy``, ``batch_evaluate`` + Properties: + ``best`` + + .. note:: + This agent class is tailored for use with the HuggingFace Model Zoo for LightZero + (e.g. https://huggingface.co/OpenDILabCommunity/CartPole-v0-GumbelMuZero), + and provides methods such as "train" and "deploy". + """ + + supported_env_list = list(supported_env_cfg.keys()) + + def __init__( + self, + env_id: str = None, + seed: int = 0, + exp_name: str = None, + model: Optional[torch.nn.Module] = None, + cfg: Optional[Union[EasyDict, dict]] = None, + policy_state_dict: str = None, + ) -> None: + """ + Overview: + Initialize the GumbelMuZeroAgent instance with environment parameters, model, and configuration. + Arguments: + - env_id (:obj:`str`): Identifier for the environment to be used, registered in gym. + - seed (:obj:`int`): Random seed for reproducibility. Defaults to 0. + - exp_name (:obj:`Optional[str]`): Name for the experiment. Defaults to None. + - model (:obj:`Optional[torch.nn.Module]`): PyTorch module to be used as the model. If None, a default model is created. Defaults to None. + - cfg (:obj:`Optional[Union[EasyDict, dict]]`): Configuration for the agent. If None, default configuration will be used. Defaults to None. + - policy_state_dict (:obj:`Optional[str]`): Path to a pre-trained model state dictionary. If provided, state dict will be loaded. Defaults to None. + + .. note:: + - If `env_id` is not specified, it must be included in `cfg`. + - The `supported_env_list` contains all the environment IDs that are supported by this agent. + """ + assert env_id is not None or cfg["main_config"]["env_id"] is not None, "Please specify env_id or cfg." + + if cfg is not None and not isinstance(cfg, EasyDict): + cfg = EasyDict(cfg) + + if env_id is not None: + assert env_id in GumbelMuZeroAgent.supported_env_list, "Please use supported envs: {}".format( + GumbelMuZeroAgent.supported_env_list + ) + if cfg is None: + cfg = supported_env_cfg[env_id] + else: + assert cfg.main_config.env.env_id == env_id, "env_id in cfg should be the same as env_id in args." + else: + assert hasattr(cfg.main_config.env, "env_id"), "Please specify env_id in cfg." + assert cfg.main_config.env.env_id in GumbelMuZeroAgent.supported_env_list, "Please use supported envs: {}".format( + GumbelMuZeroAgent.supported_env_list + ) + default_policy_config = EasyDict({"policy": GumbelMuZeroPolicy.default_config()}) + default_policy_config.policy.update(cfg.main_config.policy) + cfg.main_config.policy = default_policy_config.policy + + if exp_name is not None: + cfg.main_config.exp_name = exp_name + self.origin_cfg = cfg + self.cfg = compile_config( + cfg.main_config, seed=seed, env=None, auto=True, policy=GumbelMuZeroPolicy, create_cfg=cfg.create_config + ) + self.exp_name = self.cfg.exp_name + + logging.getLogger().setLevel(logging.INFO) + self.seed = seed + set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda) + if not os.path.exists(self.exp_name): + os.makedirs(self.exp_name) + save_config_py(cfg, os.path.join(self.exp_name, 'policy_config.py')) + if model is None: + if self.cfg.policy.model.model_type == 'mlp': + from lzero.model.muzero_model_mlp import MuZeroModelMLP + model = MuZeroModelMLP(**self.cfg.policy.model) + elif self.cfg.policy.model.model_type == 'conv': + from lzero.model.muzero_model import MuZeroModel + model = MuZeroModel(**self.cfg.policy.model) + else: + raise NotImplementedError + if self.cfg.policy.cuda and torch.cuda.is_available(): + self.cfg.policy.device = 'cuda' + else: + self.cfg.policy.device = 'cpu' + self.policy = create_policy(self.cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + if policy_state_dict is not None: + self.policy.learn_mode.load_state_dict(policy_state_dict) + self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt") + + self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env) + + def train( + self, + step: int = int(1e7), + ) -> TrainingReturn: + """ + Overview: + Train the agent through interactions with the environment. + Arguments: + - step (:obj:`int`): Total number of environment steps to train for. Defaults to 10 million (1e7). + Returns: + - A `TrainingReturn` object containing training information, such as logs and potentially a URL to a training dashboard. + .. note:: + The method involves interacting with the environment, collecting experience, and optimizing the model. + """ + + collector_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.collector_env_cfg] + ) + evaluator_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg] + ) + + collector_env.seed(self.cfg.seed) + evaluator_env.seed(self.cfg.seed, dynamic_seed=False) + set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(self.cfg.exp_name), 'serial') + ) if get_rank() == 0 else None + learner = BaseLearner( + self.cfg.policy.learn.learner, self.policy.learn_mode, tb_logger, exp_name=self.cfg.exp_name + ) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = self.cfg.policy + batch_size = policy_config.batch_size + # specific game buffer for MCTS+RL algorithms + replay_buffer = GumbelMuZeroGameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=self.policy.collect_mode, + tb_logger=tb_logger, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=self.cfg.env.n_evaluator_episode, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + tb_logger=tb_logger, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + + if self.cfg.policy.update_per_collect is not None: + update_per_collect = self.cfg.policy.update_per_collect + + # The purpose of collecting random data before training: + # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely. + # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms. + if self.cfg.policy.random_collect_episode_num > 0: + random_collect(self.cfg.policy, self.policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + + while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + collect_kwargs = {} + # set temperature for visit count distributions according to the train_iter, + # please refer to Appendix D in MuZero paper for details. + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 + + # Evaluate policy performance. + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + if self.cfg.policy.update_per_collect is None: + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) + update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio) + # save returned new_data collected by the collector + replay_buffer.push_game_segments(new_data) + # remove the oldest data if the replay buffer is full. + replay_buffer.remove_oldest_data_to_fit() + + # Learn policy from collected data. + for i in range(update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, self.policy) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + + # The core train steps for MCTS+RL algorithms. + log_vars = learner.train(train_data, collector.envstep) + + if self.cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + if collector.envstep >= step: + break + + # Learner's after_run hook. + learner.call_hook('after_run') + + return TrainingReturn(wandb_url=None) + + def deploy( + self, + enable_save_replay: bool = False, + concatenate_all_replay: bool = False, + replay_save_path: str = None, + seed: Optional[Union[int, List]] = None, + debug: bool = False + ) -> EvalReturn: + """ + Overview: + Deploy the agent for evaluation in the environment, with optional replay saving. The performance of the + agent will be evaluated. Average return and standard deviation of the return will be returned. + If `enable_save_replay` is True, replay videos are saved in the specified `replay_save_path`. + Arguments: + - enable_save_replay (:obj:`bool`): Flag to enable saving of replay footage. Defaults to False. + - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one file. Defaults to False. + - replay_save_path (:obj:`Optional[str]`): Directory path to save replay videos. Defaults to None, which sets a default path. + - seed (:obj:`Optional[Union[int, List[int]]]`): Seed or list of seeds for environment reproducibility. Defaults to None. + - debug (:obj:`bool`): Whether to enable the debug mode. Default to False. + Returns: + - An `EvalReturn` object containing evaluation metrics such as mean and standard deviation of returns. + """ + + deply_configs = [self.evaluator_env_cfg[0]] + + if type(seed) == int: + seed_list = [seed] + elif seed: + seed_list = seed + else: + seed_list = [0] + + reward_list = [] + + if enable_save_replay: + replay_save_path = replay_save_path if replay_save_path is not None else os.path.join( + self.exp_name, 'videos' + ) + deply_configs[0]['replay_path'] = replay_save_path + deply_configs[0]['save_replay'] = True + + for seed in seed_list: + + evaluator_env = create_env_manager(self.cfg.env.manager, [partial(self.env_fn, cfg=deply_configs[0])]) + + evaluator_env.seed(seed if seed is not None else self.cfg.seed, dynamic_seed=False) + set_pkg_seed(seed if seed is not None else self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = self.cfg.policy + + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=1, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + + stop, reward = evaluator.eval() + reward_list.extend(reward['eval_episode_return']) + + if enable_save_replay: + if not os.path.exists(replay_save_path): + os.makedirs(replay_save_path) + files = os.listdir(replay_save_path) + files = [file for file in files if file.endswith('.mp4')] + files.sort() + if concatenate_all_replay: + # create a file named 'files.txt' to store the names of all mp4 files + with open(os.path.join(replay_save_path, 'files.txt'), 'w') as f: + for file in files: + f.write("file '{}'\n".format(file)) + + # combine all the mp4 files into one mp4 file, rename it as 'deploy.mp4' + os.system( + 'ffmpeg -f concat -safe 0 -i {} -c copy {}/deploy.mp4'.format( + os.path.join(replay_save_path, 'files.txt'), replay_save_path + ) + ) + + return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list)) + + def batch_evaluate( + self, + n_evaluator_episode: int = None, + ) -> EvalReturn: + """ + Overview: + Perform a batch evaluation of the agent over a specified number of episodes: ``n_evaluator_episode``. + Arguments: + - n_evaluator_episode (:obj:`Optional[int]`): Number of episodes to run the evaluation. + If None, uses default value from configuration. Defaults to None. + Returns: + - An `EvalReturn` object with evaluation results such as mean and standard deviation of returns. + + .. note:: + This method evaluates the agent's performance across multiple episodes to gauge its effectiveness. + """ + evaluator_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg] + ) + + evaluator_env.seed(self.cfg.seed, dynamic_seed=False) + set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = self.cfg.policy + + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=self.cfg.env.n_evaluator_episode + if n_evaluator_episode is None else n_evaluator_episode, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + + stop, reward = evaluator.eval() + + return EvalReturn( + eval_value=np.mean(reward['eval_episode_return']), eval_value_std=np.std(reward['eval_episode_return']) + ) + + @property + def best(self): + """ + Overview: + Provides access to the best model according to evaluation metrics. + Returns: + - The agent with the best model loaded. + + .. note:: + The best model is saved in the path `./exp_name/ckpt/ckpt_best.pth.tar`. + When this property is accessed, the agent instance will load the best model state. + """ + + best_model_file_path = os.path.join(self.checkpoint_save_dir, "ckpt_best.pth.tar") + # Load best model if it exists + if os.path.exists(best_model_file_path): + policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu")) + self.policy.learn_mode.load_state_dict(policy_state_dict) + return self diff --git a/lzero/agent/muzero.py b/lzero/agent/muzero.py index 4ddb436ae..55dda5d00 100644 --- a/lzero/agent/muzero.py +++ b/lzero/agent/muzero.py @@ -31,9 +31,9 @@ class MuZeroAgent: Overview: Agent class for executing MuZero algorithms which include methods for training, deployment, and batch evaluation. Interfaces: - __init__, train, deploy, batch_evaluate + ``__init__``, ``train``, ``deploy``, ``batch_evaluate`` Properties: - best + ``best`` .. note:: This agent class is tailored for use with the HuggingFace Model Zoo for LightZero @@ -303,6 +303,7 @@ def deploy( self.exp_name, 'videos' ) deply_configs[0]['replay_path'] = replay_save_path + deply_configs[0]['save_replay'] = True for seed in seed_list: @@ -334,8 +335,10 @@ def deploy( reward_list.extend(reward['eval_episode_return']) if enable_save_replay: + if not os.path.exists(replay_save_path): + os.makedirs(replay_save_path) files = os.listdir(replay_save_path) - files = [file for file in files if file.endswith('0.mp4')] + files = [file for file in files if file.endswith('.mp4')] files.sort() if concatenate_all_replay: # create a file named 'files.txt' to store the names of all mp4 files diff --git a/lzero/agent/sampled_alphazero.py b/lzero/agent/sampled_alphazero.py new file mode 100644 index 000000000..dc76c16e5 --- /dev/null +++ b/lzero/agent/sampled_alphazero.py @@ -0,0 +1,381 @@ +import os +from functools import partial +from typing import Optional, Union, List +import copy +import numpy as np +import torch +from ding.bonus.common import TrainingReturn, EvalReturn +from ding.config import save_config_py, compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner, create_buffer +from ditk import logging +from easydict import EasyDict +from tensorboardX import SummaryWriter + +from lzero.agent.config.alphazero import supported_env_cfg +from lzero.policy import visit_count_temperature +from lzero.policy.alphazero import AlphaZeroPolicy +from lzero.worker import AlphaZeroCollector as Collector +from lzero.worker import AlphaZeroEvaluator as Evaluator + + +class SampledAlphaZeroAgent: + """ + Overview: + Agent class for executing AlphaZero algorithms which include methods for training, deployment, and batch evaluation. + Interfaces: + ``__init__``, ``train``, ``deploy``, ``batch_evaluate`` + Properties: + ``best`` + + .. note:: + This agent class is tailored for use with the HuggingFace Model Zoo for LightZero + (e.g. https://huggingface.co/OpenDILabCommunity/CartPole-v0-AlphaZero), + and provides methods such as "train" and "deploy". + """ + + supported_env_list = list(supported_env_cfg.keys()) + + def __init__( + self, + env_id: str = None, + seed: int = 0, + exp_name: str = None, + model: Optional[torch.nn.Module] = None, + cfg: Optional[Union[EasyDict, dict]] = None, + policy_state_dict: str = None, + ) -> None: + """ + Overview: + Initialize the SampledAlphaZeroAgent instance with environment parameters, model, and configuration. + Arguments: + - env_id (:obj:`str`): Identifier for the environment to be used, registered in gym. + - seed (:obj:`int`): Random seed for reproducibility. Defaults to 0. + - exp_name (:obj:`Optional[str]`): Name for the experiment. Defaults to None. + - model (:obj:`Optional[torch.nn.Module]`): PyTorch module to be used as the model. If None, a default model is created. Defaults to None. + - cfg (:obj:`Optional[Union[EasyDict, dict]]`): Configuration for the agent. If None, default configuration will be used. Defaults to None. + - policy_state_dict (:obj:`Optional[str]`): Path to a pre-trained model state dictionary. If provided, state dict will be loaded. Defaults to None. + + .. note:: + - If `env_id` is not specified, it must be included in `cfg`. + - The `supported_env_list` contains all the environment IDs that are supported by this agent. + """ + assert env_id is not None or cfg["main_config"]["env_id"] is not None, "Please specify env_id or cfg." + + if cfg is not None and not isinstance(cfg, EasyDict): + cfg = EasyDict(cfg) + + if env_id is not None: + assert env_id in SampledAlphaZeroAgent.supported_env_list, "Please use supported envs: {}".format( + SampledAlphaZeroAgent.supported_env_list + ) + if cfg is None: + cfg = supported_env_cfg[env_id] + else: + assert cfg.main_config.env.env_id == env_id, "env_id in cfg should be the same as env_id in args." + else: + assert hasattr(cfg.main_config.env, "env_id"), "Please specify env_id in cfg." + assert cfg.main_config.env.env_id in SampledAlphaZeroAgent.supported_env_list, "Please use supported envs: {}".format( + SampledAlphaZeroAgent.supported_env_list + ) + default_policy_config = EasyDict({"policy": AlphaZeroPolicy.default_config()}) + default_policy_config.policy.update(cfg.main_config.policy) + cfg.main_config.policy = default_policy_config.policy + + if exp_name is not None: + cfg.main_config.exp_name = exp_name + self.origin_cfg = cfg + self.cfg = compile_config( + cfg.main_config, seed=seed, env=None, auto=True, policy=AlphaZeroPolicy, create_cfg=cfg.create_config + ) + self.exp_name = self.cfg.exp_name + + logging.getLogger().setLevel(logging.INFO) + self.seed = seed + set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda) + if not os.path.exists(self.exp_name): + os.makedirs(self.exp_name) + save_config_py(cfg, os.path.join(self.exp_name, 'policy_config.py')) + if model is None: + from lzero.model.alphazero_model import AlphaZeroModel + model = AlphaZeroModel(**self.cfg.policy.model) + + if self.cfg.policy.cuda and torch.cuda.is_available(): + self.cfg.policy.device = 'cuda' + else: + self.cfg.policy.device = 'cpu' + self.policy = create_policy(self.cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + if policy_state_dict is not None: + self.policy.learn_mode.load_state_dict(policy_state_dict) + self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt") + + self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env) + + def train( + self, + step: int = int(1e7), + ) -> TrainingReturn: + """ + Overview: + Train the agent through interactions with the environment. + Arguments: + - step (:obj:`int`): Total number of environment steps to train for. Defaults to 10 million (1e7). + Returns: + - A `TrainingReturn` object containing training information, such as logs and potentially a URL to a training dashboard. + .. note:: + The method involves interacting with the environment, collecting experience, and optimizing the model. + """ + + collector_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.collector_env_cfg] + ) + evaluator_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg] + ) + + collector_env.seed(self.cfg.seed) + evaluator_env.seed(self.cfg.seed, dynamic_seed=False) + set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(self.cfg.exp_name), 'serial') + ) if get_rank() == 0 else None + learner = BaseLearner( + self.cfg.policy.learn.learner, self.policy.learn_mode, tb_logger, exp_name=self.cfg.exp_name + ) + replay_buffer = create_buffer(self.cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=self.cfg.exp_name) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = self.cfg.policy + batch_size = policy_config.batch_size + collector = Collector( + env=collector_env, + policy=self.policy.collect_mode, + tb_logger=tb_logger, + exp_name=self.cfg.exp_name, + ) + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=self.cfg.env.n_evaluator_episode, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + tb_logger=tb_logger, + exp_name=self.cfg.exp_name, + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + + if self.cfg.policy.update_per_collect is not None: + update_per_collect = self.cfg.policy.update_per_collect + + while True: + collect_kwargs = {} + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + # Evaluate policy performance. + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + new_data = sum(new_data, []) + + if self.cfg.policy.update_per_collect is None: + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + collected_transitions_num = len(new_data) + update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio) + replay_buffer.push(new_data, cur_collector_envstep=collector.envstep) + + # Learn policy from collected data + for i in range(update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + train_data = replay_buffer.sample(batch_size, learner.train_iter) + if train_data is None: + logging.warning( + 'The data in replay_buffer is not sufficient to sample a mini-batch.' + 'continue to collect now ....' + ) + break + + learner.train(train_data, collector.envstep) + + if collector.envstep >= step: + break + + # Learner's after_run hook. + learner.call_hook('after_run') + + return TrainingReturn(wandb_url=None) + + def deploy( + self, + enable_save_replay: bool = False, + concatenate_all_replay: bool = False, + replay_save_path: str = None, + seed: Optional[Union[int, List]] = None, + debug: bool = False + ) -> EvalReturn: + """ + Overview: + Deploy the agent for evaluation in the environment, with optional replay saving. The performance of the + agent will be evaluated. Average return and standard deviation of the return will be returned. + If `enable_save_replay` is True, replay videos are saved in the specified `replay_save_path`. + Arguments: + - enable_save_replay (:obj:`bool`): Flag to enable saving of replay footage. Defaults to False. + - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one file. Defaults to False. + - replay_save_path (:obj:`Optional[str]`): Directory path to save replay videos. Defaults to None, which sets a default path. + - seed (:obj:`Optional[Union[int, List[int]]]`): Seed or list of seeds for environment reproducibility. Defaults to None. + - debug (:obj:`bool`): Whether to enable the debug mode. Default to False. + Returns: + - An `EvalReturn` object containing evaluation metrics such as mean and standard deviation of returns. + """ + + deply_configs = [copy.deepcopy(self.evaluator_env_cfg[0])] + + if type(seed) == int: + seed_list = [seed] + elif seed: + seed_list = seed + else: + seed_list = [0] + + reward_list = [] + + if enable_save_replay: + replay_save_path = replay_save_path if replay_save_path is not None else os.path.join( + self.exp_name, 'videos' + ) + deply_configs[0]['replay_path'] = replay_save_path + deply_configs[0]['save_replay'] = True + + for seed in seed_list: + + evaluator_env = create_env_manager(self.cfg.env.manager, [partial(self.env_fn, cfg=deply_configs[0])]) + + evaluator_env.seed(seed if seed is not None else self.cfg.seed, dynamic_seed=False) + set_pkg_seed(seed if seed is not None else self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=1, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + exp_name=self.cfg.exp_name, + ) + + # ============================================================== + # Main loop + # ============================================================== + + stop, reward = evaluator.eval() + reward_list.extend(reward['eval_episode_return']) + + if enable_save_replay: + if not os.path.exists(replay_save_path): + os.makedirs(replay_save_path) + files = os.listdir(replay_save_path) + files = [file for file in files if file.endswith('0.mp4')] + files.sort() + if concatenate_all_replay: + # create a file named 'files.txt' to store the names of all mp4 files + with open(os.path.join(replay_save_path, 'files.txt'), 'w') as f: + for file in files: + f.write("file '{}'\n".format(file)) + + # combine all the mp4 files into one mp4 file, rename it as 'deploy.mp4' + os.system( + 'ffmpeg -f concat -safe 0 -i {} -c copy {}/deploy.mp4'.format( + os.path.join(replay_save_path, 'files.txt'), replay_save_path + ) + ) + + return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list)) + + def batch_evaluate( + self, + n_evaluator_episode: int = None, + ) -> EvalReturn: + """ + Overview: + Perform a batch evaluation of the agent over a specified number of episodes: ``n_evaluator_episode``. + Arguments: + - n_evaluator_episode (:obj:`Optional[int]`): Number of episodes to run the evaluation. + If None, uses default value from configuration. Defaults to None. + Returns: + - An `EvalReturn` object with evaluation results such as mean and standard deviation of returns. + + .. note:: + This method evaluates the agent's performance across multiple episodes to gauge its effectiveness. + """ + evaluator_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg] + ) + + evaluator_env.seed(self.cfg.seed, dynamic_seed=False) + set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=self.cfg.env.n_evaluator_episode + if n_evaluator_episode is None else n_evaluator_episode, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + exp_name=self.cfg.exp_name, + ) + + # ============================================================== + # Main loop + # ============================================================== + + stop, reward = evaluator.eval() + + return EvalReturn( + eval_value=np.mean(reward['eval_episode_return']), eval_value_std=np.std(reward['eval_episode_return']) + ) + + @property + def best(self): + """ + Overview: + Provides access to the best model according to evaluation metrics. + Returns: + - The agent with the best model loaded. + + .. note:: + The best model is saved in the path `./exp_name/ckpt/ckpt_best.pth.tar`. + When this property is accessed, the agent instance will load the best model state. + """ + + best_model_file_path = os.path.join(self.checkpoint_save_dir, "ckpt_best.pth.tar") + # Load best model if it exists + if os.path.exists(best_model_file_path): + policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu")) + self.policy.learn_mode.load_state_dict(policy_state_dict) + return self diff --git a/lzero/agent/sampled_efficientzero.py b/lzero/agent/sampled_efficientzero.py new file mode 100644 index 000000000..079bdd11d --- /dev/null +++ b/lzero/agent/sampled_efficientzero.py @@ -0,0 +1,425 @@ +import os +from functools import partial +from typing import Optional, Union, List + +import numpy as np +import torch +from ding.bonus.common import TrainingReturn, EvalReturn +from ding.config import save_config_py, compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from ditk import logging +from easydict import EasyDict +from tensorboardX import SummaryWriter + +from lzero.agent.config.sampled_efficientzero import supported_env_cfg +from lzero.entry.utils import log_buffer_memory_usage, random_collect +from lzero.mcts import SampledEfficientZeroGameBuffer +from lzero.policy import visit_count_temperature +from lzero.policy.sampled_efficientzero import SampledEfficientZeroPolicy +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroEvaluator as Evaluator + + +class SampledEfficientZeroAgent: + """ + Overview: + Agent class for executing Sampled EfficientZero algorithms which include methods for training, deployment, and batch evaluation. + Interfaces: + ``__init__``, ``train``, ``deploy``, ``batch_evaluate`` + Properties: + ``best`` + + .. note:: + This agent class is tailored for use with the HuggingFace Model Zoo for LightZero + (e.g. https://huggingface.co/OpenDILabCommunity/CartPole-v0-SampledEfficientZero), + and provides methods such as "train" and "deploy". + """ + + supported_env_list = list(supported_env_cfg.keys()) + + def __init__( + self, + env_id: str = None, + seed: int = 0, + exp_name: str = None, + model: Optional[torch.nn.Module] = None, + cfg: Optional[Union[EasyDict, dict]] = None, + policy_state_dict: str = None, + ) -> None: + """ + Overview: + Initialize the SampledEfficientZeroAgent instance with environment parameters, model, and configuration. + Arguments: + - env_id (:obj:`str`): Identifier for the environment to be used, registered in gym. + - seed (:obj:`int`): Random seed for reproducibility. Defaults to 0. + - exp_name (:obj:`Optional[str]`): Name for the experiment. Defaults to None. + - model (:obj:`Optional[torch.nn.Module]`): PyTorch module to be used as the model. If None, a default model is created. Defaults to None. + - cfg (:obj:`Optional[Union[EasyDict, dict]]`): Configuration for the agent. If None, default configuration will be used. Defaults to None. + - policy_state_dict (:obj:`Optional[str]`): Path to a pre-trained model state dictionary. If provided, state dict will be loaded. Defaults to None. + + .. note:: + - If `env_id` is not specified, it must be included in `cfg`. + - The `supported_env_list` contains all the environment IDs that are supported by this agent. + """ + assert env_id is not None or cfg["main_config"]["env_id"] is not None, "Please specify env_id or cfg." + + if cfg is not None and not isinstance(cfg, EasyDict): + cfg = EasyDict(cfg) + + if env_id is not None: + assert env_id in SampledEfficientZeroAgent.supported_env_list, "Please use supported envs: {}".format( + SampledEfficientZeroAgent.supported_env_list + ) + if cfg is None: + cfg = supported_env_cfg[env_id] + else: + assert cfg.main_config.env.env_id == env_id, "env_id in cfg should be the same as env_id in args." + else: + assert hasattr(cfg.main_config.env, "env_id"), "Please specify env_id in cfg." + assert cfg.main_config.env.env_id in SampledEfficientZeroAgent.supported_env_list, "Please use supported envs: {}".format( + SampledEfficientZeroAgent.supported_env_list + ) + default_policy_config = EasyDict({"policy": SampledEfficientZeroPolicy.default_config()}) + default_policy_config.policy.update(cfg.main_config.policy) + cfg.main_config.policy = default_policy_config.policy + + if exp_name is not None: + cfg.main_config.exp_name = exp_name + self.origin_cfg = cfg + self.cfg = compile_config( + cfg.main_config, seed=seed, env=None, auto=True, policy=SampledEfficientZeroPolicy, create_cfg=cfg.create_config + ) + self.exp_name = self.cfg.exp_name + + logging.getLogger().setLevel(logging.INFO) + self.seed = seed + set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda) + if not os.path.exists(self.exp_name): + os.makedirs(self.exp_name) + save_config_py(cfg, os.path.join(self.exp_name, 'policy_config.py')) + if model is None: + if self.cfg.policy.model.model_type == 'mlp': + from lzero.model.sampled_efficientzero_model_mlp import SampledEfficientZeroModelMLP + model = SampledEfficientZeroModelMLP(**self.cfg.policy.model) + elif self.cfg.policy.model.model_type == 'conv': + from lzero.model.sampled_efficientzero_model import SampledEfficientZeroModel + model = SampledEfficientZeroModel(**self.cfg.policy.model) + else: + raise NotImplementedError + if self.cfg.policy.cuda and torch.cuda.is_available(): + self.cfg.policy.device = 'cuda' + else: + self.cfg.policy.device = 'cpu' + self.policy = create_policy(self.cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + if policy_state_dict is not None: + self.policy.learn_mode.load_state_dict(policy_state_dict) + self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt") + + self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env) + + def train( + self, + step: int = int(1e7), + ) -> TrainingReturn: + """ + Overview: + Train the agent through interactions with the environment. + Arguments: + - step (:obj:`int`): Total number of environment steps to train for. Defaults to 10 million (1e7). + Returns: + - A `TrainingReturn` object containing training information, such as logs and potentially a URL to a training dashboard. + .. note:: + The method involves interacting with the environment, collecting experience, and optimizing the model. + """ + + collector_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.collector_env_cfg] + ) + evaluator_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg] + ) + + collector_env.seed(self.cfg.seed) + evaluator_env.seed(self.cfg.seed, dynamic_seed=False) + set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # Create worker components: learner, collector, evaluator, replay buffer, commander. + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(self.cfg.exp_name), 'serial') + ) if get_rank() == 0 else None + learner = BaseLearner( + self.cfg.policy.learn.learner, self.policy.learn_mode, tb_logger, exp_name=self.cfg.exp_name + ) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = self.cfg.policy + batch_size = policy_config.batch_size + # specific game buffer for MCTS+RL algorithms + replay_buffer = SampledEfficientZeroGameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=self.policy.collect_mode, + tb_logger=tb_logger, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=self.cfg.env.n_evaluator_episode, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + tb_logger=tb_logger, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + # Learner's before_run hook. + learner.call_hook('before_run') + + if self.cfg.policy.update_per_collect is not None: + update_per_collect = self.cfg.policy.update_per_collect + + # The purpose of collecting random data before training: + # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely. + # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms. + if self.cfg.policy.random_collect_episode_num > 0: + random_collect(self.cfg.policy, self.policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + + while True: + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + collect_kwargs = {} + # set temperature for visit count distributions according to the train_iter, + # please refer to Appendix D in MuZero paper for details. + collect_kwargs['temperature'] = visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ) + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + else: + collect_kwargs['epsilon'] = 0.0 + + # Evaluate policy performance. + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect data by default config n_sample/n_episode. + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + if self.cfg.policy.update_per_collect is None: + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio. + collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]]) + update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio) + # save returned new_data collected by the collector + replay_buffer.push_game_segments(new_data) + # remove the oldest data if the replay buffer is full. + replay_buffer.remove_oldest_data_to_fit() + + # Learn policy from collected data. + for i in range(update_per_collect): + # Learner will train ``update_per_collect`` times in one iteration. + if replay_buffer.get_num_of_transitions() > batch_size: + train_data = replay_buffer.sample(batch_size, self.policy) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, ' + f'{replay_buffer} ' + f'continue to collect now ....' + ) + break + + # The core train steps for MCTS+RL algorithms. + log_vars = learner.train(train_data, collector.envstep) + + if self.cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + if collector.envstep >= step: + break + + # Learner's after_run hook. + learner.call_hook('after_run') + + return TrainingReturn(wandb_url=None) + + def deploy( + self, + enable_save_replay: bool = False, + concatenate_all_replay: bool = False, + replay_save_path: str = None, + seed: Optional[Union[int, List]] = None, + debug: bool = False + ) -> EvalReturn: + """ + Overview: + Deploy the agent for evaluation in the environment, with optional replay saving. The performance of the + agent will be evaluated. Average return and standard deviation of the return will be returned. + If `enable_save_replay` is True, replay videos are saved in the specified `replay_save_path`. + Arguments: + - enable_save_replay (:obj:`bool`): Flag to enable saving of replay footage. Defaults to False. + - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one file. Defaults to False. + - replay_save_path (:obj:`Optional[str]`): Directory path to save replay videos. Defaults to None, which sets a default path. + - seed (:obj:`Optional[Union[int, List[int]]]`): Seed or list of seeds for environment reproducibility. Defaults to None. + - debug (:obj:`bool`): Whether to enable the debug mode. Default to False. + Returns: + - An `EvalReturn` object containing evaluation metrics such as mean and standard deviation of returns. + """ + + deply_configs = [self.evaluator_env_cfg[0]] + + if type(seed) == int: + seed_list = [seed] + elif seed: + seed_list = seed + else: + seed_list = [0] + + reward_list = [] + + if enable_save_replay: + replay_save_path = replay_save_path if replay_save_path is not None else os.path.join( + self.exp_name, 'videos' + ) + deply_configs[0]['replay_path'] = replay_save_path + deply_configs[0]['save_replay'] = True + + for seed in seed_list: + + evaluator_env = create_env_manager(self.cfg.env.manager, [partial(self.env_fn, cfg=deply_configs[0])]) + + evaluator_env.seed(seed if seed is not None else self.cfg.seed, dynamic_seed=False) + set_pkg_seed(seed if seed is not None else self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = self.cfg.policy + + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=1, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + + stop, reward = evaluator.eval() + reward_list.extend(reward['eval_episode_return']) + + if enable_save_replay: + if not os.path.exists(replay_save_path): + os.makedirs(replay_save_path) + files = os.listdir(replay_save_path) + files = [file for file in files if file.endswith('.mp4')] + files.sort() + if concatenate_all_replay: + # create a file named 'files.txt' to store the names of all mp4 files + with open(os.path.join(replay_save_path, 'files.txt'), 'w') as f: + for file in files: + f.write("file '{}'\n".format(file)) + + # combine all the mp4 files into one mp4 file, rename it as 'deploy.mp4' + os.system( + 'ffmpeg -f concat -safe 0 -i {} -c copy {}/deploy.mp4'.format( + os.path.join(replay_save_path, 'files.txt'), replay_save_path + ) + ) + + return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list)) + + def batch_evaluate( + self, + n_evaluator_episode: int = None, + ) -> EvalReturn: + """ + Overview: + Perform a batch evaluation of the agent over a specified number of episodes: ``n_evaluator_episode``. + Arguments: + - n_evaluator_episode (:obj:`Optional[int]`): Number of episodes to run the evaluation. + If None, uses default value from configuration. Defaults to None. + Returns: + - An `EvalReturn` object with evaluation results such as mean and standard deviation of returns. + + .. note:: + This method evaluates the agent's performance across multiple episodes to gauge its effectiveness. + """ + evaluator_env = create_env_manager( + self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg] + ) + + evaluator_env.seed(self.cfg.seed, dynamic_seed=False) + set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda) + + # ============================================================== + # MCTS+RL algorithms related core code + # ============================================================== + policy_config = self.cfg.policy + + evaluator = Evaluator( + eval_freq=self.cfg.policy.eval_freq, + n_evaluator_episode=self.cfg.env.n_evaluator_episode + if n_evaluator_episode is None else n_evaluator_episode, + stop_value=self.cfg.env.stop_value, + env=evaluator_env, + policy=self.policy.eval_mode, + exp_name=self.cfg.exp_name, + policy_config=policy_config + ) + + # ============================================================== + # Main loop + # ============================================================== + + stop, reward = evaluator.eval() + + return EvalReturn( + eval_value=np.mean(reward['eval_episode_return']), eval_value_std=np.std(reward['eval_episode_return']) + ) + + @property + def best(self): + """ + Overview: + Provides access to the best model according to evaluation metrics. + Returns: + - The agent with the best model loaded. + + .. note:: + The best model is saved in the path `./exp_name/ckpt/ckpt_best.pth.tar`. + When this property is accessed, the agent instance will load the best model state. + """ + + best_model_file_path = os.path.join(self.checkpoint_save_dir, "ckpt_best.pth.tar") + # Load best model if it exists + if os.path.exists(best_model_file_path): + policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu")) + self.policy.learn_mode.load_state_dict(policy_state_dict) + return self diff --git a/lzero/entry/eval_alphazero.py b/lzero/entry/eval_alphazero.py index 486e2e6e5..b4f9259e9 100644 --- a/lzero/entry/eval_alphazero.py +++ b/lzero/entry/eval_alphazero.py @@ -87,7 +87,7 @@ def eval_alphazero( if print_seed_details: print("=" * 20) print(f'In seed {seed}, returns: {returns}') - if cfg.policy.simulation_env_name in ['tictactoe', 'connect4', 'gomoku', 'chess']: + if cfg.policy.simulation_env_id in ['tictactoe', 'connect4', 'gomoku', 'chess']: print( f'win rate: {len(np.where(returns == 1.)[0]) / num_episodes_each_seed}, draw rate: {len(np.where(returns == 0.)[0]) / num_episodes_each_seed}, lose rate: {len(np.where(returns == -1.)[0]) / num_episodes_each_seed}' ) diff --git a/lzero/entry/eval_muzero_with_gym_env.py b/lzero/entry/eval_muzero_with_gym_env.py index 663b4945a..c49e4f406 100644 --- a/lzero/entry/eval_muzero_with_gym_env.py +++ b/lzero/entry/eval_muzero_with_gym_env.py @@ -26,7 +26,7 @@ def eval_muzero_with_gym_env( """ Overview: The eval entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. - We create a gym environment using env_name parameter, and then convert it to the format + We create a gym environment using env_id parameter, and then convert it to the format required by LightZero using LightZeroEnvWrapper class. Please refer to the get_wrappered_env method for more details. Arguments: @@ -55,10 +55,10 @@ def eval_muzero_with_gym_env( collector_env_cfg = DingEnvWrapper.create_collector_env_cfg(cfg.env) evaluator_env_cfg = DingEnvWrapper.create_evaluator_env_cfg(cfg.env) collector_env = BaseEnvManager( - [get_wrappered_env(c, cfg.env.env_name) for c in collector_env_cfg], cfg=BaseEnvManager.default_config() + [get_wrappered_env(c, cfg.env.env_id) for c in collector_env_cfg], cfg=BaseEnvManager.default_config() ) evaluator_env = BaseEnvManager( - [get_wrappered_env(c, cfg.env.env_name) for c in evaluator_env_cfg], cfg=BaseEnvManager.default_config() + [get_wrappered_env(c, cfg.env.env_id) for c in evaluator_env_cfg], cfg=BaseEnvManager.default_config() ) collector_env.seed(cfg.seed) evaluator_env.seed(cfg.seed, dynamic_seed=False) diff --git a/lzero/entry/train_muzero_with_gym_env.py b/lzero/entry/train_muzero_with_gym_env.py index 1bfd855c5..3aa3906b9 100644 --- a/lzero/entry/train_muzero_with_gym_env.py +++ b/lzero/entry/train_muzero_with_gym_env.py @@ -27,7 +27,7 @@ def train_muzero_with_gym_env( """ Overview: The train entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero. - We create a gym environment using env_name parameter, and then convert it to the format required by LightZero using LightZeroEnvWrapper class. + We create a gym environment using env_id parameter, and then convert it to the format required by LightZero using LightZeroEnvWrapper class. Please refer to the get_wrappered_env method for more details. Arguments: - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. @@ -65,10 +65,10 @@ def train_muzero_with_gym_env( collector_env_cfg = DingEnvWrapper.create_collector_env_cfg(cfg.env) evaluator_env_cfg = DingEnvWrapper.create_evaluator_env_cfg(cfg.env) collector_env = BaseEnvManager( - [get_wrappered_env(c, cfg.env.env_name) for c in collector_env_cfg], cfg=BaseEnvManager.default_config() + [get_wrappered_env(c, cfg.env.env_id) for c in collector_env_cfg], cfg=BaseEnvManager.default_config() ) evaluator_env = BaseEnvManager( - [get_wrappered_env(c, cfg.env.env_name) for c in evaluator_env_cfg], cfg=BaseEnvManager.default_config() + [get_wrappered_env(c, cfg.env.env_id) for c in evaluator_env_cfg], cfg=BaseEnvManager.default_config() ) collector_env.seed(cfg.seed) evaluator_env.seed(cfg.seed, dynamic_seed=False) diff --git a/lzero/envs/get_wrapped_env.py b/lzero/envs/get_wrapped_env.py index 41e9262db..6422ec632 100644 --- a/lzero/envs/get_wrapped_env.py +++ b/lzero/envs/get_wrapped_env.py @@ -5,19 +5,19 @@ from lzero.envs.wrappers import ActionDiscretizationEnvWrapper, LightZeroEnvWrapper -def get_wrappered_env(wrapper_cfg: EasyDict, env_name: str): +def get_wrappered_env(wrapper_cfg: EasyDict, env_id: str): """ Overview: Returns a new environment with one or more wrappers applied to it. Arguments: - wrapper_cfg (:obj:`EasyDict`): A dictionary containing configuration settings for the wrappers. - - env_name (:obj:`str`): The name of the environment to create. + - env_id (:obj:`str`): The name of the environment to create. Returns: A callable that creates the wrapped environment. """ if wrapper_cfg.manually_discretization: return lambda: DingEnvWrapper( - gym.make(env_name), + gym.make(env_id), cfg={ 'env_wrapper': [ lambda env: ActionDiscretizationEnvWrapper(env, wrapper_cfg), lambda env: @@ -27,5 +27,5 @@ def get_wrappered_env(wrapper_cfg: EasyDict, env_name: str): ) else: return lambda: DingEnvWrapper( - gym.make(env_name), cfg={'env_wrapper': [lambda env: LightZeroEnvWrapper(env, wrapper_cfg)]} + gym.make(env_id), cfg={'env_wrapper': [lambda env: LightZeroEnvWrapper(env, wrapper_cfg)]} ) diff --git a/lzero/envs/tests/test_lightzero_env_wrapper.py b/lzero/envs/tests/test_lightzero_env_wrapper.py index 6440ef848..0ed365ef8 100644 --- a/lzero/envs/tests/test_lightzero_env_wrapper.py +++ b/lzero/envs/tests/test_lightzero_env_wrapper.py @@ -13,7 +13,7 @@ class TestLightZeroEnvWrapper: def test_continuous_pendulum(self): env_cfg = EasyDict( dict( - env_name='Pendulum-v1', + env_id='Pendulum-v1', manually_discretization=False, continuous=True, each_dim_disc_size=None, @@ -22,7 +22,7 @@ def test_continuous_pendulum(self): ) lightzero_env = DingEnvWrapper( - gym.make(env_cfg.env_name), cfg={'env_wrapper': [ + gym.make(env_cfg.env_id), cfg={'env_wrapper': [ lambda env: LightZeroEnvWrapper(env, env_cfg), ]} ) @@ -43,7 +43,7 @@ def test_continuous_pendulum(self): def test_discretization_pendulum(self): env_cfg = EasyDict( dict( - env_name='Pendulum-v1', + env_id='Pendulum-v1', manually_discretization=True, continuous=False, each_dim_disc_size=11, @@ -52,7 +52,7 @@ def test_discretization_pendulum(self): ) lightzero_env = DingEnvWrapper( - gym.make(env_cfg.env_name), + gym.make(env_cfg.env_id), cfg={ 'env_wrapper': [ lambda env: ActionDiscretizationEnvWrapper(env, env_cfg), @@ -77,7 +77,7 @@ def test_discretization_pendulum(self): def test_continuous_bipedalwalker(self): env_cfg = EasyDict( dict( - env_name='BipedalWalker-v3', + env_id='BipedalWalker-v3', manually_discretization=False, continuous=True, each_dim_disc_size=4, @@ -86,7 +86,7 @@ def test_continuous_bipedalwalker(self): ) lightzero_env = DingEnvWrapper( - gym.make(env_cfg.env_name), cfg={'env_wrapper': [ + gym.make(env_cfg.env_id), cfg={'env_wrapper': [ lambda env: LightZeroEnvWrapper(env, env_cfg), ]} ) @@ -107,7 +107,7 @@ def test_continuous_bipedalwalker(self): def test_discretization_bipedalwalker(self): env_cfg = EasyDict( dict( - env_name='BipedalWalker-v3', + env_id='BipedalWalker-v3', manually_discretization=True, continuous=False, each_dim_disc_size=4, @@ -116,7 +116,7 @@ def test_discretization_bipedalwalker(self): ) lightzero_env = DingEnvWrapper( - gym.make(env_cfg.env_name), + gym.make(env_cfg.env_id), cfg={ 'env_wrapper': [ lambda env: ActionDiscretizationEnvWrapper(env, env_cfg), diff --git a/lzero/envs/wrappers/action_discretization_env_wrapper.py b/lzero/envs/wrappers/action_discretization_env_wrapper.py index efd1e0fe5..433dc79aa 100644 --- a/lzero/envs/wrappers/action_discretization_env_wrapper.py +++ b/lzero/envs/wrappers/action_discretization_env_wrapper.py @@ -34,7 +34,7 @@ def __init__(self, env: gym.Env, cfg: EasyDict) -> None: assert 'is_train' in cfg, '`is_train` flag must set in the config of env' self.is_train = cfg.is_train self.cfg = cfg - self.env_name = cfg.env_name + self.env_id = cfg.env_id self.continuous = cfg.continuous def reset(self, **kwargs): diff --git a/lzero/envs/wrappers/lightzero_env_wrapper.py b/lzero/envs/wrappers/lightzero_env_wrapper.py index 82318eb05..47e37be3a 100644 --- a/lzero/envs/wrappers/lightzero_env_wrapper.py +++ b/lzero/envs/wrappers/lightzero_env_wrapper.py @@ -30,7 +30,7 @@ def __init__(self, env: gym.Env, cfg: EasyDict) -> None: assert 'is_train' in cfg, '`is_train` flag must set in the config of env' self.is_train = cfg.is_train self.cfg = cfg - self.env_name = cfg.env_name + self.env_id = cfg.env_id self.continuous = cfg.continuous def reset(self, **kwargs): diff --git a/lzero/mcts/tests/config/atari_efficientzero_config_for_test.py b/lzero/mcts/tests/config/atari_efficientzero_config_for_test.py index 1da966cc9..f751e8db2 100644 --- a/lzero/mcts/tests/config/atari_efficientzero_config_for_test.py +++ b/lzero/mcts/tests/config/atari_efficientzero_config_for_test.py @@ -1,6 +1,6 @@ from easydict import EasyDict -env_name = 'PongNoFrameskip-v4' +env_id = 'PongNoFrameskip-v4' action_space_size = 6 # ============================================================== @@ -21,7 +21,7 @@ atari_efficientzero_config = dict( exp_name='data_ez_ctree/efficientzero_seed0', env=dict( - env_name=env_name, + env_id=env_id, obs_shape=(4, 96, 96), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, diff --git a/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py b/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py index fbcc4a2de..7e6414895 100644 --- a/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py +++ b/lzero/mcts/tests/config/tictactoe_muzero_bot_mode_config_for_test.py @@ -23,7 +23,7 @@ evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), - env_name="TicTacToe", + env_id="TicTacToe", mcts_mode='self_play_mode', # only used in AlphaZero bot_action_type='v0', # {'v0', 'alpha_beta_pruning'} agent_vs_human=False, diff --git a/lzero/policy/alphazero.py b/lzero/policy/alphazero.py index e1b96ee64..4b57534a4 100644 --- a/lzero/policy/alphazero.py +++ b/lzero/policy/alphazero.py @@ -334,7 +334,7 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]: return output def _get_simulation_env(self): - if self._cfg.simulation_env_name == 'tictactoe': + if self._cfg.simulation_env_id == 'tictactoe': from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv if self._cfg.simulation_env_config_type == 'play_with_bot': from zoo.board_games.tictactoe.config.tictactoe_alphazero_bot_mode_config import \ @@ -346,7 +346,7 @@ def _get_simulation_env(self): raise NotImplementedError self.simulate_env = TicTacToeEnv(tictactoe_alphazero_config.env) - elif self._cfg.simulation_env_name == 'gomoku': + elif self._cfg.simulation_env_id == 'gomoku': from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv if self._cfg.simulation_env_config_type == 'play_with_bot': from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import gomoku_alphazero_config @@ -355,7 +355,7 @@ def _get_simulation_env(self): else: raise NotImplementedError self.simulate_env = GomokuEnv(gomoku_alphazero_config.env) - elif self._cfg.simulation_env_name == 'connect4': + elif self._cfg.simulation_env_id == 'connect4': from zoo.board_games.connect4.envs.connect4_env import Connect4Env if self._cfg.simulation_env_config_type == 'play_with_bot': from zoo.board_games.connect4.config.connect4_alphazero_bot_mode_config import connect4_alphazero_config diff --git a/lzero/policy/sampled_alphazero.py b/lzero/policy/sampled_alphazero.py index 31a445fd8..6c1dbf708 100644 --- a/lzero/policy/sampled_alphazero.py +++ b/lzero/policy/sampled_alphazero.py @@ -460,10 +460,10 @@ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]: return output def _get_simulation_env(self): - assert self._cfg.simulation_env_name in ['tictactoe', 'gomoku', 'go'], self._cfg.simulation_env_name + assert self._cfg.simulation_env_id in ['tictactoe', 'gomoku', 'go'], self._cfg.simulation_env_id assert self._cfg.simulation_env_config_type in ['play_with_bot', 'self_play', 'league', 'sampled_play_with_bot'], self._cfg.simulation_env_config_type - if self._cfg.simulation_env_name == 'tictactoe': + if self._cfg.simulation_env_id == 'tictactoe': from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv if self._cfg.simulation_env_config_type == 'play_with_bot': from zoo.board_games.tictactoe.config.tictactoe_alphazero_bot_mode_config import \ @@ -480,7 +480,7 @@ def _get_simulation_env(self): self.simulate_env = TicTacToeEnv(tictactoe_alphazero_config.env) - elif self._cfg.simulation_env_name == 'gomoku': + elif self._cfg.simulation_env_id == 'gomoku': from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv if self._cfg.simulation_env_config_type == 'play_with_bot': from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import gomoku_alphazero_config @@ -493,7 +493,7 @@ def _get_simulation_env(self): gomoku_sampled_alphazero_config as gomoku_alphazero_config self.simulate_env = GomokuEnv(gomoku_alphazero_config.env) - elif self._cfg.simulation_env_name == 'go': + elif self._cfg.simulation_env_id == 'go': from zoo.board_games.go.envs.go_env import GoEnv if self._cfg.simulation_env_config_type == 'play_with_bot': from zoo.board_games.go.config.go_alphazero_bot_mode_config import go_alphazero_config diff --git a/lzero/policy/tests/config/atari_muzero_config_for_test.py b/lzero/policy/tests/config/atari_muzero_config_for_test.py index 05c64d419..18c45cbc7 100644 --- a/lzero/policy/tests/config/atari_muzero_config_for_test.py +++ b/lzero/policy/tests/config/atari_muzero_config_for_test.py @@ -1,16 +1,16 @@ from easydict import EasyDict -env_name = 'PongNoFrameskip-v4' +env_id = 'PongNoFrameskip-v4' -if env_name == 'PongNoFrameskip-v4': +if env_id == 'PongNoFrameskip-v4': action_space_size = 6 -elif env_name == 'QbertNoFrameskip-v4': +elif env_id == 'QbertNoFrameskip-v4': action_space_size = 6 -elif env_name == 'MsPacmanNoFrameskip-v4': +elif env_id == 'MsPacmanNoFrameskip-v4': action_space_size = 9 -elif env_name == 'SpaceInvadersNoFrameskip-v4': +elif env_id == 'SpaceInvadersNoFrameskip-v4': action_space_size = 6 -elif env_name == 'BreakoutNoFrameskip-v4': +elif env_id == 'BreakoutNoFrameskip-v4': action_space_size = 4 # ============================================================== @@ -31,10 +31,10 @@ atari_muzero_config = dict( exp_name= - f'data_mz_ctree/{env_name[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_mz_ctree/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( stop_value=int(1e6), - env_name=env_name, + env_id=env_id, obs_shape=(4, 96, 96), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, diff --git a/lzero/policy/tests/config/cartpole_muzero_config_for_test.py b/lzero/policy/tests/config/cartpole_muzero_config_for_test.py index b7584d19f..deaf1fddd 100644 --- a/lzero/policy/tests/config/cartpole_muzero_config_for_test.py +++ b/lzero/policy/tests/config/cartpole_muzero_config_for_test.py @@ -18,7 +18,7 @@ cartpole_muzero_config = dict( exp_name=f'data_mz_ctree/cartpole_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='CartPole-v0', + env_id='CartPole-v0', continuous=False, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/requirements.txt b/requirements.txt index c4f701bd6..5c4ca8c33 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ DI-engine>=0.4.7 gymnasium[atari] +moviepy numpy>=1.22.4 pympler bsuite diff --git a/zoo/atari/config/atari_efficientzero_config.py b/zoo/atari/config/atari_efficientzero_config.py index aedd1edbd..1db811387 100644 --- a/zoo/atari/config/atari_efficientzero_config.py +++ b/zoo/atari/config/atari_efficientzero_config.py @@ -1,17 +1,17 @@ from easydict import EasyDict # options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} -env_name = 'PongNoFrameskip-v4' +env_id = 'PongNoFrameskip-v4' -if env_name == 'PongNoFrameskip-v4': +if env_id == 'PongNoFrameskip-v4': action_space_size = 6 -elif env_name == 'QbertNoFrameskip-v4': +elif env_id == 'QbertNoFrameskip-v4': action_space_size = 6 -elif env_name == 'MsPacmanNoFrameskip-v4': +elif env_id == 'MsPacmanNoFrameskip-v4': action_space_size = 9 -elif env_name == 'SpaceInvadersNoFrameskip-v4': +elif env_id == 'SpaceInvadersNoFrameskip-v4': action_space_size = 6 -elif env_name == 'BreakoutNoFrameskip-v4': +elif env_id == 'BreakoutNoFrameskip-v4': action_space_size = 4 # ============================================================== @@ -33,9 +33,9 @@ atari_efficientzero_config = dict( exp_name= - f'data_ez_ctree/{env_name[:-14]}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_ez_ctree/{env_id[:-14]}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name=env_name, + env_id=env_id, obs_shape=(4, 96, 96), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, diff --git a/zoo/atari/config/atari_efficientzero_multigpu_ddp_config.py b/zoo/atari/config/atari_efficientzero_multigpu_ddp_config.py index 876b8c012..633e471b3 100644 --- a/zoo/atari/config/atari_efficientzero_multigpu_ddp_config.py +++ b/zoo/atari/config/atari_efficientzero_multigpu_ddp_config.py @@ -1,17 +1,17 @@ from easydict import EasyDict # options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} -env_name = 'PongNoFrameskip-v4' +env_id = 'PongNoFrameskip-v4' -if env_name == 'PongNoFrameskip-v4': +if env_id == 'PongNoFrameskip-v4': action_space_size = 6 -elif env_name == 'QbertNoFrameskip-v4': +elif env_id == 'QbertNoFrameskip-v4': action_space_size = 6 -elif env_name == 'MsPacmanNoFrameskip-v4': +elif env_id == 'MsPacmanNoFrameskip-v4': action_space_size = 9 -elif env_name == 'SpaceInvadersNoFrameskip-v4': +elif env_id == 'SpaceInvadersNoFrameskip-v4': action_space_size = 6 -elif env_name == 'BreakoutNoFrameskip-v4': +elif env_id == 'BreakoutNoFrameskip-v4': action_space_size = 4 # ============================================================== @@ -42,9 +42,9 @@ atari_efficientzero_config = dict( exp_name= - f'data_ez_ctree/{env_name[:-14]}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_ddp_{gpu_num}gpu_seed0', + f'data_ez_ctree/{env_id[:-14]}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_ddp_{gpu_num}gpu_seed0', env=dict( - env_name=env_name, + env_id=env_id, obs_shape=(4, 96, 96), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, @@ -125,6 +125,6 @@ with DDPContext(): # Each iteration uses a different seed for training # Change exp_name according to current seed - main_config.exp_name = f'data_ez_ctree/{env_name[:-14]}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_ddp_{gpu_num}gpu_seed{seed}' + main_config.exp_name = f'data_ez_ctree/{env_id[:-14]}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_ddp_{gpu_num}gpu_seed{seed}' main_config = lz_to_ddp_config(main_config) train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step) \ No newline at end of file diff --git a/zoo/atari/config/atari_gumbel_muzero_config.py b/zoo/atari/config/atari_gumbel_muzero_config.py index 4d39eeeda..918a93236 100644 --- a/zoo/atari/config/atari_gumbel_muzero_config.py +++ b/zoo/atari/config/atari_gumbel_muzero_config.py @@ -1,17 +1,17 @@ from easydict import EasyDict # options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} -env_name = 'PongNoFrameskip-v4' +env_id = 'PongNoFrameskip-v4' -if env_name == 'PongNoFrameskip-v4': +if env_id == 'PongNoFrameskip-v4': action_space_size = 6 -elif env_name == 'QbertNoFrameskip-v4': +elif env_id == 'QbertNoFrameskip-v4': action_space_size = 6 -elif env_name == 'MsPacmanNoFrameskip-v4': +elif env_id == 'MsPacmanNoFrameskip-v4': action_space_size = 9 -elif env_name == 'SpaceInvadersNoFrameskip-v4': +elif env_id == 'SpaceInvadersNoFrameskip-v4': action_space_size = 6 -elif env_name == 'BreakoutNoFrameskip-v4': +elif env_id == 'BreakoutNoFrameskip-v4': action_space_size = 4 # ============================================================== @@ -31,10 +31,10 @@ atari_gumbel_muzero_config = dict( exp_name= - f'data_mz_ctree/{env_name[:-14]}_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_mz_ctree/{env_id[:-14]}_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( stop_value=int(1e6), - env_name=env_name, + env_id=env_id, obs_shape=(4, 96, 96), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, diff --git a/zoo/atari/config/atari_muzero_config.py b/zoo/atari/config/atari_muzero_config.py index 64ed37eaa..d6225b21e 100644 --- a/zoo/atari/config/atari_muzero_config.py +++ b/zoo/atari/config/atari_muzero_config.py @@ -1,17 +1,17 @@ from easydict import EasyDict # options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} -env_name = 'PongNoFrameskip-v4' +env_id = 'PongNoFrameskip-v4' -if env_name == 'PongNoFrameskip-v4': +if env_id == 'PongNoFrameskip-v4': action_space_size = 6 -elif env_name == 'QbertNoFrameskip-v4': +elif env_id == 'QbertNoFrameskip-v4': action_space_size = 6 -elif env_name == 'MsPacmanNoFrameskip-v4': +elif env_id == 'MsPacmanNoFrameskip-v4': action_space_size = 9 -elif env_name == 'SpaceInvadersNoFrameskip-v4': +elif env_id == 'SpaceInvadersNoFrameskip-v4': action_space_size = 6 -elif env_name == 'BreakoutNoFrameskip-v4': +elif env_id == 'BreakoutNoFrameskip-v4': action_space_size = 4 # ============================================================== @@ -32,10 +32,10 @@ atari_muzero_config = dict( exp_name= - f'data_mz_ctree/{env_name[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_mz_ctree/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( stop_value=int(1e6), - env_name=env_name, + env_id=env_id, obs_shape=(4, 96, 96), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, diff --git a/zoo/atari/config/atari_muzero_multigpu_ddp_config.py b/zoo/atari/config/atari_muzero_multigpu_ddp_config.py index b632ba7e0..91ef9b188 100644 --- a/zoo/atari/config/atari_muzero_multigpu_ddp_config.py +++ b/zoo/atari/config/atari_muzero_multigpu_ddp_config.py @@ -1,17 +1,17 @@ from easydict import EasyDict # options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} -env_name = 'PongNoFrameskip-v4' +env_id = 'PongNoFrameskip-v4' -if env_name == 'PongNoFrameskip-v4': +if env_id == 'PongNoFrameskip-v4': action_space_size = 6 -elif env_name == 'QbertNoFrameskip-v4': +elif env_id == 'QbertNoFrameskip-v4': action_space_size = 6 -elif env_name == 'MsPacmanNoFrameskip-v4': +elif env_id == 'MsPacmanNoFrameskip-v4': action_space_size = 9 -elif env_name == 'SpaceInvadersNoFrameskip-v4': +elif env_id == 'SpaceInvadersNoFrameskip-v4': action_space_size = 6 -elif env_name == 'BreakoutNoFrameskip-v4': +elif env_id == 'BreakoutNoFrameskip-v4': action_space_size = 4 # ============================================================== @@ -33,10 +33,10 @@ atari_muzero_config = dict( exp_name= - f'data_mz_ctree/{env_name[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_ddp_{gpu_num}gpu_seed0', + f'data_mz_ctree/{env_id[:-14]}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_ddp_{gpu_num}gpu_seed0', env=dict( stop_value=int(1e6), - env_name=env_name, + env_id=env_id, obs_shape=(4, 96, 96), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, diff --git a/zoo/atari/config/atari_sampled_efficientzero_config.py b/zoo/atari/config/atari_sampled_efficientzero_config.py index 682f85009..442cd54a0 100644 --- a/zoo/atari/config/atari_sampled_efficientzero_config.py +++ b/zoo/atari/config/atari_sampled_efficientzero_config.py @@ -1,17 +1,17 @@ from easydict import EasyDict # options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} -env_name = 'PongNoFrameskip-v4' +env_id = 'PongNoFrameskip-v4' -if env_name == 'PongNoFrameskip-v4': +if env_id == 'PongNoFrameskip-v4': action_space_size = 6 -elif env_name == 'QbertNoFrameskip-v4': +elif env_id == 'QbertNoFrameskip-v4': action_space_size = 6 -elif env_name == 'MsPacmanNoFrameskip-v4': +elif env_id == 'MsPacmanNoFrameskip-v4': action_space_size = 9 -elif env_name == 'SpaceInvadersNoFrameskip-v4': +elif env_id == 'SpaceInvadersNoFrameskip-v4': action_space_size = 6 -elif env_name == 'BreakoutNoFrameskip-v4': +elif env_id == 'BreakoutNoFrameskip-v4': action_space_size = 4 # ============================================================== @@ -33,9 +33,9 @@ atari_sampled_efficientzero_config = dict( exp_name= - f'data_sez_ctree/{env_name[:-14]}_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_sez_ctree/{env_id[:-14]}_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name=env_name, + env_id=env_id, obs_shape=(4, 96, 96), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, diff --git a/zoo/atari/config/atari_stochastic_muzero_config.py b/zoo/atari/config/atari_stochastic_muzero_config.py index 91dfe45c7..ab674c27d 100644 --- a/zoo/atari/config/atari_stochastic_muzero_config.py +++ b/zoo/atari/config/atari_stochastic_muzero_config.py @@ -1,17 +1,17 @@ from easydict import EasyDict # options={'PongNoFrameskip-v4', 'QbertNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SpaceInvadersNoFrameskip-v4', 'BreakoutNoFrameskip-v4', ...} -env_name = 'PongNoFrameskip-v4' +env_id = 'PongNoFrameskip-v4' -if env_name == 'PongNoFrameskip-v4': +if env_id == 'PongNoFrameskip-v4': action_space_size = 6 -elif env_name == 'QbertNoFrameskip-v4': +elif env_id == 'QbertNoFrameskip-v4': action_space_size = 6 -elif env_name == 'MsPacmanNoFrameskip-v4': +elif env_id == 'MsPacmanNoFrameskip-v4': action_space_size = 9 -elif env_name == 'SpaceInvadersNoFrameskip-v4': +elif env_id == 'SpaceInvadersNoFrameskip-v4': action_space_size = 6 -elif env_name == 'BreakoutNoFrameskip-v4': +elif env_id == 'BreakoutNoFrameskip-v4': action_space_size = 4 # ============================================================== @@ -43,10 +43,10 @@ atari_stochastic_muzero_config = dict( exp_name= - f'data_stochastic_mz_ctree/{env_name[:-14]}_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_chance{chance_space_size}_seed0', + f'data_stochastic_mz_ctree/{env_id[:-14]}_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_chance{chance_space_size}_seed0', env=dict( stop_value=int(1e6), - env_name=env_name, + env_id=env_id, obs_shape=(4, 96, 96), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index e2b19f7be..34dacce4a 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -31,7 +31,7 @@ class AtariLightZeroEnv(BaseEnv): # (int) The number of episodes to evaluate during each evaluation period. n_evaluator_episode=3, # (str) The name of the Atari game environment. - env_name='PongNoFrameskip-v4', + # env_id='PongNoFrameskip-v4', # (str) The type of the environment, here it's Atari. env_type='Atari', # (tuple) The shape of the observation space, which is a stacked frame of 4 images each of 96x96 pixels. @@ -211,7 +211,7 @@ def reward_space(self) -> gym.spaces.Space: return self._reward_space def __repr__(self) -> str: - return "LightZero Atari Env({})".format(self.cfg.env_name) + return "LightZero Atari Env({})".format(self.cfg.env_id) @staticmethod def create_collector_env_cfg(cfg: dict) -> List[dict]: diff --git a/zoo/atari/envs/atari_wrappers.py b/zoo/atari/envs/atari_wrappers.py index 1ac63122e..00ec07ca3 100644 --- a/zoo/atari/envs/atari_wrappers.py +++ b/zoo/atari/envs/atari_wrappers.py @@ -91,11 +91,12 @@ def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) -> - env (:obj:`gym.Env`): The wrapped Atari environment with the given configurations. """ if config.render_mode_human: - env = gymnasium.make(config.env_name, render_mode='human') + env = gymnasium.make(config.env_id, render_mode='human') else: - env = gymnasium.make(config.env_name, render_mode='rgb_array') + env = gymnasium.make(config.env_id, render_mode='rgb_array') assert 'NoFrameskip' in env.spec.id - if config.save_replay: + if hasattr(config, 'save_replay') and config.save_replay \ + and hasattr(config, 'replay_path') and config.replay_path is not None: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") video_name = f'{env.spec.id}-video-{timestamp}' env = RecordVideo( diff --git a/zoo/atari/tests/test_atari_lightzero_env.py b/zoo/atari/tests/test_atari_lightzero_env.py index c3252e829..258f0d7b8 100644 --- a/zoo/atari/tests/test_atari_lightzero_env.py +++ b/zoo/atari/tests/test_atari_lightzero_env.py @@ -6,7 +6,7 @@ collector_env_num=8, evaluator_env_num=3, n_evaluator_episode=3, - env_name='PongNoFrameskip-v4', + env_id='PongNoFrameskip-v4', env_type='Atari', obs_shape=(4, 96, 96), collect_max_episode_steps=int(1.08e5), diff --git a/zoo/atari/tests/test_atari_sampled_efficientzero_config.py b/zoo/atari/tests/test_atari_sampled_efficientzero_config.py index 122a041f9..2c76943cb 100644 --- a/zoo/atari/tests/test_atari_sampled_efficientzero_config.py +++ b/zoo/atari/tests/test_atari_sampled_efficientzero_config.py @@ -1,16 +1,16 @@ from easydict import EasyDict -env_name = 'BreakoutNoFrameskip-v4' +env_id = 'BreakoutNoFrameskip-v4' -if env_name == 'PongNoFrameskip-v4': +if env_id == 'PongNoFrameskip-v4': action_space_size = 6 -elif env_name == 'QbertNoFrameskip-v4': +elif env_id == 'QbertNoFrameskip-v4': action_space_size = 6 -elif env_name == 'MsPacmanNoFrameskip-v4': +elif env_id == 'MsPacmanNoFrameskip-v4': action_space_size = 9 -elif env_name == 'SpaceInvadersNoFrameskip-v4': +elif env_id == 'SpaceInvadersNoFrameskip-v4': action_space_size = 6 -elif env_name == 'BreakoutNoFrameskip-v4': +elif env_id == 'BreakoutNoFrameskip-v4': action_space_size = 4 # ============================================================== @@ -32,9 +32,9 @@ atari_sampled_efficientzero_config = dict( exp_name= - f'data_sez_ctree/{env_name[:-14]}_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', + f'data_sez_ctree/{env_id[:-14]}_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name=env_name, + env_id=env_id, obs_shape=(4, 96, 96), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, diff --git a/zoo/board_games/connect4/config/connect4_alphazero_bot_mode_config.py b/zoo/board_games/connect4/config/connect4_alphazero_bot_mode_config.py index e0eae45a7..806fe07bc 100644 --- a/zoo/board_games/connect4/config/connect4_alphazero_bot_mode_config.py +++ b/zoo/board_games/connect4/config/connect4_alphazero_bot_mode_config.py @@ -42,7 +42,7 @@ mcts_ctree=mcts_ctree, # ============================================================== # for the creation of simulation env - simulation_env_name='connect4', + simulation_env_id='connect4', simulation_env_config_type='play_with_bot', # ============================================================== model=dict( diff --git a/zoo/board_games/connect4/config/connect4_alphazero_sp_mode_config.py b/zoo/board_games/connect4/config/connect4_alphazero_sp_mode_config.py index 6ef7a8f3a..31ab963d1 100644 --- a/zoo/board_games/connect4/config/connect4_alphazero_sp_mode_config.py +++ b/zoo/board_games/connect4/config/connect4_alphazero_sp_mode_config.py @@ -43,7 +43,7 @@ mcts_ctree=mcts_ctree, # ============================================================== # for the creation of simulation env - simulation_env_name='connect4', + simulation_env_id='connect4', simulation_env_config_type='self_play', # ============================================================== model=dict( diff --git a/zoo/board_games/connect4/envs/connect4_env.py b/zoo/board_games/connect4/envs/connect4_env.py index 259d1346b..f8a46431b 100644 --- a/zoo/board_games/connect4/envs/connect4_env.py +++ b/zoo/board_games/connect4/envs/connect4_env.py @@ -51,7 +51,7 @@ class Connect4Env(BaseEnv): config = dict( # (str) The name of the environment registered in the environment registry. - env_name="Connect4", + env_id="Connect4", # (str) The mode of the environment when take a step. battle_mode='self_play_mode', # (str) The mode of the environment when doing the MCTS. diff --git a/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py index 90c1a9d7c..c0606cfdb 100644 --- a/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py @@ -45,7 +45,7 @@ mcts_ctree=mcts_ctree, # ============================================================== # for the creation of simulation env - simulation_env_name='gomoku', + simulation_env_id='gomoku', simulation_env_config_type='play_with_bot', # ============================================================== torch_compile=False, diff --git a/zoo/board_games/gomoku/config/gomoku_alphazero_sp_mode_config.py b/zoo/board_games/gomoku/config/gomoku_alphazero_sp_mode_config.py index 7a8e086e8..f93831567 100644 --- a/zoo/board_games/gomoku/config/gomoku_alphazero_sp_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_alphazero_sp_mode_config.py @@ -45,7 +45,7 @@ mcts_ctree=mcts_ctree, # ============================================================== # for the creation of simulation env - simulation_env_name='gomoku', + simulation_env_id='gomoku', simulation_env_config_type='self_play', # ============================================================== torch_compile=False, diff --git a/zoo/board_games/gomoku/config/gomoku_sampled_alphazero_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_sampled_alphazero_bot_mode_config.py index e0f5e72c3..b2ceb9a5f 100644 --- a/zoo/board_games/gomoku/config/gomoku_sampled_alphazero_bot_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_sampled_alphazero_bot_mode_config.py @@ -42,7 +42,7 @@ prob_expert_agent=0, scale=True, check_action_to_connect4_in_bot_v0=False, - simulation_env_name="gomoku", + simulation_env_id="gomoku", screen_scaling=9, render_mode=None, replay_path=None, @@ -52,7 +52,7 @@ policy=dict( # ============================================================== # for the creation of simulation env - simulation_env_name='gomoku', + simulation_env_id='gomoku', simulation_env_config_type='sampled_play_with_bot', # ============================================================== torch_compile=False, diff --git a/zoo/board_games/gomoku/config/gomoku_sampled_alphazero_sp_mode_config.py b/zoo/board_games/gomoku/config/gomoku_sampled_alphazero_sp_mode_config.py index ccf1df378..4e1820f7f 100644 --- a/zoo/board_games/gomoku/config/gomoku_sampled_alphazero_sp_mode_config.py +++ b/zoo/board_games/gomoku/config/gomoku_sampled_alphazero_sp_mode_config.py @@ -38,7 +38,7 @@ prob_expert_agent=0, scale=True, check_action_to_connect4_in_bot_v0=False, - simulation_env_name="gomoku", + simulation_env_id="gomoku", screen_scaling=9, render_mode=None, replay_path=None, @@ -48,7 +48,7 @@ policy=dict( # ============================================================== # for the creation of simulation env - simulation_env_name='gomoku', + simulation_env_id='gomoku', simulation_env_config_type='sampled_self_play', # ============================================================== torch_compile=False, diff --git a/zoo/board_games/gomoku/envs/gomoku_env.py b/zoo/board_games/gomoku/envs/gomoku_env.py index d31da95e3..e1347c337 100644 --- a/zoo/board_games/gomoku/envs/gomoku_env.py +++ b/zoo/board_games/gomoku/envs/gomoku_env.py @@ -1,11 +1,14 @@ import copy import os import sys +from datetime import datetime from functools import lru_cache from typing import List, Any import gymnasium as gym import imageio +import matplotlib +matplotlib.use('Agg') import matplotlib.patches as patches import matplotlib.pyplot as plt import numpy as np @@ -54,7 +57,7 @@ class GomokuEnv(BaseEnv): config = dict( # (str) The name of the environment registered in the environment registry. - env_name="Gomoku", + env_id="Gomoku", # (int) The size of the board. board_size=6, # (str) The mode of the environment when take a step. @@ -136,9 +139,11 @@ def __init__(self, cfg: dict = None): self.screen_scaling = cfg.screen_scaling # options = {None, 'state_realtime_mode', 'image_realtime_mode', 'image_savefile_mode'} self.render_mode = cfg.render_mode - self.replay_name_suffix = "test" + assert self.render_mode in [None, 'state_realtime_mode', 'image_realtime_mode', 'image_savefile_mode'] + self.replay_name_suffix = "" if hasattr(cfg, 'replay_name_suffix') is False else cfg.replay_name_suffix self.replay_path = cfg.replay_path - self.replay_format = 'gif' # 'mp4' # + self.replay_format = 'gif' if hasattr(cfg, 'replay_format') is False else cfg.replay_format + assert self.replay_format in ['gif', 'mp4'] self.screen = None self.frames = [] @@ -158,6 +163,7 @@ def __init__(self, cfg: dict = None): # plt is not work in mcts_ctree mode self.fig, self.ax = plt.subplots(figsize=(self.board_size, self.board_size)) plt.ion() + self._save_replay_count = 0 def reset(self, start_player_index=0, init_state=None, katago_policy_init=False, katago_game_state=None): """ @@ -338,7 +344,7 @@ def _player_step(self, action): if done: info['eval_episode_return'] = reward self._env.render(self.render_mode) - if self.render_mode == 'image_savefile_mode': + if self.render_mode == 'image_savefile_mode' and self.replay_path is not None: self.save_render_output(replay_name_suffix=self.replay_name_suffix, replay_path=self.replay_path, format=self.replay_format) @@ -598,7 +604,7 @@ def render(self, mode="state_realtime_mode"): print(np.array(self.board).reshape(self.board_size, self.board_size)) return # Render the game as an image - elif mode == "image_realtime_mode" or mode == "image_savefile_mode": + elif mode == "image_realtime_mode" or (mode == "image_savefile_mode" and self.replay_path is not None): self.draw_board() # Draw the pieces on the board for x in range(self.board_size): @@ -723,13 +729,28 @@ def save_render_output(self, replay_name_suffix: str = '', replay_path: str = No - replay_path (:obj:`str`): The path to save the replay file. If None, the default filename will be used. - format (:obj:`str`): The format of the output file. Options are 'gif' or 'mp4'. """ + timestamp = datetime.now().strftime("%Y%m%d%H%M%S") + # At the end of the episode, save the frames. - if replay_path is None: - filename = f'gomoku_{self.board_size}_{replay_name_suffix}.{format}' + if replay_name_suffix == '': + if replay_path is None: + filename = f'gomoku_{self.board_size}_{os.getpid()}_{timestamp}.{format}' + else: + if not os.path.exists(replay_path): + os.makedirs(replay_path) + filename = os.path.join( + replay_path, + f'gomoku_{self.board_size}_{os.getpid()}_{timestamp}.{format}' + ) else: - if not os.path.exists(replay_path): - os.makedirs(replay_path) - filename = replay_path+f'/gomoku_{self.board_size}_{replay_name_suffix}.{format}' + if replay_path is None: + filename = f'gomoku_{self.board_size}_{replay_name_suffix}.{format}' + else: + if not os.path.exists(replay_path): + os.makedirs(replay_path) + filename = replay_path+f'/gomoku_{self.board_size}_{replay_name_suffix}.{format}' + + self._save_replay_count += 1 if format == 'gif': # Save frames as a GIF with a duration of 0.1 seconds per frame. @@ -737,7 +758,8 @@ def save_render_output(self, replay_name_suffix: str = '', replay_path: str = No imageio.mimsave(filename, self.frames, 'GIF', fps=30, subrectangles=True) elif format == 'mp4': # Save frames as an MP4 video with a frame rate of 30 frames per second. - imageio.mimsave(filename, self.frames, fps=30, codec='mpeg4') + # imageio.mimsave(filename, self.frames, fps=30, codec='mpeg4') + imageio.mimwrite(filename, self.frames, fps=30) else: raise ValueError("Unsupported format: {}".format(format)) diff --git a/zoo/board_games/gomoku/test/test_gomoku_env_legal_actions.py b/zoo/board_games/gomoku/test/test_gomoku_env_legal_actions.py index bb9ea4b22..37cc14428 100644 --- a/zoo/board_games/gomoku/test/test_gomoku_env_legal_actions.py +++ b/zoo/board_games/gomoku/test/test_gomoku_env_legal_actions.py @@ -21,7 +21,7 @@ def test_self_play_mode(self): prob_random_action_in_bot=0., check_action_to_connect4_in_bot_v0=False, prob_expert_agent=0, - simulation_env_name="gomoku", + simulation_env_id="gomoku", screen_scaling=9, render_mode=None, replay_path=None, diff --git a/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_config.py index 4c4c1e2a1..808e5cfb2 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_config.py @@ -42,7 +42,7 @@ mcts_ctree=mcts_ctree, # ============================================================== # for the creation of simulation env - simulation_env_name='tictactoe', + simulation_env_id='tictactoe', simulation_env_config_type='play_with_bot', # ============================================================== model=dict( diff --git a/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_multigpu_ddp_config.py b/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_multigpu_ddp_config.py index 6f2f9f7be..c94394474 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_multigpu_ddp_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_multigpu_ddp_config.py @@ -37,7 +37,7 @@ policy=dict( # ============================================================== # for the creation of simulation env - simulation_env_name='tictactoe', + simulation_env_id='tictactoe', simulation_env_config_type='play_with_bot', # ============================================================== model=dict( diff --git a/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_config.py index f74b10492..9290debee 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_config.py @@ -40,7 +40,7 @@ mcts_ctree=mcts_ctree, # ============================================================== # for the creation of simulation env - simulation_env_name='tictactoe', + simulation_env_id='tictactoe', simulation_env_config_type='self_play', # ============================================================== model=dict( diff --git a/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_multigpu_ddp_config.py b/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_multigpu_ddp_config.py index dd324cb65..d6edd3274 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_multigpu_ddp_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_multigpu_ddp_config.py @@ -36,7 +36,7 @@ policy=dict( # ============================================================== # for the creation of simulation env - simulation_env_name='tictactoe', + simulation_env_id='tictactoe', simulation_env_config_type='self_play', # ============================================================== model=dict( diff --git a/zoo/board_games/tictactoe/config/tictactoe_sampled_alphazero_bot_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_sampled_alphazero_bot_mode_config.py index baecaaeb8..3bf11d4ef 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_sampled_alphazero_bot_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_sampled_alphazero_bot_mode_config.py @@ -41,7 +41,7 @@ policy=dict( # ============================================================== # for the creation of simulation env - simulation_env_name='tictactoe', + simulation_env_id='tictactoe', simulation_env_config_type='play_with_bot', # ============================================================== model=dict( diff --git a/zoo/board_games/tictactoe/config/tictactoe_sampled_alphazero_sp_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_sampled_alphazero_sp_mode_config.py index 2491330fb..69f99cd93 100644 --- a/zoo/board_games/tictactoe/config/tictactoe_sampled_alphazero_sp_mode_config.py +++ b/zoo/board_games/tictactoe/config/tictactoe_sampled_alphazero_sp_mode_config.py @@ -40,7 +40,7 @@ policy=dict( # ============================================================== # for the creation of simulation env - simulation_env_name='tictactoe', + simulation_env_id='tictactoe', simulation_env_config_type='self_play', # ============================================================== model=dict( diff --git a/zoo/board_games/tictactoe/entry/tictactoe_alphazero_eval.py b/zoo/board_games/tictactoe/entry/tictactoe_alphazero_eval.py index 1c6de7850..b9f401527 100644 --- a/zoo/board_games/tictactoe/entry/tictactoe_alphazero_eval.py +++ b/zoo/board_games/tictactoe/entry/tictactoe_alphazero_eval.py @@ -23,8 +23,7 @@ num_episodes_each_seed = 1 # Enable saving of replay as a gif, specify the path to save the replay gif - main_config.env.save_replay_gif = True - main_config.env.replay_path_gif = './video' + main_config.env.replay_path = './video' main_config.policy.mcts_ctree = False # If True, you can play with the agent. diff --git a/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval.py b/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval.py index f2bebb0bd..d99bf86d4 100644 --- a/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval.py +++ b/zoo/board_games/tictactoe/entry/tictactoe_muzero_eval.py @@ -31,8 +31,7 @@ total_test_episodes = num_episodes_each_seed * len(seeds) # Enable saving of replay as a gif, specify the path to save the replay gif - main_config.env.save_replay_gif = True - main_config.env.replay_path_gif = './video' + main_config.env.replay_path = './video' returns_mean_seeds = [] returns_seeds = [] diff --git a/zoo/board_games/tictactoe/envs/tictactoe_env.py b/zoo/board_games/tictactoe/envs/tictactoe_env.py index 5de3e5915..b01ade517 100644 --- a/zoo/board_games/tictactoe/envs/tictactoe_env.py +++ b/zoo/board_games/tictactoe/envs/tictactoe_env.py @@ -40,18 +40,16 @@ def _get_done_winner_func_lru(board_tuple): class TicTacToeEnv(BaseEnv): config = dict( - # env_name (str): The name of the environment. - env_name="TicTacToe", + # env_id (str): The name of the environment. + env_id="TicTacToe", # battle_mode (str): The mode of the battle. Choices are 'self_play_mode' or 'alpha_beta_pruning'. battle_mode='self_play_mode', # battle_mode_in_simulation_env (str): The mode of Monte Carlo Tree Search. This is only used in AlphaZero. battle_mode_in_simulation_env='self_play_mode', # bot_action_type (str): The type of action the bot should take. Choices are 'v0' or 'alpha_beta_pruning'. bot_action_type='v0', - # save_replay_gif (bool): If True, the replay will be saved as a gif file. - save_replay_gif=False, - # replay_path_gif (str): The path to save the replay gif. - replay_path_gif='./replay_gif', + # replay_path (str): The folder path where replay video saved, if None, will not save replay video. + replay_path=None, # agent_vs_human (bool): If True, the agent will play against a human. agent_vs_human=False, # prob_random_agent (int): The probability of the random agent. @@ -97,8 +95,7 @@ def __init__(self, cfg=None): if 'alpha_beta_pruning' in self.bot_action_type: self.alpha_beta_pruning_player = AlphaBetaPruningBot(self, cfg, 'alpha_beta_pruning_player') self.alphazero_mcts_ctree = cfg.alphazero_mcts_ctree - self._replay_path_gif = cfg.replay_path_gif - self._save_replay_gif = cfg.save_replay_gif + self._replay_path = cfg.replay_path if hasattr(cfg, "replay_path") and cfg.replay_path is not None else None self._save_replay_count = 0 @property @@ -192,7 +189,7 @@ def reset(self, start_player_index=0, init_state=None, katago_policy_init=False, 'current_player_index': self.start_player_index, 'to_play': self.current_player } - if self._save_replay_gif: + if self._replay_path is not None: self._frames = [] return obs @@ -254,7 +251,7 @@ def step(self, action): # player 1 battle with expert player 2 # player 1's turn - if self._save_replay_gif: + if self._replay_path is not None: self._frames.append(self._env.render(mode='rgb_array')) timestep_player1 = self._player_step(action) # self.env.render() @@ -263,16 +260,16 @@ def step(self, action): # And the to_play is used in MCTS. timestep_player1.obs['to_play'] = -1 - if self._save_replay_gif: - if not os.path.exists(self._replay_path_gif): - os.makedirs(self._replay_path_gif) + if self._replay_path is not None: + if not os.path.exists(self._replay_path): + os.makedirs(self._replay_path) timestamp = datetime.now().strftime("%Y%m%d%H%M%S") path = os.path.join( - self._replay_path_gif, - 'tictactoe_episode_{}_{}.gif'.format(self._save_replay_count, timestamp) + self._replay_path, + 'tictactoe_{}_{}_{}.mp4'.format(os.getpid(), timestamp, self._save_replay_count) ) - self.display_frames_as_gif(self._frames, path) - print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') + self.display_frames_as_mp4(self._frames, path) + print(f'replay {path} saved!') self._save_replay_count += 1 return timestep_player1 @@ -283,10 +280,10 @@ def step(self, action): else: bot_action = self.bot_action() # print('player 2 (computer player): ' + self.action_to_string(bot_action)) - if self._save_replay_gif: + if self._replay_path is not None: self._frames.append(self._env.render(mode='rgb_array')) timestep_player2 = self._player_step(bot_action) - if self._save_replay_gif: + if self._replay_path is not None: self._frames.append(self._env.render(mode='rgb_array')) # the eval_episode_return is calculated from Player 1's perspective timestep_player2.info['eval_episode_return'] = -timestep_player2.reward @@ -298,16 +295,16 @@ def step(self, action): timestep.obs['to_play'] = -1 if timestep_player2.done: - if self._save_replay_gif: - if not os.path.exists(self._replay_path_gif): - os.makedirs(self._replay_path_gif) + if self._replay_path is not None: + if not os.path.exists(self._replay_path): + os.makedirs(self._replay_path) timestamp = datetime.now().strftime("%Y%m%d%H%M%S") path = os.path.join( - self._replay_path_gif, - 'tictactoe_episode_{}_{}.gif'.format(self._save_replay_count, timestamp) + self._replay_path, + 'tictactoe_{}_{}_{}.mp4'.format(os.getpid(), timestamp, self._save_replay_count) ) - self.display_frames_as_gif(self._frames, path) - print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') + self.display_frames_as_mp4(self._frames, path) + print(f'replay {path} saved!') self._save_replay_count += 1 return timestep @@ -703,6 +700,12 @@ def display_frames_as_gif(frames: list, path: str) -> None: import imageio imageio.mimsave(path, frames, fps=20) + @staticmethod + def display_frames_as_mp4(frames: list, path: str, fps=5) -> None: + assert path.endswith('.mp4'), f'path must end with .mp4, but got {path}' + import imageio + imageio.mimwrite(path, frames, fps=fps) + def clone(self): return copy.deepcopy(self) diff --git a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py index 6d5fbdfa6..929d5594f 100644 --- a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py +++ b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_efficientzero_config.py @@ -23,7 +23,7 @@ f'data_sez_ctree/bipedalwalker_cont_disc_efficientzero_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_seed0', env=dict( stop_value=int(1e6), - env_name='BipedalWalker-v3', + env_id='BipedalWalker-v3', env_type='normal', manually_discretization=True, continuous=False, diff --git a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py index 8aeea2c7b..f20558b3a 100644 --- a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py +++ b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_disc_sampled_efficientzero_config.py @@ -24,7 +24,7 @@ f'data_sez_ctree/bipedalwalker_cont_disc_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_seed0', env=dict( stop_value=int(1e6), - env_name='BipedalWalker-v3', + env_id='BipedalWalker-v3', env_type='normal', continuous=True, manually_discretization=True, diff --git a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_sampled_efficientzero_config.py b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_sampled_efficientzero_config.py index f61ccfed3..a61856c44 100644 --- a/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_sampled_efficientzero_config.py +++ b/zoo/box2d/bipedalwalker/config/bipedalwalker_cont_sampled_efficientzero_config.py @@ -22,7 +22,7 @@ exp_name= f'data_sez_ctree/bipedalwalker_cont_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='BipedalWalker-v3', + env_id='BipedalWalker-v3', env_type='normal', continuous=True, manually_discretization=False, diff --git a/zoo/box2d/bipedalwalker/envs/bipedalwalker_cont_disc_env.py b/zoo/box2d/bipedalwalker/envs/bipedalwalker_cont_disc_env.py index eb0e994f9..acf349119 100644 --- a/zoo/box2d/bipedalwalker/envs/bipedalwalker_cont_disc_env.py +++ b/zoo/box2d/bipedalwalker/envs/bipedalwalker_cont_disc_env.py @@ -37,7 +37,7 @@ def default_config(cls: type) -> EasyDict: config = dict( # (str) The gym environment name. - env_name="BipedalWalker-v3", + env_id="BipedalWalker-v3", # (int) The number of bins for each dimension of the action space. each_dim_disc_size=4, # (bool) If True, save the replay as a gif file. @@ -65,7 +65,7 @@ def __init__(self, cfg: dict) -> None: """ self._cfg = cfg self._init_flag = False - self._env_name = cfg.env_name + self._env_id = cfg.env_id self._act_scale = cfg.act_scale self._rew_clip = cfg.rew_clip self._replay_path = cfg.replay_path @@ -159,7 +159,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") path = os.path.join( self._replay_path_gif, - '{}_episode_{}_seed{}_{}.gif'.format(self._env_name, self._save_replay_count, self._seed, timestamp) + '{}_episode_{}_seed{}_{}.gif'.format(self._env_id, self._save_replay_count, self._seed, timestamp) ) self.display_frames_as_gif(self._frames, path) print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') diff --git a/zoo/box2d/bipedalwalker/envs/bipedalwalker_env.py b/zoo/box2d/bipedalwalker/envs/bipedalwalker_env.py index b7b8621af..225bf975f 100644 --- a/zoo/box2d/bipedalwalker/envs/bipedalwalker_env.py +++ b/zoo/box2d/bipedalwalker/envs/bipedalwalker_env.py @@ -24,7 +24,7 @@ class BipedalWalkerEnv(CartPoleEnv): config = dict( # (str) The gym environment name. - env_name="BipedalWalker-v3", + env_id="BipedalWalker-v3", # (str) The type of the environment. Options: {'normal', 'hardcore'} env_type='normal', # (bool) If True, save the replay as a gif file. @@ -61,11 +61,11 @@ def __init__(self, cfg: dict) -> None: Overview: Initialize the BipedalWalker environment. Arguments: - - cfg (:obj:`dict`): Configuration dict. The dict should include keys like 'env_name', 'replay_path', etc. + - cfg (:obj:`dict`): Configuration dict. The dict should include keys like 'env_id', 'replay_path', etc. """ self._cfg = cfg self._init_flag = False - self._env_name = cfg.env_name + self._env_id = cfg.env_id self._act_scale = cfg.act_scale self._rew_clip = cfg.rew_clip self._replay_path = cfg.replay_path @@ -155,7 +155,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") path = os.path.join( self._replay_path_gif, - '{}_episode_{}_seed{}_{}.gif'.format(self._env_name, self._save_replay_count, self._seed, timestamp) + '{}_episode_{}_seed{}_{}.gif'.format(self._env_id, self._save_replay_count, self._seed, timestamp) ) self.display_frames_as_gif(self._frames, path) print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') diff --git a/zoo/box2d/lunarlander/config/lunarlander_cont_disc_efficientzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_cont_disc_efficientzero_config.py index 28835f4ab..7b9876bf3 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_cont_disc_efficientzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_cont_disc_efficientzero_config.py @@ -24,7 +24,7 @@ exp_name= f'data_ez_ctree/lunarlander_cont_disc_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='LunarLanderContinuous-v2', + env_id='LunarLanderContinuous-v2', continuous=False, manually_discretization=True, each_dim_disc_size=each_dim_disc_size, diff --git a/zoo/box2d/lunarlander/config/lunarlander_cont_disc_sampled_efficientzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_cont_disc_sampled_efficientzero_config.py index a7e896c76..1a49707dd 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_cont_disc_sampled_efficientzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_cont_disc_sampled_efficientzero_config.py @@ -22,7 +22,7 @@ exp_name= f'data_sez_ctree/lunarlander_cont_disc_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='LunarLanderContinuous-v2', + env_id='LunarLanderContinuous-v2', continuous=False, manually_discretization=True, each_dim_disc_size=each_dim_disc_size, diff --git a/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py index 0dc7de58a..035fa1c7c 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_cont_sampled_efficientzero_config.py @@ -21,7 +21,7 @@ exp_name= f'data_sez_ctree/lunarlander_cont_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='LunarLanderContinuous-v2', + env_id='LunarLanderContinuous-v2', continuous=True, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py index 5f71321e4..38e268c3e 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_efficientzero_config.py @@ -19,7 +19,7 @@ exp_name= f'data_ez_ctree/lunarlander_disc_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='LunarLander-v2', + env_id='LunarLander-v2', continuous=False, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py index 1929757cd..3f4241685 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_gumbel_muzero_config.py @@ -18,7 +18,7 @@ lunarlander_gumbel_muzero_config = dict( exp_name=f'data_mz_ctree/lunarlander_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='LunarLander-v2', + env_id='LunarLander-v2', continuous=False, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_config.py index b3229b63b..19b7bcd86 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_muzero_config.py @@ -18,7 +18,7 @@ lunarlander_muzero_config = dict( exp_name=f'data_mz_ctree/lunarlander_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='LunarLander-v2', + env_id='LunarLander-v2', continuous=False, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/box2d/lunarlander/config/lunarlander_disc_stochastic_muzero_config.py b/zoo/box2d/lunarlander/config/lunarlander_disc_stochastic_muzero_config.py index 21f40584f..11643b1c0 100644 --- a/zoo/box2d/lunarlander/config/lunarlander_disc_stochastic_muzero_config.py +++ b/zoo/box2d/lunarlander/config/lunarlander_disc_stochastic_muzero_config.py @@ -18,7 +18,7 @@ lunarlander_muzero_config = dict( exp_name=f'data_stochastic_mz_ctree/lunarlander_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='LunarLander-v2', + env_id='LunarLander-v2', continuous=False, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/box2d/lunarlander/envs/lunarlander_cont_disc_env.py b/zoo/box2d/lunarlander/envs/lunarlander_cont_disc_env.py index cfd386fee..2b36300fc 100755 --- a/zoo/box2d/lunarlander/envs/lunarlander_cont_disc_env.py +++ b/zoo/box2d/lunarlander/envs/lunarlander_cont_disc_env.py @@ -39,7 +39,7 @@ def default_config(cls: type) -> EasyDict: config = dict( # (str) The gym environment name. - env_name="LunarLander-v2", + env_id="LunarLander-v2", # (int) The number of bins for each dimension of the action space. each_dim_disc_size=4, # (bool) If True, save the replay as a gif file. @@ -65,13 +65,13 @@ def __init__(self, cfg: dict) -> None: """ self._cfg = cfg self._init_flag = False - # env_name: LunarLander-v2, LunarLanderContinuous-v2 - self._env_name = cfg.env_name + # env_id: LunarLander-v2, LunarLanderContinuous-v2 + self._env_id = cfg.env_id self._replay_path = cfg.replay_path self._replay_path_gif = cfg.replay_path_gif self._save_replay_gif = cfg.save_replay_gif self._save_replay_count = 0 - if 'Continuous' in self._env_name: + if 'Continuous' in self._env_id: self._act_scale = cfg.act_scale # act_scale only works in continuous env else: self._act_scale = False @@ -85,7 +85,7 @@ def reset(self) -> np.ndarray: - info_dict (:obj:`Dict[str, Any]`): Including observation, action_mask, and to_play label. """ if not self._init_flag: - self._env = gym.make(self._cfg.env_name, render_mode="rgb_array") + self._env = gym.make(self._cfg.env_id, render_mode="rgb_array") if self._replay_path is not None: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") video_name = f'{self._env.spec.id}-video-{timestamp}' @@ -163,7 +163,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") path = os.path.join( self._replay_path_gif, - '{}_episode_{}_seed{}_{}.gif'.format(self._env_name, self._save_replay_count, self._seed, timestamp) + '{}_episode_{}_seed{}_{}.gif'.format(self._env_id, self._save_replay_count, self._seed, timestamp) ) self.display_frames_as_gif(self._frames, path) print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') diff --git a/zoo/box2d/lunarlander/envs/lunarlander_env.py b/zoo/box2d/lunarlander/envs/lunarlander_env.py index 1c3751a86..e1617cb82 100755 --- a/zoo/box2d/lunarlander/envs/lunarlander_env.py +++ b/zoo/box2d/lunarlander/envs/lunarlander_env.py @@ -25,7 +25,7 @@ class LunarLanderEnv(CartPoleEnv): config = dict( # (str) The gym environment name. - env_name="LunarLander-v2", + env_id="LunarLander-v2", # (bool) If True, save the replay as a gif file. save_replay_gif=False, # (str or None) The path to save the replay gif. If None, the replay gif will not be saved. @@ -58,17 +58,17 @@ def __init__(self, cfg: dict) -> None: Overview: Initialize the LunarLander environment. Arguments: - - cfg (:obj:`dict`): Configuration dict. The dict should include keys like 'env_name', 'replay_path', etc. + - cfg (:obj:`dict`): Configuration dict. The dict should include keys like 'env_id', 'replay_path', etc. """ self._cfg = cfg self._init_flag = False - # env_name options = {'LunarLander-v2', 'LunarLanderContinuous-v2'} - self._env_name = cfg.env_name + # env_id options = {'LunarLander-v2', 'LunarLanderContinuous-v2'} + self._env_id = cfg.env_id self._replay_path = cfg.replay_path self._replay_path_gif = cfg.replay_path_gif self._save_replay_gif = cfg.save_replay_gif self._save_replay_count = 0 - if 'Continuous' in self._env_name: + if 'Continuous' in self._env_id: self._act_scale = cfg.act_scale # act_scale only works in continuous env else: self._act_scale = False @@ -81,7 +81,7 @@ def reset(self) -> Dict[str, np.ndarray]: - obs (:obj:`np.ndarray`): The initial observation after resetting. """ if not self._init_flag: - self._env = gym.make(self._cfg.env_name, render_mode="rgb_array") + self._env = gym.make(self._cfg.env_id, render_mode="rgb_array") if self._replay_path is not None: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") video_name = f'{self._env.spec.id}-video-{timestamp}' @@ -111,7 +111,7 @@ def reset(self) -> Dict[str, np.ndarray]: self._eval_episode_return = 0. if self._save_replay_gif: self._frames = [] - if 'Continuous' not in self._env_name: + if 'Continuous' not in self._env_id: action_mask = np.ones(4, 'int8') obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1} else: @@ -137,7 +137,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: obs, rew, terminated, truncated, info = self._env.step(action) done = terminated or truncated - if 'Continuous' not in self._env_name: + if 'Continuous' not in self._env_id: action_mask = np.ones(4, 'int8') # TODO: test the performance of varied_action_space. # action_mask[0] = 0 @@ -154,7 +154,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") path = os.path.join( self._replay_path_gif, - '{}_episode_{}_seed{}_{}.gif'.format(self._env_name, self._save_replay_count, self._seed, timestamp) + '{}_episode_{}_seed{}_{}.gif'.format(self._env_id, self._save_replay_count, self._seed, timestamp) ) self.display_frames_as_gif(self._frames, path) print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') diff --git a/zoo/box2d/lunarlander/envs/test_lunarlander_env.py b/zoo/box2d/lunarlander/envs/test_lunarlander_env.py index f932f1de2..ffab8fdd0 100755 --- a/zoo/box2d/lunarlander/envs/test_lunarlander_env.py +++ b/zoo/box2d/lunarlander/envs/test_lunarlander_env.py @@ -8,14 +8,14 @@ @pytest.mark.parametrize( 'cfg', [ EasyDict({ - 'env_name': 'LunarLander-v2', + 'env_id': 'LunarLander-v2', 'act_scale': False, 'replay_path': None, 'replay_path_gif': None, 'save_replay_gif': False, }), EasyDict({ - 'env_name': 'LunarLanderContinuous-v2', + 'env_id': 'LunarLanderContinuous-v2', 'act_scale': True, 'replay_path': None, 'replay_path_gif': None, diff --git a/zoo/bsuite/config/bsuite_efficientzero_config.py b/zoo/bsuite/config/bsuite_efficientzero_config.py index a69593d71..d35b6d442 100644 --- a/zoo/bsuite/config/bsuite_efficientzero_config.py +++ b/zoo/bsuite/config/bsuite_efficientzero_config.py @@ -1,20 +1,20 @@ from easydict import EasyDict # options={'memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22', 'memory_size/0', 'bsuite_swingup/0', 'bandit_noise/0'} -env_name = 'memory_len/9' +env_id = 'memory_len/9' -if env_name in ['memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22']: +if env_id in ['memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22']: # the memory_length of above envs is 1, 10, 50, 80, 100, respectively. action_space_size = 2 observation_shape = 3 -elif env_name in ['bsuite_swingup/0']: +elif env_id in ['bsuite_swingup/0']: action_space_size = 3 observation_shape = 8 -elif env_name == 'bandit_noise/0': +elif env_id == 'bandit_noise/0': action_space_size = 11 observation_shape = 1 -elif env_name in ['memory_size/0']: +elif env_id in ['memory_size/0']: action_space_size = 2 observation_shape = 3 else: @@ -38,9 +38,9 @@ bsuite_efficientzero_config = dict( exp_name= - f'data_ez_ctree/bsuite_{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + f'data_ez_ctree/bsuite_{env_id}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( - env_name=env_name, + env_id=env_id, stop_value=int(1e6), continuous=False, manually_discretization=False, diff --git a/zoo/bsuite/config/bsuite_muzero_config.py b/zoo/bsuite/config/bsuite_muzero_config.py index 42e12328b..5c40a204b 100644 --- a/zoo/bsuite/config/bsuite_muzero_config.py +++ b/zoo/bsuite/config/bsuite_muzero_config.py @@ -1,20 +1,20 @@ from easydict import EasyDict # options={'memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22', 'memory_size/0', 'bsuite_swingup/0', 'bandit_noise/0'} -env_name = 'memory_len/9' +env_id = 'memory_len/9' -if env_name in ['memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22']: +if env_id in ['memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22']: # the memory_length of above envs is 1, 10, 50, 80, 100, respectively. action_space_size = 2 observation_shape = 3 -elif env_name in ['bsuite_swingup/0']: +elif env_id in ['bsuite_swingup/0']: action_space_size = 3 observation_shape = 8 -elif env_name == 'bandit_noise/0': +elif env_id == 'bandit_noise/0': action_space_size = 11 observation_shape = 1 -elif env_name in ['memory_size/0']: +elif env_id in ['memory_size/0']: action_space_size = 2 observation_shape = 3 else: @@ -38,9 +38,9 @@ # ============================================================== bsuite_muzero_config = dict( - exp_name=f'data_mz_ctree/bsuite_{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + exp_name=f'data_mz_ctree/bsuite_{env_id}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( - env_name=env_name, + env_id=env_id, stop_value=int(1e6), continuous=False, manually_discretization=False, diff --git a/zoo/bsuite/config/bsuite_sampled_efficientzero_config.py b/zoo/bsuite/config/bsuite_sampled_efficientzero_config.py index 38ffe28bd..a1988f43d 100644 --- a/zoo/bsuite/config/bsuite_sampled_efficientzero_config.py +++ b/zoo/bsuite/config/bsuite_sampled_efficientzero_config.py @@ -1,20 +1,20 @@ from easydict import EasyDict # options={'memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22', 'memory_size/0', 'bsuite_swingup/0', 'bandit_noise/0'} -env_name = 'memory_len/9' +env_id = 'memory_len/9' -if env_name in ['memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22']: +if env_id in ['memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22']: # the memory_length of above envs is 1, 10, 50, 80, 100, respectively. action_space_size = 2 observation_shape = 3 -elif env_name in ['bsuite_swingup/0']: +elif env_id in ['bsuite_swingup/0']: action_space_size = 3 observation_shape = 8 -elif env_name == 'bandit_noise/0': +elif env_id == 'bandit_noise/0': action_space_size = 11 observation_shape = 1 -elif env_name in ['memory_size/0']: +elif env_id in ['memory_size/0']: action_space_size = 2 observation_shape = 3 else: @@ -42,7 +42,7 @@ exp_name= f'data_sez_ctree/bsuite_sampled_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( - env_name=env_name, + env_id=env_id, stop_value=int(1e6), continuous=False, manually_discretization=False, diff --git a/zoo/bsuite/envs/bsuite_lightzero_env.py b/zoo/bsuite/envs/bsuite_lightzero_env.py index e21a6fc65..4586d3789 100644 --- a/zoo/bsuite/envs/bsuite_lightzero_env.py +++ b/zoo/bsuite/envs/bsuite_lightzero_env.py @@ -26,7 +26,7 @@ class BSuiteEnv(BaseEnv): """ config = dict( # (str) The gym environment name. - env_name='memory_len/9', + env_id='memory_len/9', # (bool) If True, save the replay as a gif file. # Due to the definition of the environment, rendering images of certain sub-environments are meaningless. save_replay_gif=False, @@ -55,7 +55,7 @@ def __init__(self, cfg: dict = {}) -> None: """ self._cfg = cfg self._init_flag = False - self._env_name = cfg.env_name + self._env_id = cfg.env_id self._replay_path = cfg.replay_path self._replay_path_gif = cfg.replay_path_gif self._save_replay_gif = cfg.save_replay_gif @@ -67,7 +67,7 @@ def reset(self) -> Dict[str, np.ndarray]: if necessary. Returns the first observation. """ if not self._init_flag: - raw_env = bsuite.load_from_id(bsuite_id=self._env_name) + raw_env = bsuite.load_from_id(bsuite_id=self._env_id) self._env = gym_wrapper.GymFromDMEnv(raw_env) self._observation_space = self._env.observation_space self._action_space = self._env.action_space @@ -151,7 +151,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: return BaseEnvTimestep(obs, rew, done, info) def config_info(self) -> dict: - config_info = sweep.SETTINGS[self._env_name] # additional info that are specific to each env configuration + config_info = sweep.SETTINGS[self._env_id] # additional info that are specific to each env configuration config_info['num_episodes'] = self._env.bsuite_num_episodes return config_info @@ -237,4 +237,4 @@ def __repr__(self) -> str: """ String representation of the environment. """ - return "LightZero BSuite Env({})".format(self._env_name) + return "LightZero BSuite Env({})".format(self._env_id) diff --git a/zoo/classic_control/cartpole/config/cartpole_efficientzero_config.py b/zoo/classic_control/cartpole/config/cartpole_efficientzero_config.py index 705d4d73e..091cc065b 100644 --- a/zoo/classic_control/cartpole/config/cartpole_efficientzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_efficientzero_config.py @@ -19,7 +19,7 @@ exp_name= f'data_ez_ctree/cartpole_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='CartPole-v0', + env_id='CartPole-v0', continuous=False, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py index 4a84a861b..9f292575a 100644 --- a/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_gumbel_muzero_config.py @@ -18,7 +18,7 @@ cartpole_gumbel_muzero_config = dict( exp_name=f'data_mz_ctree/cartpole_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='CartPole-v0', + env_id='CartPole-v0', continuous=False, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/classic_control/cartpole/config/cartpole_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_muzero_config.py index ab47c2299..56ee32acf 100644 --- a/zoo/classic_control/cartpole/config/cartpole_muzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_muzero_config.py @@ -18,7 +18,7 @@ cartpole_muzero_config = dict( exp_name=f'data_mz_ctree/cartpole_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='CartPole-v0', + env_id='CartPole-v0', continuous=False, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/classic_control/cartpole/config/cartpole_sampled_efficientzero_config.py b/zoo/classic_control/cartpole/config/cartpole_sampled_efficientzero_config.py index 83f3c9769..6f4cc86a3 100644 --- a/zoo/classic_control/cartpole/config/cartpole_sampled_efficientzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_sampled_efficientzero_config.py @@ -21,7 +21,7 @@ exp_name= f'data_sez_ctree/cartpole_sampled_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='CartPole-v0', + env_id='CartPole-v0', continuous=False, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py b/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py index 65dbf6418..ba994807f 100644 --- a/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py +++ b/zoo/classic_control/cartpole/config/cartpole_stochastic_muzero_config.py @@ -18,7 +18,7 @@ cartpole_stochastic_muzero_config = dict( exp_name=f'data_stochastic_mz_ctree/cartpole_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='CartPole-v0', + env_id='CartPole-v0', continuous=False, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py b/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py index 29f386164..62bceb6b6 100644 --- a/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py +++ b/zoo/classic_control/cartpole/envs/cartpole_lightzero_env.py @@ -21,8 +21,8 @@ class CartPoleEnv(BaseEnv): """ config = dict( - # env_name (str): The name of the environment. - env_name="CartPole-v0", + # env_id (str): The name of the environment. + env_id="CartPole-v0", # replay_path (str): The path to save the replay video. If None, the replay will not be saved. # Only effective when env_manager.type is 'base'. replay_path=None, diff --git a/zoo/classic_control/pendulum/config/pendulum_cont_disc_efficientzero_config.py b/zoo/classic_control/pendulum/config/pendulum_cont_disc_efficientzero_config.py index 83608c83a..2ec71e411 100644 --- a/zoo/classic_control/pendulum/config/pendulum_cont_disc_efficientzero_config.py +++ b/zoo/classic_control/pendulum/config/pendulum_cont_disc_efficientzero_config.py @@ -19,7 +19,7 @@ exp_name= f'data_ez_ctree/pendulum_disc_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='Pendulum-v1', + env_id='Pendulum-v1', continuous=False, manually_discretization=True, each_dim_disc_size=11, diff --git a/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py b/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py index 987f472b6..f2c0749b6 100644 --- a/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py +++ b/zoo/classic_control/pendulum/config/pendulum_cont_disc_gumbel_muzero_config.py @@ -22,7 +22,7 @@ pendulum_disc_gumbel_muzero_config = dict( exp_name=f'data_mz_ctree/pendulum_disc_gumbel_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='Pendulum-v1', + env_id='Pendulum-v1', continuous=False, manually_discretization=True, each_dim_disc_size=11, diff --git a/zoo/classic_control/pendulum/config/pendulum_cont_disc_muzero_config.py b/zoo/classic_control/pendulum/config/pendulum_cont_disc_muzero_config.py index 5e7e2ade4..b75866a0c 100644 --- a/zoo/classic_control/pendulum/config/pendulum_cont_disc_muzero_config.py +++ b/zoo/classic_control/pendulum/config/pendulum_cont_disc_muzero_config.py @@ -22,7 +22,7 @@ pendulum_disc_muzero_config = dict( exp_name=f'data_ez_ctree/pendulum_disc_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='Pendulum-v1', + env_id='Pendulum-v1', continuous=False, manually_discretization=True, each_dim_disc_size=11, diff --git a/zoo/classic_control/pendulum/config/pendulum_cont_disc_sampled_efficientzero_config.py b/zoo/classic_control/pendulum/config/pendulum_cont_disc_sampled_efficientzero_config.py index cba4d2f30..04c1a9b11 100644 --- a/zoo/classic_control/pendulum/config/pendulum_cont_disc_sampled_efficientzero_config.py +++ b/zoo/classic_control/pendulum/config/pendulum_cont_disc_sampled_efficientzero_config.py @@ -21,7 +21,7 @@ exp_name= f'data_sez_ctree/pendulum_disc_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='Pendulum-v1', + env_id='Pendulum-v1', continuous=False, manually_discretization=True, each_dim_disc_size=11, diff --git a/zoo/classic_control/pendulum/config/pendulum_cont_disc_stochastic_muzero_config.py b/zoo/classic_control/pendulum/config/pendulum_cont_disc_stochastic_muzero_config.py index bcc807362..332b4435e 100644 --- a/zoo/classic_control/pendulum/config/pendulum_cont_disc_stochastic_muzero_config.py +++ b/zoo/classic_control/pendulum/config/pendulum_cont_disc_stochastic_muzero_config.py @@ -22,7 +22,7 @@ pendulum_disc_stochastic_muzero_config = dict( exp_name=f'data_ez_ctree/pendulum_disc_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='Pendulum-v1', + env_id='Pendulum-v1', continuous=False, manually_discretization=True, each_dim_disc_size=11, diff --git a/zoo/classic_control/pendulum/config/pendulum_cont_sampled_efficientzero_config.py b/zoo/classic_control/pendulum/config/pendulum_cont_sampled_efficientzero_config.py index 5712af84c..a6644a52e 100644 --- a/zoo/classic_control/pendulum/config/pendulum_cont_sampled_efficientzero_config.py +++ b/zoo/classic_control/pendulum/config/pendulum_cont_sampled_efficientzero_config.py @@ -21,7 +21,7 @@ exp_name= f'data_sez_ctree/pendulum_sampled_efficientzero_k{K}_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed0', env=dict( - env_name='Pendulum-v1', + env_id='Pendulum-v1', continuous=True, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/game_2048/config/muzero_2048_config.py b/zoo/game_2048/config/muzero_2048_config.py index 45eb66271..45f1c3e05 100644 --- a/zoo/game_2048/config/muzero_2048_config.py +++ b/zoo/game_2048/config/muzero_2048_config.py @@ -4,7 +4,7 @@ # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== -env_name = 'game_2048' +env_id = 'game_2048' action_space_size = 4 collector_env_num = 8 n_episode = 8 @@ -24,7 +24,7 @@ exp_name=f'data_mz_ctree/game_2048_npct-{num_of_possible_chance_tile}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_sslw2_seed0', env=dict( stop_value=int(1e6), - env_name=env_name, + env_id=env_id, obs_shape=(16, 4, 4), obs_type='dict_encoded_board', raw_reward_type='raw', # 'merged_tiles_plus_log_max_tile_num' diff --git a/zoo/game_2048/config/stochastic_muzero_2048_config.py b/zoo/game_2048/config/stochastic_muzero_2048_config.py index 367124478..9c5204c28 100644 --- a/zoo/game_2048/config/stochastic_muzero_2048_config.py +++ b/zoo/game_2048/config/stochastic_muzero_2048_config.py @@ -4,7 +4,7 @@ # ============================================================== # begin of the most frequently changed config specified by the user # ============================================================== -env_name = 'game_2048' +env_id = 'game_2048' action_space_size = 4 use_ture_chance_label_in_chance_encoder = True collector_env_num = 8 @@ -25,7 +25,7 @@ exp_name=f'data_stochastic_mz_ctree/game_2048_npct-{num_of_possible_chance_tile}_stochastic_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs{batch_size}_chance-{use_ture_chance_label_in_chance_encoder}_sslw2_seed0', env=dict( stop_value=int(1e6), - env_name=env_name, + env_id=env_id, obs_shape=(16, 4, 4), obs_type='dict_encoded_board', num_of_possible_chance_tile=num_of_possible_chance_tile, diff --git a/zoo/game_2048/entry/2048_bot_eval.py b/zoo/game_2048/entry/2048_bot_eval.py index 680f34b97..1b76cf1c1 100644 --- a/zoo/game_2048/entry/2048_bot_eval.py +++ b/zoo/game_2048/entry/2048_bot_eval.py @@ -7,7 +7,7 @@ # Define game configuration config = EasyDict(dict( - env_name="game_2048", + env_id="game_2048", # (str) The render mode. Options are 'None', 'state_realtime_mode', 'image_realtime_mode' or 'image_savefile_mode'. # If None, then the game will not be rendered. render_mode='image_realtime_mode', diff --git a/zoo/game_2048/envs/game_2048_env.py b/zoo/game_2048/envs/game_2048_env.py index 9f3d0254c..11c6755ab 100644 --- a/zoo/game_2048/envs/game_2048_env.py +++ b/zoo/game_2048/envs/game_2048_env.py @@ -89,7 +89,7 @@ class Game2048Env(gym.Env): # The default_config for game 2048 env. config = dict( # (str) The name of the environment registered in the environment registry. - env_name="game_2048", + env_id="game_2048", # (str) The render mode. Options are 'None', 'state_realtime_mode', 'image_realtime_mode' or 'image_savefile_mode'. # If None, then the game will not be rendered. render_mode=None, @@ -143,7 +143,7 @@ def default_config(cls: type) -> EasyDict: def __init__(self, cfg: dict) -> None: self._cfg = cfg self._init_flag = False - self._env_name = cfg.env_name + self._env_id = cfg.env_id self.replay_format = cfg.replay_format self.replay_name_suffix = cfg.replay_name_suffix self.replay_path = cfg.replay_path diff --git a/zoo/game_2048/envs/test_game_2048_env.py b/zoo/game_2048/envs/test_game_2048_env.py index 2bcd1ab8e..8d87ee971 100644 --- a/zoo/game_2048/envs/test_game_2048_env.py +++ b/zoo/game_2048/envs/test_game_2048_env.py @@ -10,7 +10,7 @@ class TestGame2048(): def setup(self) -> None: # Configuration for the Game2048 environment cfg = EasyDict(dict( - env_name="game_2048", + env_id="game_2048", # (str) The render mode. Options are 'None', 'state_realtime_mode', 'image_realtime_mode' or 'image_savefile_mode'. If None, then the game will not be rendered. render_mode=None, replay_format='gif', diff --git a/zoo/minigrid/config/minigrd_sampled_efficientzero_config.py b/zoo/minigrid/config/minigrd_sampled_efficientzero_config.py index f9312afd0..2b253fff9 100644 --- a/zoo/minigrid/config/minigrd_sampled_efficientzero_config.py +++ b/zoo/minigrid/config/minigrd_sampled_efficientzero_config.py @@ -2,7 +2,7 @@ # The typical MiniGrid env id: {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'}, # please refer to https://github.com/Farama-Foundation/MiniGrid for details. -env_name = 'MiniGrid-Empty-8x8-v0' +env_id = 'MiniGrid-Empty-8x8-v0' max_env_step = int(1e6) # ============================================================== # begin of the most frequently changed config specified by the user @@ -29,9 +29,9 @@ # ============================================================== minigrid_sampled_efficientzero_config = dict( - exp_name=f'data_sez_ctree/{env_name}_sampled_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + exp_name=f'data_sez_ctree/{env_id}_sampled_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( - env_name=env_name, + env_id=env_id, continuous=False, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/minigrid/config/minigrid_efficientzero_config.py b/zoo/minigrid/config/minigrid_efficientzero_config.py index 4f48550e8..fa61a3b1d 100644 --- a/zoo/minigrid/config/minigrid_efficientzero_config.py +++ b/zoo/minigrid/config/minigrid_efficientzero_config.py @@ -2,7 +2,7 @@ # The typical MiniGrid env id: {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'}, # please refer to https://github.com/Farama-Foundation/MiniGrid for details. -env_name = 'MiniGrid-Empty-8x8-v0' +env_id = 'MiniGrid-Empty-8x8-v0' max_env_step = int(1e6) # ============================================================== @@ -25,10 +25,10 @@ # ============================================================== minigrid_efficientzero_config = dict( - exp_name=f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', + exp_name=f'data_ez_ctree/{env_id}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}', env=dict( stop_value=int(1e6), - env_name=env_name, + env_id=env_id, continuous=False, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/minigrid/config/minigrid_muzero_config.py b/zoo/minigrid/config/minigrid_muzero_config.py index 3a1a7ec28..304d0860c 100644 --- a/zoo/minigrid/config/minigrid_muzero_config.py +++ b/zoo/minigrid/config/minigrid_muzero_config.py @@ -2,7 +2,7 @@ # The typical MiniGrid env id: {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'}, # please refer to https://github.com/Farama-Foundation/MiniGrid for details. -env_name = 'MiniGrid-Empty-8x8-v0' +env_id = 'MiniGrid-Empty-8x8-v0' max_env_step = int(1e6) # ============================================================== @@ -25,11 +25,11 @@ # ============================================================== minigrid_muzero_config = dict( - exp_name=f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_' + exp_name=f'data_mz_ctree/{env_id}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_' f'collect-eps-{eps_greedy_exploration_in_collect}_temp-final-steps-{threshold_training_steps_for_final_temperature}_pelw{policy_entropy_loss_weight}_seed{seed}', env=dict( stop_value=int(1e6), - env_name=env_name, + env_id=env_id, continuous=False, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/minigrid/config/minigrid_muzero_rnd_config.py b/zoo/minigrid/config/minigrid_muzero_rnd_config.py index eac8abe16..31a5f50a5 100644 --- a/zoo/minigrid/config/minigrid_muzero_rnd_config.py +++ b/zoo/minigrid/config/minigrid_muzero_rnd_config.py @@ -2,7 +2,7 @@ # The typical MiniGrid env id: {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'}, # please refer to https://github.com/Farama-Foundation/MiniGrid for details. -env_name = 'MiniGrid-Empty-8x8-v0' +env_id = 'MiniGrid-Empty-8x8-v0' max_env_step = int(1e6) # ============================================================== @@ -30,12 +30,12 @@ # ============================================================== minigrid_muzero_rnd_config = dict( - exp_name=f'data_mz_rnd_ctree/{env_name}_muzero-rnd_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}' + exp_name=f'data_mz_rnd_ctree/{env_id}_muzero-rnd_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}' f'_collect-eps-{eps_greedy_exploration_in_collect}_temp-final-steps-{threshold_training_steps_for_final_temperature}_pelw{policy_entropy_loss_weight}' f'_rnd-rew-{input_type}-{target_model_for_intrinsic_reward_update_type}_seed{seed}', env=dict( stop_value=int(1e6), - env_name=env_name, + env_id=env_id, continuous=False, manually_discretization=False, collector_env_num=collector_env_num, diff --git a/zoo/minigrid/envs/minigrid_lightzero_env.py b/zoo/minigrid/envs/minigrid_lightzero_env.py index e1eebed0f..676a50e6e 100644 --- a/zoo/minigrid/envs/minigrid_lightzero_env.py +++ b/zoo/minigrid/envs/minigrid_lightzero_env.py @@ -26,14 +26,14 @@ class MiniGridEnvLightZero(MiniGridEnv): config (dict): Configuration dict. Default configurations can be updated using this. _cfg (dict): Internal configuration dict that stores runtime configurations. _init_flag (bool): Flag to check if the environment is initialized. - _env_name (str): The name of the MiniGrid environment. + _env_id (str): The name of the MiniGrid environment. _flat_obs (bool): Flag to check if flat observations are returned. _save_replay (bool): Flag to check if replays are saved. _max_step (int): Maximum number of steps for the environment. """ config = dict( # (str) The gym environment name. - env_name='MiniGrid-Empty-8x8-v0', + env_id='MiniGrid-Empty-8x8-v0', # (bool) If True, save the replay as a gif file. save_replay_gif=False, # (str or None) The path to save the replay gif. If None, the replay gif will not be saved. @@ -65,7 +65,7 @@ def __init__(self, cfg: dict) -> None: """ self._cfg = cfg self._init_flag = False - self._env_name = cfg.env_name + self._env_id = cfg.env_id self._flat_obs = cfg.flat_obs self._save_replay_gif = cfg.save_replay_gif self._replay_path_gif = cfg.replay_path_gif @@ -81,17 +81,17 @@ def reset(self) -> np.ndarray: """ if not self._init_flag: if self._save_replay_gif: - self._env = gym.make(self._env_name, render_mode="rgb_array") + self._env = gym.make(self._env_id, render_mode="rgb_array") else: - self._env = gym.make(self._env_name) + self._env = gym.make(self._env_id) # NOTE: customize the max step of the env self._env.max_steps = self._max_step - if self._env_name in ['MiniGrid-AKTDT-13x13-v0' or 'MiniGrid-AKTDT-13x13-1-v0']: + if self._env_id in ['MiniGrid-AKTDT-13x13-v0' or 'MiniGrid-AKTDT-13x13-1-v0']: # customize the agent field of view size, note this must be an odd number # This also related to the observation space, see gym_minigrid.wrappers for more details self._env = ViewSizeWrapper(self._env, agent_view_size=5) - if self._env_name == 'MiniGrid-AKTDT-7x7-1-v0': + if self._env_id == 'MiniGrid-AKTDT-7x7-1-v0': self._env = ViewSizeWrapper(self._env, agent_view_size=3) if self._flat_obs: self._env = FlatObsWrapper(self._env) @@ -188,7 +188,7 @@ def step(self, action: np.ndarray) -> BaseEnvTimestep: timestamp = datetime.now().strftime("%Y%m%d%H%M%S") path = os.path.join( self._replay_path_gif, - '{}_episode_{}_seed{}_{}.gif'.format(self._env_name, self._save_replay_count, self._seed, timestamp) + '{}_episode_{}_seed{}_{}.gif'.format(self._env_id, self._save_replay_count, self._seed, timestamp) ) self.display_frames_as_gif(self._frames, path) print(f'save episode {self._save_replay_count} in {self._replay_path_gif}!') @@ -269,4 +269,4 @@ def __repr__(self) -> str: """ String representation of the environment. """ - return "LightZero MiniGrid Env({})".format(self._cfg.env_name) \ No newline at end of file + return "LightZero MiniGrid Env({})".format(self._cfg.env_id) \ No newline at end of file diff --git a/zoo/mujoco/config/mujoco_disc_sampled_efficientzero_config.py b/zoo/mujoco/config/mujoco_disc_sampled_efficientzero_config.py index 1a0565057..b3fdb08cb 100644 --- a/zoo/mujoco/config/mujoco_disc_sampled_efficientzero_config.py +++ b/zoo/mujoco/config/mujoco_disc_sampled_efficientzero_config.py @@ -1,23 +1,23 @@ from easydict import EasyDict # options={'Hopper-v3', 'HalfCheetah-v3', 'Walker2d-v3', 'Ant-v3', 'Humanoid-v3'} -env_name = 'Hopper-v3' +env_id = 'Hopper-v3' -if env_name == 'Hopper-v3': +if env_id == 'Hopper-v3': action_space_size = 3 observation_shape = 11 -elif env_name in ['HalfCheetah-v3', 'Walker2d-v3']: +elif env_id in ['HalfCheetah-v3', 'Walker2d-v3']: action_space_size = 6 observation_shape = 17 -elif env_name == 'Ant-v3': +elif env_id == 'Ant-v3': action_space_size = 8 observation_shape = 111 -elif env_name == 'Humanoid-v3': +elif env_id == 'Humanoid-v3': action_space_size = 17 observation_shape = 376 ignore_done = False -if env_name == 'HalfCheetah-v3': +if env_id == 'HalfCheetah-v3': # for halfcheetah, we ignore done signal to predict the Q value of the last step correctly. ignore_done = True @@ -44,9 +44,9 @@ mujoco_disc_sampled_efficientzero_config = dict( exp_name= - f'data_sez_ctree/{env_name[:-3]}_bin-{each_dim_disc_size}_sampled_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_pelw{policy_entropy_loss_weight}_seed0', + f'data_sez_ctree/{env_id[:-3]}_bin-{each_dim_disc_size}_sampled_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_pelw{policy_entropy_loss_weight}_seed0', env=dict( - env_name=env_name, + env_id=env_id, action_clip=True, continuous=False, manually_discretization=False, diff --git a/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py b/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py index c7cc30c0b..3737c4db3 100644 --- a/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py +++ b/zoo/mujoco/config/mujoco_sampled_efficientzero_config.py @@ -1,23 +1,23 @@ from easydict import EasyDict # options={'Hopper-v3', 'HalfCheetah-v3', 'Walker2d-v3', 'Ant-v3', 'Humanoid-v3'} -env_name = 'Hopper-v3' +env_id = 'Hopper-v3' -if env_name == 'Hopper-v3': +if env_id == 'Hopper-v3': action_space_size = 3 observation_shape = 11 -elif env_name in ['HalfCheetah-v3', 'Walker2d-v3']: +elif env_id in ['HalfCheetah-v3', 'Walker2d-v3']: action_space_size = 6 observation_shape = 17 -elif env_name == 'Ant-v3': +elif env_id == 'Ant-v3': action_space_size = 8 observation_shape = 111 -elif env_name == 'Humanoid-v3': +elif env_id == 'Humanoid-v3': action_space_size = 17 observation_shape = 376 ignore_done = False -if env_name == 'HalfCheetah-v3': +if env_id == 'HalfCheetah-v3': # for halfcheetah, we ignore done signal to predict the Q value of the last step correctly. ignore_done = True @@ -44,9 +44,9 @@ mujoco_sampled_efficientzero_config = dict( exp_name= - f'data_sez_ctree/{env_name[:-3]}_sampled_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs-{batch_size}_pelw{policy_entropy_loss_weight}_seed{seed}', + f'data_sez_ctree/{env_id[:-3]}_sampled_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_bs-{batch_size}_pelw{policy_entropy_loss_weight}_seed{seed}', env=dict( - env_name=env_name, + env_id=env_id, action_clip=True, continuous=True, manually_discretization=False, diff --git a/zoo/mujoco/envs/mujoco_disc_lightzero_env.py b/zoo/mujoco/envs/mujoco_disc_lightzero_env.py index 83b7b9784..e741191b7 100644 --- a/zoo/mujoco/envs/mujoco_disc_lightzero_env.py +++ b/zoo/mujoco/envs/mujoco_disc_lightzero_env.py @@ -40,8 +40,8 @@ def __init__(self, cfg: dict) -> None: """ super().__init__(cfg) self._cfg = cfg - # We use env_name to indicate the env_id in LightZero. - self._cfg.env_id = self._cfg.env_name + # We use env_id to indicate the env_id in LightZero. + self._cfg.env_id = self._cfg.env_id self._action_clip = cfg.action_clip self._delay_reward_step = cfg.delay_reward_step self._init_flag = False @@ -122,7 +122,7 @@ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep: if done: if self._save_replay_gif: path = os.path.join( - self._replay_path_gif, '{}_episode_{}.gif'.format(self._cfg.env_name, self._save_replay_count) + self._replay_path_gif, '{}_episode_{}.gif'.format(self._cfg.env_id, self._save_replay_count) ) save_frames_as_gif(self._frames, path) self._save_replay_count += 1 @@ -143,4 +143,4 @@ def __repr__(self) -> str: Returns: - repr_str (:obj:`str`): Representation string of the environment instance. """ - return "LightZero modified Mujoco Env({}) with manually discretized action space".format(self._cfg.env_name) + return "LightZero modified Mujoco Env({}) with manually discretized action space".format(self._cfg.env_id) diff --git a/zoo/mujoco/envs/mujoco_lightzero_env.py b/zoo/mujoco/envs/mujoco_lightzero_env.py index b3a62330f..a5a9036f0 100644 --- a/zoo/mujoco/envs/mujoco_lightzero_env.py +++ b/zoo/mujoco/envs/mujoco_lightzero_env.py @@ -38,12 +38,12 @@ def __init__(self, cfg: dict) -> None: Overview: Initialize the MuJoCo environment. Arguments: - - cfg (:obj:`dict`): Configuration dict. The dict should include keys like 'env_name', 'replay_path', etc. + - cfg (:obj:`dict`): Configuration dict. The dict should include keys like 'env_id', 'replay_path', etc. """ super().__init__(cfg) self._cfg = cfg - # We use env_name to indicate the env_id in LightZero. - self._cfg.env_id = self._cfg.env_name + # We use env_id to indicate the env_id in LightZero. + self._cfg.env_id = self._cfg.env_id self._action_clip = cfg.action_clip self._delay_reward_step = cfg.delay_reward_step self._init_flag = False @@ -120,7 +120,7 @@ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep: if done: if self._save_replay_gif: path = os.path.join( - self._replay_path_gif, '{}_episode_{}.gif'.format(self._cfg.env_name, self._save_replay_count) + self._replay_path_gif, '{}_episode_{}.gif'.format(self._cfg.env_id, self._save_replay_count) ) save_frames_as_gif(self._frames, path) self._save_replay_count += 1 @@ -138,5 +138,5 @@ def __repr__(self) -> str: """ String representation of the environment. """ - return "LightZero Mujoco Env({})".format(self._cfg.env_name) + return "LightZero Mujoco Env({})".format(self._cfg.env_id)