diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
new file mode 100644
index 000000000..6e8823425
--- /dev/null
+++ b/CONTRIBUTING.md
@@ -0,0 +1,63 @@
+# ๐ Welcome to LightZero! ๐
+
+We're thrilled that you want to contribute to LightZero. Your help is invaluable, and we appreciate your efforts to make this project even better. ๐
+
+## ๐ How to Contribute
+
+1. **Fork the Repository** ๐ด
+ - Click on the "Fork" button at the top right of the [LightZero repository](https://github.com/opendilab/LightZero).
+
+2. **Clone your Fork** ๐ป
+ - `git clone https://github.com/your-username/LightZero.git`
+
+3. **Create a New Branch** ๐ฟ
+ - `git checkout -b your-new-feature`
+
+4. **Make Your Awesome Changes** ๐ฅ
+ - Add some cool features.
+ - Fix a bug.
+ - Improve the documentation.
+ - Anything that adds value!
+
+5. **Commit Your Changes** ๐ฆ
+ - `git commit -m "Your descriptive commit message"`
+
+6. **Push to Your Fork** ๐ข
+ - `git push origin your-new-feature`
+
+7. **Create a Pull Request** ๐
+ - Go to the [LightZero repository](https://github.com/opendilab/LightZero).
+ - Click on "New Pull Request."
+ - Fill in the details and submit your PR.
+ - Please make sure your PR has a clear title and description.
+
+8. **Review & Collaborate** ๐ค
+ - Be prepared to answer questions or make changes to your PR as requested by the maintainers.
+
+9. **Celebrate! ๐** Your contribution has been added to LightZero.
+
+## ๐ฆ Reporting Issues
+
+If you encounter a bug or have an idea for an improvement, please create an issue in the [Issues](https://github.com/opendilab/LightZero/issues) section. Make sure to include details about the problem and how to reproduce it.
+
+## ๐ Code Style and Guidelines
+
+We follow a few simple guidelines:
+- Keep your code clean and readable.
+- Use meaningful variable and function names.
+- Comment your code when necessary.
+- Ensure your code adheres to existing coding styles and standards.
+
+For detailed information on code style, unit testing, and code review, please refer to our documentation:
+
+- [Code Style](https://di-engine-docs.readthedocs.io/en/latest/21_code_style/index.html)
+- [Unit Test](https://di-engine-docs.readthedocs.io/en/latest/22_test/index.html)
+- [Code Review](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/git_guide.html)
+
+## ๐ค Code of Conduct
+
+Please be kind and respectful when interacting with other contributors. We have a [Code of Conduct](LICENSE) to ensure a positive and welcoming environment for everyone.
+
+## ๐ Thank You! ๐
+
+Your contribution helps make LightZero even better. We appreciate your dedication to the project. Keep coding and stay awesome! ๐
diff --git a/README.md b/README.md
index 83d3d032c..bd5fd847a 100644
--- a/README.md
+++ b/README.md
@@ -31,11 +31,14 @@ Updated on 2023.09.21 LightZero-v0.0.2
> LightZero is a lightweight, efficient, and easy-to-understand open-source algorithm toolkit that combines Monte Carlo Tree Search (MCTS) and Deep Reinforcement Learning (RL).
-English | [็ฎไฝไธญๆ](https://github.com/opendilab/LightZero/blob/main/README.zh.md) | [Paper](https://arxiv.org/pdf/2310.08348.pdf)
+English | [็ฎไฝไธญๆ(Simplified Chinese)](https://github.com/opendilab/LightZero/blob/main/README.zh.md) | [Paper](https://arxiv.org/pdf/2310.08348.pdf)
## Background
-The method of combining Monte Carlo Tree Search and Deep Reinforcement Learning represented by AlphaZero and MuZero has achieved superhuman level in various games such as Go and Atari,and has also made gratifying progress in scientific fields such as protein structure prediction, matrix multiplication algorithm search, etc.
+The integration of Monte Carlo Tree Search and Deep Reinforcement Learning,
+exemplified by AlphaZero and MuZero,
+has achieved unprecedented performance levels in various games, including Go and Atari.
+This advanced methodology has also made significant strides in scientific domains like protein structure prediction and the search for matrix multiplication algorithms.
The following is an overview of the historical evolution of the Monte Carlo Tree Search algorithm series:
![pipeline](assets/mcts_rl_evolution_overview.png)
@@ -484,4 +487,4 @@ Special thanks to [@PaParaZz1](https://github.com/PaParaZz1), [@karroyan](https:
## License
All code within this repository is under [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
-
@@ -477,3 +479,6 @@ python3 -u zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py
## ่ฎธๅฏ่ฏ
ๆฌไปๅบไธญ็ๆๆไปฃ็ ้ฝ็ฌฆๅ [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0)ใ
+
+
+
(ๅๅฐ้กถ้จ)
diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py
index f68d876a6..9b50ed6ec 100644
--- a/lzero/entry/__init__.py
+++ b/lzero/entry/__init__.py
@@ -1,6 +1,7 @@
from .train_alphazero import train_alphazero
from .eval_alphazero import eval_alphazero
from .train_muzero import train_muzero
+from .train_muzero_with_reward_model import train_muzero_with_reward_model
from .eval_muzero import eval_muzero
from .eval_muzero_with_gym_env import eval_muzero_with_gym_env
from .train_muzero_with_gym_env import train_muzero_with_gym_env
\ No newline at end of file
diff --git a/lzero/entry/train_muzero.py b/lzero/entry/train_muzero.py
index d9556dc98..5e3976063 100644
--- a/lzero/entry/train_muzero.py
+++ b/lzero/entry/train_muzero.py
@@ -122,8 +122,8 @@ def train_muzero(
update_per_collect = cfg.policy.update_per_collect
# The purpose of collecting random data before training:
- # Exploration: The collection of random data aids the agent in exploring the environment and prevents premature convergence to a suboptimal policy.
- # Comparation: The agent's performance during random action-taking can be used as a reference point to evaluate the efficacy of reinforcement learning algorithms.
+ # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely.
+ # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms.
if cfg.policy.random_collect_episode_num > 0:
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
diff --git a/lzero/entry/train_muzero_with_reward_model.py b/lzero/entry/train_muzero_with_reward_model.py
new file mode 100644
index 000000000..2ae409601
--- /dev/null
+++ b/lzero/entry/train_muzero_with_reward_model.py
@@ -0,0 +1,210 @@
+import logging
+import os
+from functools import partial
+from typing import Optional, Tuple
+
+import torch
+from ding.config import compile_config
+from ding.envs import create_env_manager
+from ding.envs import get_vec_env_setting
+from ding.policy import create_policy
+from ding.rl_utils import get_epsilon_greedy_fn
+from ding.utils import set_pkg_seed
+from ding.worker import BaseLearner
+from tensorboardX import SummaryWriter
+
+from lzero.entry.utils import log_buffer_memory_usage, random_collect
+from lzero.policy import visit_count_temperature
+from lzero.policy.random_policy import LightZeroRandomPolicy
+from lzero.reward_model.rnd_reward_model import RNDRewardModel
+from lzero.worker import MuZeroCollector, MuZeroEvaluator
+
+
+def train_muzero_with_reward_model(
+ input_cfg: Tuple[dict, dict],
+ seed: int = 0,
+ model: Optional[torch.nn.Module] = None,
+ model_path: Optional[str] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ The train entry for MCTS+RL algorithms augmented with reward_model.
+ Arguments:
+ - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - model_path (:obj:`Optional[str]`): The pretrained model path, which should
+ point to the ckpt file of the pretrained model, and an absolute path is recommended.
+ In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+
+ cfg, create_cfg = input_cfg
+ assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_rnd', 'sampled_efficientzero'], \
+ "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero'"
+
+ if create_cfg.policy.type in ['muzero', 'muzero_rnd']:
+ from lzero.mcts import MuZeroGameBuffer as GameBuffer
+ elif create_cfg.policy.type == 'efficientzero':
+ from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
+ elif create_cfg.policy.type == 'sampled_efficientzero':
+ from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer
+
+ if cfg.policy.cuda and torch.cuda.is_available():
+ cfg.policy.device = 'cuda'
+ else:
+ cfg.policy.device = 'cpu'
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
+
+ # load pretrained model
+ if model_path is not None:
+ policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+
+ # ==============================================================
+ # MCTS+RL algorithms related core code
+ # ==============================================================
+ policy_config = cfg.policy
+ batch_size = policy_config.batch_size
+ # specific game buffer for MCTS+RL algorithms
+ replay_buffer = GameBuffer(policy_config)
+ collector = MuZeroCollector(
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ policy_config=policy_config
+ )
+ evaluator = MuZeroEvaluator(
+ eval_freq=cfg.policy.eval_freq,
+ n_evaluator_episode=cfg.env.n_evaluator_episode,
+ stop_value=cfg.env.stop_value,
+ env=evaluator_env,
+ policy=policy.eval_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ policy_config=policy_config
+ )
+ # create reward_model
+ reward_model = RNDRewardModel(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger,
+ policy._learn_model.representation_network,
+ policy._target_model_for_intrinsic_reward.representation_network,
+ cfg.policy.use_momentum_representation_network
+ )
+
+ # ==============================================================
+ # Main loop
+ # ==============================================================
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+ if cfg.policy.update_per_collect is not None:
+ update_per_collect = cfg.policy.update_per_collect
+
+ # The purpose of collecting random data before training:
+ # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely.
+ # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms.
+ if cfg.policy.random_collect_episode_num > 0:
+ random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
+
+ while True:
+ log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
+ collect_kwargs = {}
+ # set temperature for visit count distributions according to the train_iter,
+ # please refer to Appendix D in MuZero paper for details.
+ collect_kwargs['temperature'] = visit_count_temperature(
+ policy_config.manual_temperature_decay,
+ policy_config.fixed_temperature_value,
+ policy_config.threshold_training_steps_for_final_temperature,
+ trained_steps=learner.train_iter,
+ )
+
+ if policy_config.eps.eps_greedy_exploration_in_collect:
+ epsilon_greedy_fn = get_epsilon_greedy_fn(start=policy_config.eps.start, end=policy_config.eps.end,
+ decay=policy_config.eps.decay, type_=policy_config.eps.type)
+ collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
+ else:
+ collect_kwargs['epsilon'] = 0.0
+
+ # Evaluate policy performance.
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+
+ # Collect data by default config n_sample/n_episode.
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+
+ # ****** reward_model related code ******
+ # collect data for reward_model training
+ reward_model.collect_data(new_data)
+ # update reward_model
+ if reward_model.cfg.input_type == 'latent_state':
+ # train reward_model with latent_state
+ if len(reward_model.train_latent_state) > reward_model.cfg.batch_size:
+ reward_model.train_with_data()
+ elif reward_model.cfg.input_type in ['obs', 'latent_state']:
+ # train reward_model with obs
+ if len(reward_model.train_obs) > reward_model.cfg.batch_size:
+ reward_model.train_with_data()
+ # clear old data in reward_model
+ reward_model.clear_old_data()
+
+ if cfg.policy.update_per_collect is None:
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
+ update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
+ # save returned new_data collected by the collector
+ replay_buffer.push_game_segments(new_data)
+ # remove the oldest data if the replay buffer is full.
+ replay_buffer.remove_oldest_data_to_fit()
+
+ # Learn policy from collected data.
+ for i in range(update_per_collect):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ if replay_buffer.get_num_of_transitions() > batch_size:
+ train_data = replay_buffer.sample(batch_size, policy)
+ else:
+ logging.warning(
+ f'The data in replay_buffer is not sufficient to sample a mini-batch: '
+ f'batch_size: {batch_size}, '
+ f'{replay_buffer} '
+ f'continue to collect now ....'
+ )
+ break
+
+ # update train_data reward using the augmented reward
+ train_data_augmented = reward_model.estimate(train_data)
+
+ # The core train steps for MCTS+RL algorithms.
+ log_vars = learner.train(train_data_augmented, collector.envstep)
+
+ if cfg.policy.use_priority:
+ replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
+
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/lzero/mcts/ptree/ptree_az.py b/lzero/mcts/ptree/ptree_az.py
index 27978764d..486b7eb0a 100644
--- a/lzero/mcts/ptree/ptree_az.py
+++ b/lzero/mcts/ptree/ptree_az.py
@@ -12,7 +12,7 @@
import copy
import math
-from typing import List, Tuple, Union, Callable, Type
+from typing import List, Tuple, Union, Callable, Type, Dict, Any
import numpy as np
import torch
@@ -169,7 +169,7 @@ class MCTS(object):
Finally, by repeatedly calling ``_simulate`` through ``get_next_action``, the optimal action is obtained.
"""
- def __init__(self, cfg: EasyDict) -> None:
+ def __init__(self, cfg: EasyDict, simulate_env: Type[BaseEnv]) -> None:
"""
Overview:
Initializes the MCTS process.
@@ -193,11 +193,12 @@ def __init__(self, cfg: EasyDict) -> None:
'root_dirichlet_alpha', 0.3
) # 0.3 # for chess, 0.03 for Go and 0.15 for shogi.
self._root_noise_weight = self._cfg.get('root_noise_weight', 0.25)
- self.simulate_cnt = 0
+
+ self.simulate_env = simulate_env
def get_next_action(
self,
- simulate_env: Type[BaseEnv],
+ state_config_for_simulate_env_reset: Dict[str, Any],
policy_forward_fn: Callable,
temperature: int = 1.0,
sample: bool = True
@@ -206,7 +207,7 @@ def get_next_action(
Overview:
Get the next action to take based on the current state of the game.
Arguments:
- - simulate_env (:obj:`Class BaseGameEnv`): The class of simulate env.
+ - state_config_for_simulate_env_reset (:obj:`Dict`): The config of state when reset the env.
- policy_forward_fn (:obj:`Function`): The Callable to compute the action probs and state value.
- temperature (:obj:`Float`): The exploration temperature.
- sample (:obj:`Bool`): Whether to sample an action from the probabilities or choose the most probable action.
@@ -217,41 +218,38 @@ def get_next_action(
# Create a new root node for the MCTS search.
root = Node()
+
+ self.simulate_env.reset(
+ start_player_index=state_config_for_simulate_env_reset.start_player_index,
+ init_state=state_config_for_simulate_env_reset.init_state,
+ )
# Expand the root node by adding children to it.
- self._expand_leaf_node(root, simulate_env, policy_forward_fn)
+ self._expand_leaf_node(root, self.simulate_env, policy_forward_fn)
# Add Dirichlet noise to the root node's prior probabilities to encourage exploration.
if sample:
self._add_exploration_noise(root)
- # for debugging
- # print(simulate_env.board)
- # print('value= {}'.format([(k, v.value) for k,v in root.children.items()]))
- # print('visit_count= {}'.format([(k, v.visit_count) for k,v in root.children.items()]))
- # print('legal_action= {}',format(simulate_env.legal_actions))
-
# Perform MCTS search for a fixed number of iterations.
for n in range(self._num_simulations):
# Initialize the simulated environment and reset it to the root node.
- simulate_env_copy = copy.deepcopy(simulate_env)
+ self.simulate_env.reset(
+ start_player_index=state_config_for_simulate_env_reset.start_player_index,
+ init_state=state_config_for_simulate_env_reset.init_state,
+ )
# Set the battle mode adopted by the environment during the MCTS process.
# In ``self_play_mode``, when the environment calls the step function once, it will play one move based on the incoming action.
# In ``play_with_bot_mode``, when the step function is called, it will play one move based on the incoming action,
# and then it will play another move based on the action generated by the built-in bot in the environment, which means two moves in total.
# Therefore, in the MCTS process, except for the terminal nodes, the player corresponding to each node is the same player as the root node.
- simulate_env_copy.battle_mode = simulate_env_copy.mcts_mode
- simulate_env_copy.render_mode = None
+ self.simulate_env.battle_mode = self.simulate_env.mcts_mode
+ self.simulate_env.render_mode = None
# Run the simulation from the root to a leaf node and update the node values along the way.
- self._simulate(root, simulate_env_copy, policy_forward_fn)
-
- # for debugging
- # print('after simulation')
- # print('value= {}'.format([(k, v.value) for k,v in root.children.items()]))
- # print('visit_count= {}'.format([(k, v.visit_count) for k,v in root.children.items()]))
+ self._simulate(root, self.simulate_env, policy_forward_fn)
# Get the visit count for each possible action at the root node.
action_visits = []
- for action in range(simulate_env.action_space.n):
+ for action in range(self.simulate_env.action_space.n):
if action in root.children:
action_visits.append((action, root.children[action].visit_count))
else:
@@ -273,6 +271,7 @@ def get_next_action(
action = np.random.choice(actions, p=action_probs)
else:
action = actions[np.argmax(action_probs)]
+
# Return the selected action and the output probability of each action.
return action, action_probs
@@ -288,11 +287,6 @@ def _simulate(self, node: Node, simulate_env: Type[BaseEnv], policy_forward_fn:
"""
while not node.is_leaf():
# Traverse the tree until the leaf node.
-
- # only for debug
- # self.simulate_cnt += 1
- # print('simulate_cnt: {}'.format(self.simulate_cnt))
- # print(f'node:{node}, list(node.children.keys()) is: {list(node.children.keys())}. simulate_env.legal_actions is: {simulate_env.legal_actions}')
action, node = self._select_child(node, simulate_env)
# When there are no common elements in ``node.children`` and ``simulate_env.legal_actions``, action would be None, and we set the node to be a leaf node.
if action is None:
diff --git a/lzero/policy/alphazero.py b/lzero/policy/alphazero.py
index ec8ff6a02..bb9ee2637 100644
--- a/lzero/policy/alphazero.py
+++ b/lzero/policy/alphazero.py
@@ -9,6 +9,7 @@
from ding.torch_utils import to_device
from ding.utils import POLICY_REGISTRY
from ding.utils.data import default_collate
+from easydict import EasyDict
from lzero.mcts.ptree.ptree_az import MCTS
from lzero.policy import configure_optimizers
@@ -210,17 +211,17 @@ def _init_collect(self) -> None:
Overview:
Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils.
"""
- self._collect_mcts = MCTS(self._cfg.mcts)
+ self._get_simulation_env()
self._collect_model = self._model
self._collect_mcts_temperature = 1
+ self._collect_mcts = MCTS(self._cfg.mcts, self.simulate_env)
@torch.no_grad()
- def _forward_collect(self, envs: Dict, obs: Dict, temperature: float = 1) -> Dict[str, torch.Tensor]:
+ def _forward_collect(self, obs: Dict, temperature: float = 1) -> Dict[str, torch.Tensor]:
"""
Overview:
The forward function for collecting data in collect mode. Use real env to execute MCTS search.
Arguments:
- - envs (:obj:`Dict`): The dict of colletor envs, the key is env_id and the value is the env instance.
- obs (:obj:`Dict`): The dict of obs, the key is env_id and the value is the \
corresponding obs in this timestep.
- temperature (:obj:`float`): The temperature for MCTS search.
@@ -229,20 +230,16 @@ def _forward_collect(self, envs: Dict, obs: Dict, temperature: float = 1) -> Dic
the corresponding policy output in this timestep, including action, probs and so on.
"""
self._collect_mcts_temperature = temperature
- ready_env_id = list(envs.keys())
+ ready_env_id = list(obs.keys())
init_state = {env_id: obs[env_id]['board'] for env_id in ready_env_id}
start_player_index = {env_id: obs[env_id]['current_player_index'] for env_id in ready_env_id}
output = {}
self._policy_model = self._collect_model
for env_id in ready_env_id:
- # print('[collect] start_player_index={}'.format(start_player_index[env_id]))
- # print('[collect] init_state=\n{}'.format(init_state[env_id]))
- envs[env_id].reset(
- start_player_index=start_player_index[env_id],
- init_state=init_state[env_id],
- )
+ state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
+ init_state=init_state[env_id], ))
action, mcts_probs = self._collect_mcts.get_next_action(
- envs[env_id],
+ state_config_for_simulation_env_reset,
policy_forward_fn=self._policy_value_fn,
temperature=self._collect_mcts_temperature,
sample=True
@@ -258,15 +255,18 @@ def _init_eval(self) -> None:
Overview:
Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils.
"""
- self._eval_mcts = MCTS(self._cfg.mcts)
+ self._get_simulation_env()
+ import copy
+ mcts_eval_config = copy.deepcopy(self._cfg.mcts)
+ mcts_eval_config.num_simulations = mcts_eval_config.num_simulations * 2
+ self._eval_mcts = MCTS(mcts_eval_config, self.simulate_env)
self._eval_model = self._model
- def _forward_eval(self, envs: Dict, obs: Dict) -> Dict[str, torch.Tensor]:
+ def _forward_eval(self, obs: Dict) -> Dict[str, torch.Tensor]:
"""
Overview:
The forward function for evaluating the current policy in eval mode, similar to ``self._forward_collect``.
Arguments:
- - envs (:obj:`Dict`): The dict of colletor envs, the key is env_id and the value is the env instance.
- obs (:obj:`Dict`): The dict of obs, the key is env_id and the value is the \
corresponding obs in this timestep.
Returns:
@@ -279,14 +279,10 @@ def _forward_eval(self, envs: Dict, obs: Dict) -> Dict[str, torch.Tensor]:
output = {}
self._policy_model = self._eval_model
for env_id in ready_env_id:
- # print('[eval] start_player_index={}'.format(start_player_index[env_id]))
- # print('[eval] init_state=\n {}'.format(init_state[env_id]))
- envs[env_id].reset(
- start_player_index=start_player_index[env_id],
- init_state=init_state[env_id],
- )
+ state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=start_player_index[env_id],
+ init_state=init_state[env_id],))
action, mcts_probs = self._eval_mcts.get_next_action(
- envs[env_id], policy_forward_fn=self._policy_value_fn, temperature=1.0, sample=False
+ state_config_for_simulation_env_reset, policy_forward_fn=self._policy_value_fn, temperature=1.0, sample=False
)
output[env_id] = {
'action': action,
@@ -294,6 +290,31 @@ def _forward_eval(self, envs: Dict, obs: Dict) -> Dict[str, torch.Tensor]:
}
return output
+ def _get_simulation_env(self):
+ if self._cfg.simulation_env_name == 'tictactoe':
+ from zoo.board_games.tictactoe.envs.tictactoe_env import TicTacToeEnv
+ if self._cfg.simulation_env_config_type == 'play_with_bot':
+ from zoo.board_games.tictactoe.config.tictactoe_alphazero_bot_mode_config import \
+ tictactoe_alphazero_config
+ elif self._cfg.simulation_env_config_type == 'self_play':
+ from zoo.board_games.tictactoe.config.tictactoe_alphazero_sp_mode_config import \
+ tictactoe_alphazero_config
+ else:
+ raise NotImplementedError
+ self.simulate_env = TicTacToeEnv(tictactoe_alphazero_config.env)
+
+ elif self._cfg.simulation_env_name == 'gomoku':
+ from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv
+ if self._cfg.simulation_env_config_type == 'play_with_bot':
+ from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import gomoku_alphazero_config
+ elif self._cfg.simulation_env_config_type == 'self_play':
+ from zoo.board_games.gomoku.config.gomoku_alphazero_sp_mode_config import gomoku_alphazero_config
+ else:
+ raise NotImplementedError
+ self.simulate_env = GomokuEnv(gomoku_alphazero_config.env)
+ else:
+ raise NotImplementedError
+
@torch.no_grad()
def _policy_value_fn(self, env: 'Env') -> Tuple[Dict[int, np.ndarray], float]: # noqa
legal_actions = env.legal_actions
diff --git a/lzero/policy/muzero.py b/lzero/policy/muzero.py
index 80bbc53ba..b4094e0a0 100644
--- a/lzero/policy/muzero.py
+++ b/lzero/policy/muzero.py
@@ -8,6 +8,7 @@
from ding.policy.base_policy import Policy
from ding.torch_utils import to_tensor
from ding.utils import POLICY_REGISTRY
+from torch.distributions import Categorical
from torch.nn import L1Loss
from lzero.mcts import MuZeroMCTSCtree as MCTSCtree
@@ -60,6 +61,8 @@ class MuZeroPolicy(Policy):
norm_type='BN',
),
# ****** common ******
+ # (bool) whether to use rnd model.
+ use_rnd_model=False,
# (bool) Whether to use multi-gpu training.
multi_gpu=False,
# (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero)
@@ -116,6 +119,8 @@ class MuZeroPolicy(Policy):
learning_rate=0.2,
# (int) Frequency of target network update.
target_update_freq=100,
+ # (int) Frequency of target network update.
+ target_update_freq_for_intrinsic_reward=1000,
# (float) Weight decay for training policy network.
weight_decay=1e-4,
# (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction).
@@ -138,6 +143,8 @@ class MuZeroPolicy(Policy):
value_loss_weight=0.25,
# (float) The weight of policy loss.
policy_loss_weight=1,
+ # (float) The weight of policy entropy loss.
+ policy_entropy_loss_weight=0,
# (float) The weight of ssl (self-supervised learning) loss.
ssl_loss_weight=0,
# (bool) Whether to use piecewise constant learning rate decay.
@@ -256,6 +263,22 @@ def _init_learn(self) -> None:
self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution
)
+ if self._cfg.use_rnd_model:
+ if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign':
+ self._target_model_for_intrinsic_reward = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='assign',
+ update_kwargs={'freq': self._cfg.target_update_freq_for_intrinsic_reward}
+ )
+ elif self._cfg.target_model_for_intrinsic_reward_update_type == 'momentum':
+ self._target_model_for_intrinsic_reward = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward}
+ )
+
def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]:
"""
Overview:
@@ -271,6 +294,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
"""
self._learn_model.train()
self._target_model.train()
+ if self._cfg.use_rnd_model:
+ self._target_model_for_intrinsic_reward.train()
current_batch, target_batch = data
obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch
@@ -340,6 +365,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0])
value_loss = cross_entropy_loss(value, target_value_categorical[:, 0])
+ prob = torch.softmax(policy_logits, dim=-1)
+ dist = Categorical(prob)
+ policy_entropy_loss = -dist.entropy()
+
+
reward_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device)
consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device)
@@ -383,6 +413,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# ==============================================================
policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1])
+ prob = torch.softmax(policy_logits, dim=-1)
+ dist = Categorical(prob)
+ policy_entropy_loss += -dist.entropy()
+
value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1])
reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k])
@@ -406,7 +440,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# weighted loss with masks (some invalid states which are out of trajectory.)
loss = (
self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss +
- self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss
+ self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss +
+ self._cfg.policy_entropy_loss_weight * policy_entropy_loss
)
weighted_total_loss = (weights * loss).mean()
@@ -427,6 +462,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# the core target model update step.
# ==============================================================
self._target_model.update(self._learn_model.state_dict())
+ if self._cfg.use_rnd_model:
+ self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict())
if self._cfg.monitor_extra_statistics:
predicted_rewards = torch.stack(predicted_rewards).transpose(1, 0).squeeze(-1)
@@ -439,6 +476,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
'weighted_total_loss': weighted_total_loss.item(),
'total_loss': loss.mean().item(),
'policy_loss': policy_loss.mean().item(),
+ 'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1),
'reward_loss': reward_loss.mean().item(),
'value_loss': value_loss.mean().item(),
'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps,
@@ -467,7 +505,7 @@ def _init_collect(self) -> None:
self._mcts_collect = MCTSCtree(self._cfg)
else:
self._mcts_collect = MCTSPtree(self._cfg)
- self._collect_mcts_temperature = 1
+ self._collect_mcts_temperature = 1.
self.collect_epsilon = 0.0
def _forward_collect(
@@ -697,6 +735,7 @@ def _monitor_vars_learn(self) -> List[str]:
'weighted_total_loss',
'total_loss',
'policy_loss',
+ 'policy_entropy',
'reward_loss',
'value_loss',
'consistency_loss',
diff --git a/lzero/reward_model/rnd_reward_model.py b/lzero/reward_model/rnd_reward_model.py
new file mode 100644
index 000000000..453e63759
--- /dev/null
+++ b/lzero/reward_model/rnd_reward_model.py
@@ -0,0 +1,337 @@
+import copy
+import random
+from typing import Union, Tuple, List, Dict
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ding.model import FCEncoder, ConvEncoder
+from ding.reward_model.base_reward_model import BaseRewardModel
+from ding.torch_utils.data_helper import to_tensor
+from ding.utils import RunningMeanStd
+from ding.utils import SequenceType, REWARD_MODEL_REGISTRY
+from easydict import EasyDict
+
+
+class RNDNetwork(nn.Module):
+
+ def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType) -> None:
+ super(RNDNetwork, self).__init__()
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.target = FCEncoder(obs_shape, hidden_size_list)
+ self.predictor = FCEncoder(obs_shape, hidden_size_list)
+ elif len(obs_shape) == 3:
+ self.target = ConvEncoder(obs_shape, hidden_size_list)
+ self.predictor = ConvEncoder(obs_shape, hidden_size_list)
+ else:
+ raise KeyError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own RND model".
+ format(obs_shape)
+ )
+ for param in self.target.parameters():
+ param.requires_grad = False
+
+ def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ predict_feature = self.predictor(obs)
+ with torch.no_grad():
+ target_feature = self.target(obs)
+ return predict_feature, target_feature
+
+
+class RNDNetworkRepr(nn.Module):
+ """
+ Overview:
+ The RND reward model class (https://arxiv.org/abs/1810.12894v1) with representation network.
+ """
+
+ def __init__(self, obs_shape: Union[int, SequenceType], latent_shape: Union[int, SequenceType], hidden_size_list: SequenceType,
+ representation_network) -> None:
+ super(RNDNetworkRepr, self).__init__()
+ self.representation_network = representation_network
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.target = FCEncoder(obs_shape, hidden_size_list)
+ self.predictor = FCEncoder(latent_shape, hidden_size_list)
+ elif len(obs_shape) == 3:
+ self.target = ConvEncoder(obs_shape, hidden_size_list)
+ self.predictor = ConvEncoder(latent_shape, hidden_size_list)
+ else:
+ raise KeyError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own RND model".
+ format(obs_shape)
+ )
+ for param in self.target.parameters():
+ param.requires_grad = False
+
+ def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ predict_feature = self.predictor(self.representation_network(obs))
+ with torch.no_grad():
+ target_feature = self.target(obs)
+
+ return predict_feature, target_feature
+
+
+@REWARD_MODEL_REGISTRY.register('rnd_muzero')
+class RNDRewardModel(BaseRewardModel):
+ """
+ Overview:
+ The RND reward model class (https://arxiv.org/abs/1810.12894v1) modified for MuZero.
+ Interface:
+ ``estimate``, ``train``, ``collect_data``, ``clear_data``, \
+ ``__init__``, ``_train``, ``load_state_dict``, ``state_dict``
+ Config:
+ == ==================== ===== ============= ======================================= =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ===== ============= ======================================= =======================
+ 1 ``type`` str rnd | Reward model register name, refer |
+ | to registry ``REWARD_MODEL_REGISTRY`` |
+ 2 | ``intrinsic_`` str add | the intrinsic reward type | including add, new
+ | ``reward_type`` | | , or assign
+ 3 | ``learning_rate`` float 0.001 | The step size of gradient descent |
+ 4 | ``batch_size`` int 64 | Training batch size |
+ 5 | ``hidden`` list [64, 64, | the MLP layer shape |
+ | ``_size_list`` (int) 128] | |
+ 6 | ``update_per_`` int 100 | Number of updates per collect |
+ | ``collect`` | |
+ 7 | ``input_norm`` bool True | Observation normalization |
+ 8 | ``input_norm_`` int 0 | min clip value for obs normalization |
+ | ``clamp_min``
+ 9 | ``input_norm_`` int 1 | max clip value for obs normalization |
+ | ``clamp_max``
+ 10 | ``intrinsic_`` float 0.01 | the weight of intrinsic reward | r = w*r_i + r_e
+ ``reward_weight``
+ 11 | ``extrinsic_`` bool True | Whether to normlize extrinsic reward
+ ``reward_norm``
+ 12 | ``extrinsic_`` int 1 | the upper bound of the reward
+ ``reward_norm_max`` | normalization
+ == ==================== ===== ============= ======================================= =======================
+ """
+ config = dict(
+ # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
+ type='rnd',
+ # (str) The intrinsic reward type, including add, new, or assign.
+ intrinsic_reward_type='add',
+ # (float) The step size of gradient descent.
+ learning_rate=1e-3,
+ # (float) Batch size.
+ batch_size=64,
+ # (list(int)) Sequence of ``hidden_size`` of reward network.
+ # If obs.shape == 1, use MLP layers.
+ # If obs.shape == 3, use conv layer and final dense layer.
+ hidden_size_list=[64, 64, 128],
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=100,
+ # (bool) Observation normalization: transform obs to mean 0, std 1.
+ input_norm=True,
+ # (int) Min clip value for observation normalization.
+ input_norm_clamp_min=-1,
+ # (int) Max clip value for observation normalization.
+ input_norm_clamp_max=1,
+ # Means the relative weight of RND intrinsic_reward.
+ # (float) The weight of intrinsic reward
+ # r = intrinsic_reward_weight * r_i + r_e.
+ intrinsic_reward_weight=0.01,
+ # (bool) Whether to normalize extrinsic reward.
+ # Normalize the reward to [0, extrinsic_reward_norm_max].
+ extrinsic_reward_norm=True,
+ # (int) The upper bound of the reward normalization.
+ extrinsic_reward_norm_max=1,
+ )
+
+ def __init__(self, config: EasyDict, device: str = 'cpu', tb_logger: 'SummaryWriter' = None,
+ representation_network: nn.Module = None, target_representation_network: nn.Module = None,
+ use_momentum_representation_network: bool = True) -> None: # noqa
+ super(RNDRewardModel, self).__init__()
+ self.cfg = config
+ self.representation_network = representation_network
+ self.target_representation_network = target_representation_network
+ self.use_momentum_representation_network = use_momentum_representation_network
+ self.input_type = self.cfg.input_type
+ assert self.input_type in ['obs', 'latent_state', 'obs_latent_state'], self.input_type
+ self.device = device
+ assert self.device == "cpu" or self.device.startswith("cuda")
+ self.rnd_buffer_size = config.rnd_buffer_size
+ self.intrinsic_reward_type = self.cfg.intrinsic_reward_type
+ if tb_logger is None:
+ from tensorboardX import SummaryWriter
+ tb_logger = SummaryWriter('rnd_reward_model')
+ self.tb_logger = tb_logger
+ if self.input_type == 'obs':
+ self.input_shape = self.cfg.obs_shape
+ self.reward_model = RNDNetwork(self.input_shape, self.cfg.hidden_size_list).to(self.device)
+ elif self.input_type == 'latent_state':
+ self.input_shape = self.cfg.latent_state_dim
+ self.reward_model = RNDNetwork(self.input_shape, self.cfg.hidden_size_list).to(self.device)
+ elif self.input_type == 'obs_latent_state':
+ if self.use_momentum_representation_network:
+ self.reward_model = RNDNetworkRepr(self.cfg.obs_shape, self.cfg.latent_state_dim, self.cfg.hidden_size_list[0:-1],
+ self.target_representation_network).to(self.device)
+ else:
+ self.reward_model = RNDNetworkRepr(self.cfg.obs_shape, self.cfg.latent_state_dim, self.cfg.hidden_size_list[0:-1],
+ self.representation_network).to(self.device)
+
+ assert self.intrinsic_reward_type in ['add', 'new', 'assign']
+ if self.input_type in ['obs', 'obs_latent_state']:
+ self.train_obs = []
+ if self.input_type == 'latent_state':
+ self.train_latent_state = []
+
+ self._optimizer_rnd = torch.optim.Adam(
+ self.reward_model.predictor.parameters(), lr=self.cfg.learning_rate, weight_decay=self.cfg.weight_decay
+ )
+
+ self._running_mean_std_rnd_reward = RunningMeanStd(epsilon=1e-4)
+ self._running_mean_std_rnd_obs = RunningMeanStd(epsilon=1e-4)
+ self.estimate_cnt_rnd = 0
+ self.train_cnt_rnd = 0
+
+ def _train_with_data_one_step(self) -> None:
+ if self.input_type in ['obs', 'obs_latent_state']:
+ train_data = random.sample(self.train_obs, self.cfg.batch_size)
+ elif self.input_type == 'latent_state':
+ train_data = random.sample(self.train_latent_state, self.cfg.batch_size)
+
+ train_data = torch.stack(train_data).to(self.device)
+
+ if self.cfg.input_norm:
+ # Note: observation normalization: transform obs to mean 0, std 1
+ self._running_mean_std_rnd_obs.update(train_data.detach().cpu().numpy())
+ normalized_train_data = (train_data - to_tensor(self._running_mean_std_rnd_obs.mean).to(
+ self.device)) / to_tensor(
+ self._running_mean_std_rnd_obs.std
+ ).to(self.device)
+ train_data = torch.clamp(normalized_train_data, min=self.cfg.input_norm_clamp_min,
+ max=self.cfg.input_norm_clamp_max)
+
+ predict_feature, target_feature = self.reward_model(train_data)
+ loss = F.mse_loss(predict_feature, target_feature)
+
+ self.tb_logger.add_scalar('rnd_reward_model/rnd_mse_loss', loss, self.train_cnt_rnd)
+ self._optimizer_rnd.zero_grad()
+ loss.backward()
+ self._optimizer_rnd.step()
+
+ def train_with_data(self) -> None:
+ for _ in range(self.cfg.update_per_collect):
+ # for name, param in self.reward_model.named_parameters():
+ # if param.grad is not None:
+ # print(f"{name}: {torch.isnan(param.grad).any()}, {torch.isinf(param.grad).any()}")
+ # print(f"{name}: grad min: {param.grad.min()}, grad max: {param.grad.max()}")
+ # # enable the following line to check whether there is nan or inf in the gradient.
+ # torch.autograd.set_detect_anomaly(True)
+ self._train_with_data_one_step()
+ self.train_cnt_rnd += 1
+
+ def estimate(self, data: list) -> List[Dict]:
+ """
+ Rewrite the reward key in each row of the data.
+ """
+ # current_batch, target_batch = data
+ # obs_batch_orig, action_batch, mask_batch, indices, weights, make_time = current_batch
+ # target_reward, target_value, target_policy = target_batch
+ obs_batch_orig = data[0][0]
+ target_reward = data[1][0]
+ batch_size = obs_batch_orig.shape[0]
+ # reshape to (4, 2835, 6)
+ obs_batch_tmp = np.reshape(obs_batch_orig, (batch_size, self.cfg.obs_shape, 6))
+ # reshape to (24, 2835)
+ obs_batch_tmp = np.reshape(obs_batch_tmp, (batch_size * 6, self.cfg.obs_shape))
+
+ if self.input_type == 'latent_state':
+ with torch.no_grad():
+ latent_state = self.representation_network(torch.from_numpy(obs_batch_tmp).to(self.device))
+ input_data = latent_state
+ elif self.input_type in ['obs', 'obs_latent_state']:
+ input_data = to_tensor(obs_batch_tmp).to(self.device)
+
+ # NOTE: deepcopy reward part of data is very important,
+ # otherwise the reward of data in the replay buffer will be incorrectly modified.
+ target_reward_augmented = copy.deepcopy(target_reward)
+ target_reward_augmented = np.reshape(target_reward_augmented, (batch_size * 6, 1))
+
+ if self.cfg.input_norm:
+ # add this line to avoid inplace operation on the original tensor.
+ input_data = input_data.clone()
+ # Note: observation normalization: transform obs to mean 0, std 1
+ input_data = (input_data - to_tensor(self._running_mean_std_rnd_obs.mean
+ ).to(self.device)) / to_tensor(self._running_mean_std_rnd_obs.std).to(
+ self.device)
+ input_data = torch.clamp(input_data, min=self.cfg.input_norm_clamp_min, max=self.cfg.input_norm_clamp_max)
+ else:
+ input_data = input_data
+ with torch.no_grad():
+ predict_feature, target_feature = self.reward_model(input_data)
+ mse = F.mse_loss(predict_feature, target_feature, reduction='none').mean(dim=1)
+ self._running_mean_std_rnd_reward.update(mse.detach().cpu().numpy())
+
+ # Note: according to the min-max normalization, transform rnd reward to [0,1]
+ rnd_reward = (mse - mse.min()) / (mse.max() - mse.min() + 1e-6)
+
+ # save the rnd_reward statistics into tb_logger
+ self.estimate_cnt_rnd += 1
+ self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_max', rnd_reward.max(), self.estimate_cnt_rnd)
+ self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_mean', rnd_reward.mean(), self.estimate_cnt_rnd)
+ self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_min', rnd_reward.min(), self.estimate_cnt_rnd)
+ self.tb_logger.add_scalar('rnd_reward_model/rnd_reward_std', rnd_reward.std(), self.estimate_cnt_rnd)
+
+ rnd_reward = rnd_reward.to(self.device).unsqueeze(1).cpu().numpy()
+ if self.intrinsic_reward_type == 'add':
+ if self.cfg.extrinsic_reward_norm:
+ target_reward_augmented = target_reward_augmented / self.cfg.extrinsic_reward_norm_max + rnd_reward * self.cfg.intrinsic_reward_weight
+ else:
+ target_reward_augmented = target_reward_augmented + rnd_reward * self.cfg.intrinsic_reward_weight
+ elif self.intrinsic_reward_type == 'new':
+ if self.cfg.extrinsic_reward_norm:
+ target_reward_augmented = target_reward_augmented / self.cfg.extrinsic_reward_norm_max
+ elif self.intrinsic_reward_type == 'assign':
+ target_reward_augmented = rnd_reward
+
+ self.tb_logger.add_scalar('augmented_reward/reward_max', np.max(target_reward_augmented), self.estimate_cnt_rnd)
+ self.tb_logger.add_scalar('augmented_reward/reward_mean', np.mean(target_reward_augmented),
+ self.estimate_cnt_rnd)
+ self.tb_logger.add_scalar('augmented_reward/reward_min', np.min(target_reward_augmented), self.estimate_cnt_rnd)
+ self.tb_logger.add_scalar('augmented_reward/reward_std', np.std(target_reward_augmented), self.estimate_cnt_rnd)
+
+ # reshape to (target_reward_augmented.shape[0], 6, 1)
+ target_reward_augmented = np.reshape(target_reward_augmented, (batch_size, 6, 1))
+ data[1][0] = target_reward_augmented
+ train_data_augmented = data
+
+ return train_data_augmented
+
+ def collect_data(self, data: list) -> None:
+ # TODO(pu): now we only collect the first 300 steps of each game segment.
+ collected_transitions = np.concatenate([game_segment.obs_segment[:300] for game_segment in data[0]], axis=0)
+ if self.input_type == 'latent_state':
+ with torch.no_grad():
+ self.train_latent_state.extend(
+ self.representation_network(torch.from_numpy(collected_transitions).to(self.device)))
+ elif self.input_type == 'obs':
+ self.train_obs.extend(to_tensor(collected_transitions).to(self.device))
+ elif self.input_type == 'obs_latent_state':
+ self.train_obs.extend(to_tensor(collected_transitions).to(self.device))
+
+ def clear_old_data(self) -> None:
+ if self.input_type == 'latent_state':
+ if len(self.train_latent_state) >= self.cfg.rnd_buffer_size:
+ self.train_latent_state = self.train_latent_state[-self.cfg.rnd_buffer_size:]
+ elif self.input_type == 'obs':
+ if len(self.train_obs) >= self.cfg.rnd_buffer_size:
+ self.train_obs = self.train_obs[-self.cfg.rnd_buffer_size:]
+ elif self.input_type == 'obs_latent_state':
+ if len(self.train_obs) >= self.cfg.rnd_buffer_size:
+ self.train_obs = self.train_obs[-self.cfg.rnd_buffer_size:]
+
+ def state_dict(self) -> Dict:
+ return self.reward_model.state_dict()
+
+ def load_state_dict(self, _state_dict: Dict) -> None:
+ self.reward_model.load_state_dict(_state_dict)
+
+ def clear_data(self):
+ pass
+
+ def train(self):
+ pass
diff --git a/lzero/worker/alphazero_collector.py b/lzero/worker/alphazero_collector.py
index 5ec2e6043..ced7727d0 100644
--- a/lzero/worker/alphazero_collector.py
+++ b/lzero/worker/alphazero_collector.py
@@ -215,14 +215,11 @@ def collect(self,
obs_ = {env_id: obs[env_id] for env_id in ready_env_id}
# Policy forward.
self._obs_pool.update(obs_)
- simulation_envs = {}
- for env_id in ready_env_id:
- # create the new simulation env instances from the current collect env using the same env_config.
- simulation_envs[env_id] = self._env._env_fn[env_id]()
+
# ==============================================================
# policy forward
# ==============================================================
- policy_output = self._policy.forward(simulation_envs, obs_, temperature)
+ policy_output = self._policy.forward(obs_, temperature)
self._policy_output_pool.update(policy_output)
# Interact with env.
actions = {env_id: output['action'] for env_id, output in policy_output.items()}
diff --git a/lzero/worker/alphazero_evaluator.py b/lzero/worker/alphazero_evaluator.py
index 47528e9dd..c9eb8f650 100644
--- a/lzero/worker/alphazero_evaluator.py
+++ b/lzero/worker/alphazero_evaluator.py
@@ -202,15 +202,11 @@ def eval(
with self._timer:
while not eval_monitor.is_finished():
obs = self._env.ready_obs
- simulation_envs = {}
- for env_id in list(obs.keys()):
- # create the new simulation env instances from the current evaluate env using the same env_config.
- simulation_envs[env_id] = self._env._env_fn[env_id]()
# ==============================================================
# policy forward
# ==============================================================
- policy_output = self._policy.forward(simulation_envs, obs)
+ policy_output = self._policy.forward(obs)
actions = {env_id: output['action'] for env_id, output in policy_output.items()}
# ==============================================================
# Interact with env.
diff --git a/zoo/board_games/alphabeta_pruning_bot.py b/zoo/board_games/alphabeta_pruning_bot.py
index 1f456b682..83a36665d 100644
--- a/zoo/board_games/alphabeta_pruning_bot.py
+++ b/zoo/board_games/alphabeta_pruning_bot.py
@@ -17,7 +17,7 @@ def __init__(self, board, legal_actions, start_player_index=0, parent=None, prev
super().__init__()
self.env = env
self.board = board
- self.legal_actions = legal_actions
+ self.legal_actions = copy.deepcopy(legal_actions)
self.children = []
self.parent = parent
self.prev_action = prev_action
diff --git a/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py b/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py
index 8a29c363b..9321037f4 100644
--- a/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py
+++ b/zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py
@@ -28,8 +28,21 @@
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
+ # ==============================================================
+ # for the creation of simulation env
+ agent_vs_human=False,
+ prob_random_agent=0,
+ prob_expert_agent=0,
+ scale=True,
+ check_action_to_connect4_in_bot_v0=False,
+ # ==============================================================
),
policy=dict(
+ # ==============================================================
+ # for the creation of simulation env
+ simulation_env_name='gomoku',
+ simulation_env_config_type='play_with_bot',
+ # ==============================================================
torch_compile=False,
tensor_float_32=False,
model=dict(
diff --git a/zoo/board_games/gomoku/config/gomoku_alphazero_sp_mode_config.py b/zoo/board_games/gomoku/config/gomoku_alphazero_sp_mode_config.py
index 1a139a405..22eb87564 100644
--- a/zoo/board_games/gomoku/config/gomoku_alphazero_sp_mode_config.py
+++ b/zoo/board_games/gomoku/config/gomoku_alphazero_sp_mode_config.py
@@ -28,8 +28,21 @@
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
+ # ==============================================================
+ # for the creation of simulation env
+ agent_vs_human=False,
+ prob_random_agent=0,
+ prob_expert_agent=0,
+ scale=True,
+ check_action_to_connect4_in_bot_v0=False,
+ # ==============================================================
),
policy=dict(
+ # ==============================================================
+ # for the creation of simulation env
+ simulation_env_name='gomoku',
+ simulation_env_config_type='self_play',
+ # ==============================================================
torch_compile=False,
tensor_float_32=False,
model=dict(
diff --git a/zoo/board_games/mcts_bot.py b/zoo/board_games/mcts_bot.py
index 981d4e257..32609de54 100644
--- a/zoo/board_games/mcts_bot.py
+++ b/zoo/board_games/mcts_bot.py
@@ -14,7 +14,7 @@
from collections import defaultdict
import numpy as np
-
+import copy
class MCTSNode(ABC):
"""
@@ -151,7 +151,7 @@ def __init__(self, env, parent=None):
@property
def legal_actions(self):
if self._legal_actions is None:
- self._legal_actions = self.env.legal_actions
+ self._legal_actions = copy.deepcopy(self.env.legal_actions)
return self._legal_actions
@property
@@ -195,7 +195,6 @@ def expand(self):
Returns:
- node(:obj:`TwoPlayersMCTSNode`): The child node object that has been created.
"""
-
# Choose an untried action from the list of legal actions and pop it out. Only untried actions are left in the list.
action = self.legal_actions.pop()
diff --git a/zoo/board_games/test_speed_win-rate_between_bots.py b/zoo/board_games/test_speed_win-rate_between_bots.py
index 2d43ecf53..bfe7f2455 100644
--- a/zoo/board_games/test_speed_win-rate_between_bots.py
+++ b/zoo/board_games/test_speed_win-rate_between_bots.py
@@ -1,9 +1,9 @@
"""
Overview:
- Implemrnt games between different bots to test the win rates and the speed.
+ Implement games between different bots to test the win rates and the speed.
Example:
test_tictactoe_mcts_bot_vs_alphabeta_bot means a game between mcts_bot and alphabeta_bot where
- mcts_bot makes the first move(bots on the left make the first move).
+ mcts_bot makes the first move (i.e. bots on the left make the first move).
"""
import time
@@ -17,7 +17,7 @@
cfg_tictactoe = dict(
battle_mode='self_play_mode',
agent_vs_human=False,
- bot_action_type=['v0', 'alpha_beta_pruning'], # {'v0', 'alpha_beta_pruning'}
+ bot_action_type='v0', # {'v0', 'alpha_beta_pruning'}
prob_random_agent=0,
prob_expert_agent=0,
channel_last=True,
@@ -33,13 +33,14 @@ def test_tictactoe_mcts_bot_vs_rule_bot_v0_bot(num_simulations=50):
Arguments:
- num_simulations (:obj:`int`): The number of the simulations required to find the best move.
"""
+ cfg_tictactoe['bot_action_type'] = 'v0'
# List to record the time required for each decision round and the winner.
mcts_bot_time_list = []
bot_action_time_list = []
winner = []
# Repeat the game for 10 rounds.
- for i in range(100):
+ for i in range(10):
print('-' * 10 + str(i) + '-' * 10)
# Initialize the game, where there are two players: player 1 and player 2.
env = TicTacToeEnv(EasyDict(cfg_tictactoe))
@@ -108,6 +109,8 @@ def test_tictactoe_alphabeta_bot_vs_rule_bot_v0_bot(num_simulations=50):
Arguments:
- num_simulations (:obj:`int`): The number of the simulations required to find the best move.
"""
+ cfg_tictactoe['bot_action_type'] = 'alpha_beta_pruning'
+
# List to record the time required for each decision round and the winner.
alphabeta_pruning_time_list = []
rule_bot_v0_time_list = []
@@ -186,6 +189,8 @@ def test_tictactoe_alphabeta_bot_vs_mcts_bot(num_simulations=50):
Arguments:
- num_simulations (:obj:`int`): The number of the simulations required to find the best move.
"""
+ cfg_tictactoe['bot_action_type'] = 'alpha_beta_pruning'
+
# List to record the time required for each decision round and the winner.
alphabeta_pruning_time_list = []
mcts_bot_time_list = []
@@ -266,6 +271,8 @@ def test_tictactoe_rule_bot_v0_bot_vs_alphabeta_bot(num_simulations=50):
Arguments:
- num_simulations (:obj:`int`): The number of the simulations required to find the best move.
"""
+ cfg_tictactoe['bot_action_type'] = 'alpha_beta_pruning'
+
# List to record the time required for each decision round and the winner.
alphabeta_pruning_time_list = []
rule_bot_v0_time_list = []
@@ -344,6 +351,8 @@ def test_tictactoe_mcts_bot_vs_alphabeta_bot(num_simulations=50):
Arguments:
- num_simulations (:obj:`int`): The number of the simulations required to find the best move.
"""
+ cfg_tictactoe['bot_action_type'] = 'alpha_beta_pruning'
+
# List to record the time required for each decision round and the winner.
alphabeta_pruning_time_list = []
mcts_bot_time_list = []
@@ -509,18 +518,21 @@ def test_gomoku_mcts_bot_vs_rule_bot_v0_bot(num_simulations=50):
if __name__ == '__main__':
# ==============================================================
- # test win rate between alphabeta_bot and rule_bot_v0/mcts_bot
+ # test win rate between alphabeta_bot and rule_bot_v0
# ==============================================================
# test_tictactoe_alphabeta_bot_vs_rule_bot_v0_bot()
# test_tictactoe_rule_bot_v0_bot_vs_alphabeta_bot()
+ # ==============================================================
+ # test win rate between alphabeta_bot and mcts_bot
+ # ==============================================================
# test_tictactoe_alphabeta_bot_vs_mcts_bot(num_simulations=2000)
- test_tictactoe_mcts_bot_vs_alphabeta_bot(num_simulations=2000)
+ # test_tictactoe_mcts_bot_vs_alphabeta_bot(num_simulations=2000)
# ==============================================================
# test win rate between mcts_bot and rule_bot_v0
# ==============================================================
- # test_tictactoe_mcts_bot_vs_rule_bot_v0_bot(num_simulations=50)
- # test_tictactoe_mcts_bot_vs_rule_bot_v0_bot(num_simulations=100)
+ test_tictactoe_mcts_bot_vs_rule_bot_v0_bot(num_simulations=50)
# test_tictactoe_mcts_bot_vs_rule_bot_v0_bot(num_simulations=500)
# test_tictactoe_mcts_bot_vs_rule_bot_v0_bot(num_simulations=1000)
+
# test_gomoku_mcts_bot_vs_rule_bot_v0_bot(num_simulations=1000)
diff --git a/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_config.py
index 36d2e1ea7..143cdfb8d 100644
--- a/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_config.py
+++ b/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_config.py
@@ -19,13 +19,26 @@
env=dict(
board_size=3,
battle_mode='play_with_bot_mode',
+ bot_action_type='v0', # {'v0', 'alpha_beta_pruning'}
channel_last=False, # NOTE
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
+ # ==============================================================
+ # for the creation of simulation env
+ agent_vs_human=False,
+ prob_random_agent=0,
+ prob_expert_agent=0,
+ scale=True,
+ # ==============================================================
),
policy=dict(
+ # ==============================================================
+ # for the creation of simulation env
+ simulation_env_name='tictactoe',
+ simulation_env_config_type='play_with_bot',
+ # ==============================================================
model=dict(
observation_shape=(3, 3, 3),
action_space_size=int(1 * 3 * 3),
diff --git a/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_multigpu_ddp_config.py b/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_multigpu_ddp_config.py
index 68e77335c..6f2f9f7be 100644
--- a/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_multigpu_ddp_config.py
+++ b/zoo/board_games/tictactoe/config/tictactoe_alphazero_bot_mode_multigpu_ddp_config.py
@@ -20,13 +20,26 @@
env=dict(
board_size=3,
battle_mode='play_with_bot_mode',
+ bot_action_type='v0', # {'v0', 'alpha_beta_pruning'}
channel_last=False, # NOTE
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
+ # ==============================================================
+ # for the creation of simulation env
+ agent_vs_human=False,
+ prob_random_agent=0,
+ prob_expert_agent=0,
+ scale=True,
+ # ==============================================================
),
policy=dict(
+ # ==============================================================
+ # for the creation of simulation env
+ simulation_env_name='tictactoe',
+ simulation_env_config_type='play_with_bot',
+ # ==============================================================
model=dict(
observation_shape=(3, 3, 3),
action_space_size=int(1 * 3 * 3),
diff --git a/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_config.py b/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_config.py
index defbca893..07a04943d 100644
--- a/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_config.py
+++ b/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_config.py
@@ -18,13 +18,26 @@
env=dict(
board_size=3,
battle_mode='self_play_mode',
+ bot_action_type='v0', # {'v0', 'alpha_beta_pruning'}
channel_last=False, # NOTE
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
+ # ==============================================================
+ # for the creation of simulation env
+ agent_vs_human=False,
+ prob_random_agent=0,
+ prob_expert_agent=0,
+ scale=True,
+ # ==============================================================
),
policy=dict(
+ # ==============================================================
+ # for the creation of simulation env
+ simulation_env_name='tictactoe',
+ simulation_env_config_type='self_play',
+ # ==============================================================
model=dict(
observation_shape=(3, 3, 3),
action_space_size=int(1 * 3 * 3),
diff --git a/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_multigpu_ddp_config.py b/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_multigpu_ddp_config.py
index 2916910a2..dd324cb65 100644
--- a/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_multigpu_ddp_config.py
+++ b/zoo/board_games/tictactoe/config/tictactoe_alphazero_sp_mode_multigpu_ddp_config.py
@@ -19,13 +19,26 @@
env=dict(
board_size=3,
battle_mode='self_play_mode',
- channel_last=False, # NOTE
+ bot_action_type='v0', # {'v0', 'alpha_beta_pruning'}
+ channel_last=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, ),
+ # ==============================================================
+ # for the creation of simulation env
+ agent_vs_human=False,
+ prob_random_agent=0,
+ prob_expert_agent=0,
+ scale=True,
+ # ==============================================================
),
policy=dict(
+ # ==============================================================
+ # for the creation of simulation env
+ simulation_env_name='tictactoe',
+ simulation_env_config_type='self_play',
+ # ==============================================================
model=dict(
observation_shape=(3, 3, 3),
action_space_size=int(1 * 3 * 3),
diff --git a/zoo/board_games/tictactoe/envs/tictactoe_env.py b/zoo/board_games/tictactoe/envs/tictactoe_env.py
index 674b58310..9c0ec3633 100644
--- a/zoo/board_games/tictactoe/envs/tictactoe_env.py
+++ b/zoo/board_games/tictactoe/envs/tictactoe_env.py
@@ -329,6 +329,8 @@ def bot_action(self):
return self.rule_bot_v0()
elif self.bot_action_type == 'alpha_beta_pruning':
return self.bot_action_alpha_beta_pruning()
+ else:
+ raise NotImplementedError
def bot_action_alpha_beta_pruning(self):
action = self.alpha_beta_pruning_player.get_best_action(self.board, player_index=self.current_player_index)
diff --git a/zoo/bsuite/config/bsuite_efficientzero_config.py b/zoo/bsuite/config/bsuite_efficientzero_config.py
new file mode 100644
index 000000000..a69593d71
--- /dev/null
+++ b/zoo/bsuite/config/bsuite_efficientzero_config.py
@@ -0,0 +1,103 @@
+from easydict import EasyDict
+
+# options={'memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22', 'memory_size/0', 'bsuite_swingup/0', 'bandit_noise/0'}
+env_name = 'memory_len/9'
+
+
+if env_name in ['memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22']:
+ # the memory_length of above envs is 1, 10, 50, 80, 100, respectively.
+ action_space_size = 2
+ observation_shape = 3
+elif env_name in ['bsuite_swingup/0']:
+ action_space_size = 3
+ observation_shape = 8
+elif env_name == 'bandit_noise/0':
+ action_space_size = 11
+ observation_shape = 1
+elif env_name in ['memory_size/0']:
+ action_space_size = 2
+ observation_shape = 3
+else:
+ raise NotImplementedError
+
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+seed = 0
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 3
+num_simulations = 50
+update_per_collect = 100
+batch_size = 256
+max_env_step = int(5e5)
+reanalyze_ratio = 0.
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+bsuite_efficientzero_config = dict(
+ exp_name=
+ f'data_ez_ctree/bsuite_{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}',
+ env=dict(
+ env_name=env_name,
+ stop_value=int(1e6),
+ continuous=False,
+ manually_discretization=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=observation_shape,
+ action_space_size=action_space_size,
+ model_type='mlp',
+ lstm_hidden_size=128,
+ latent_state_dim=128,
+ discrete_action_encoding_type='one_hot',
+ norm_type='BN',
+ ),
+ cuda=True,
+ env_type='not_board_games',
+ game_segment_length=50,
+ update_per_collect=update_per_collect,
+ batch_size=batch_size,
+ optim_type='Adam',
+ lr_piecewise_constant_decay=False,
+ learning_rate=0.003,
+ num_simulations=num_simulations,
+ reanalyze_ratio=reanalyze_ratio,
+ n_episode=n_episode,
+ eval_freq=int(2e2),
+ replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ ),
+)
+
+bsuite_efficientzero_config = EasyDict(bsuite_efficientzero_config)
+main_config = bsuite_efficientzero_config
+
+bsuite_efficientzero_create_config = dict(
+ env=dict(
+ type='bsuite_lightzero',
+ import_names=['zoo.bsuite.envs.bsuite_lightzero_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='efficientzero',
+ import_names=['lzero.policy.efficientzero'],
+ ),
+ collector=dict(
+ type='episode_muzero',
+ import_names=['lzero.worker.muzero_collector'],
+ )
+)
+bsuite_efficientzero_create_config = EasyDict(bsuite_efficientzero_create_config)
+create_config = bsuite_efficientzero_create_config
+
+if __name__ == "__main__":
+ from lzero.entry import train_muzero
+ train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step)
diff --git a/zoo/bsuite/config/bsuite_muzero_config.py b/zoo/bsuite/config/bsuite_muzero_config.py
new file mode 100644
index 000000000..e598a647f
--- /dev/null
+++ b/zoo/bsuite/config/bsuite_muzero_config.py
@@ -0,0 +1,105 @@
+from easydict import EasyDict
+
+# options={'memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22', 'memory_size/0', 'bsuite_swingup/0', 'bandit_noise/0'}
+env_name = 'memory_len/9'
+
+
+if env_name in ['memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22']:
+ # the memory_length of above envs is 1, 10, 50, 80, 100, respectively.
+ action_space_size = 2
+ observation_shape = 3
+elif env_name in ['bsuite_swingup/0']:
+ action_space_size = 3
+ observation_shape = 8
+elif env_name == 'bandit_noise/0':
+ action_space_size = 11
+ observation_shape = 1
+elif env_name in ['memory_size/0']:
+ action_space_size = 2
+ observation_shape = 3
+else:
+ raise NotImplementedError
+
+
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+seed = 0
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 3
+num_simulations = 50
+update_per_collect = 100
+batch_size = 256
+max_env_step = int(5e5)
+reanalyze_ratio = 0
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+bsuite_muzero_config = dict(
+ exp_name=f'data_mz_ctree/bsuite_{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}',
+ env=dict(
+ env_name=env_name,
+ stop_value=int(1e6),
+ continuous=False,
+ manually_discretization=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=observation_shape,
+ action_space_size=action_space_size,
+ model_type='mlp',
+ lstm_hidden_size=128,
+ latent_state_dim=128,
+ self_supervised_learning_loss=True, # NOTE: default is False.
+ discrete_action_encoding_type='one_hot',
+ norm_type='BN',
+ ),
+ cuda=True,
+ env_type='not_board_games',
+ game_segment_length=50,
+ update_per_collect=update_per_collect,
+ batch_size=batch_size,
+ optim_type='AdamW',
+ lr_piecewise_constant_decay=False,
+ learning_rate=0.003,
+ ssl_loss_weight=2, # NOTE: default is 0.
+ num_simulations=num_simulations,
+ reanalyze_ratio=reanalyze_ratio,
+ n_episode=n_episode,
+ eval_freq=int(2e2),
+ replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ ),
+)
+
+bsuite_muzero_config = EasyDict(bsuite_muzero_config)
+main_config = bsuite_muzero_config
+
+bsuite_muzero_create_config = dict(
+ env=dict(
+ type='bsuite_lightzero',
+ import_names=['zoo.bsuite.envs.bsuite_lightzero_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='muzero',
+ import_names=['lzero.policy.muzero'],
+ ),
+ collector=dict(
+ type='episode_muzero',
+ import_names=['lzero.worker.muzero_collector'],
+ )
+)
+bsuite_muzero_create_config = EasyDict(bsuite_muzero_create_config)
+create_config = bsuite_muzero_create_config
+
+if __name__ == "__main__":
+ from lzero.entry import train_muzero
+ train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step)
\ No newline at end of file
diff --git a/zoo/bsuite/config/bsuite_sampled_efficientzero_config.py b/zoo/bsuite/config/bsuite_sampled_efficientzero_config.py
new file mode 100644
index 000000000..38ffe28bd
--- /dev/null
+++ b/zoo/bsuite/config/bsuite_sampled_efficientzero_config.py
@@ -0,0 +1,107 @@
+from easydict import EasyDict
+
+# options={'memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22', 'memory_size/0', 'bsuite_swingup/0', 'bandit_noise/0'}
+env_name = 'memory_len/9'
+
+
+if env_name in ['memory_len/0', 'memory_len/9', 'memory_len/17', 'memory_len/20', 'memory_len/22']:
+ # the memory_length of above envs is 1, 10, 50, 80, 100, respectively.
+ action_space_size = 2
+ observation_shape = 3
+elif env_name in ['bsuite_swingup/0']:
+ action_space_size = 3
+ observation_shape = 8
+elif env_name == 'bandit_noise/0':
+ action_space_size = 11
+ observation_shape = 1
+elif env_name in ['memory_size/0']:
+ action_space_size = 2
+ observation_shape = 3
+else:
+ raise NotImplementedError
+
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+seed = 0
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 3
+continuous_action_space = False
+K = 2 # num_of_sampled_actions
+num_simulations = 50
+update_per_collect = 100
+batch_size = 256
+max_env_step = int(5e5)
+reanalyze_ratio = 0.
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+bsuite_sampled_efficientzero_config = dict(
+ exp_name=
+ f'data_sez_ctree/bsuite_sampled_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}',
+ env=dict(
+ env_name=env_name,
+ stop_value=int(1e6),
+ continuous=False,
+ manually_discretization=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=observation_shape,
+ action_space_size=action_space_size,
+ continuous_action_space=continuous_action_space,
+ num_of_sampled_actions=K,
+ model_type='mlp',
+ lstm_hidden_size=128,
+ latent_state_dim=128,
+ discrete_action_encoding_type='one_hot',
+ norm_type='BN',
+ ),
+ cuda=True,
+ env_type='not_board_games',
+ game_segment_length=50,
+ update_per_collect=update_per_collect,
+ batch_size=batch_size,
+ optim_type='Adam',
+ lr_piecewise_constant_decay=False,
+ learning_rate=0.003,
+ num_simulations=num_simulations,
+ reanalyze_ratio=reanalyze_ratio,
+ n_episode=n_episode,
+ eval_freq=int(2e2),
+ replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ ),
+)
+
+bsuite_sampled_efficientzero_config = EasyDict(bsuite_sampled_efficientzero_config)
+main_config = bsuite_sampled_efficientzero_config
+
+bsuite_sampled_efficientzero_create_config = dict(
+ env=dict(
+ type='bsuite_lightzero',
+ import_names=['zoo.bsuite.envs.bsuite_lightzero_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sampled_efficientzero',
+ import_names=['lzero.policy.sampled_efficientzero'],
+ ),
+ collector=dict(
+ type='episode_muzero',
+ import_names=['lzero.worker.muzero_collector'],
+ )
+)
+bsuite_sampled_efficientzero_create_config = EasyDict(bsuite_sampled_efficientzero_create_config)
+create_config = bsuite_sampled_efficientzero_create_config
+
+if __name__ == "__main__":
+ from lzero.entry import train_muzero
+ train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step)
diff --git a/zoo/bsuite/entry/bsuite_eval_config.py b/zoo/bsuite/entry/bsuite_eval_config.py
new file mode 100644
index 000000000..fdbb7d285
--- /dev/null
+++ b/zoo/bsuite/entry/bsuite_eval_config.py
@@ -0,0 +1,37 @@
+from bsuite_muzero_config import main_config, create_config
+from lzero.entry import eval_muzero
+import numpy as np
+
+if __name__ == "__main__":
+ """
+ model_path (:obj:`Optional[str]`): The pretrained model path, which should
+ point to the ckpt file of the pretrained model, and an absolute path is recommended.
+ In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
+ """
+ model_path = "./ckpt/ckpt_best.pth.tar"
+ seeds = [0]
+ num_episodes_each_seed = 5
+ main_config.env.evaluator_env_num = 1
+ main_config.env.n_evaluator_episode = 1
+ total_test_episodes = num_episodes_each_seed * len(seeds)
+ returns_mean_seeds = []
+ returns_seeds = []
+ for seed in seeds:
+ returns_mean, returns = eval_muzero(
+ [main_config, create_config],
+ seed=seed,
+ num_episodes_each_seed=num_episodes_each_seed,
+ print_seed_details=False,
+ model_path=model_path
+ )
+ returns_mean_seeds.append(returns_mean)
+ returns_seeds.append(returns)
+
+ returns_mean_seeds = np.array(returns_mean_seeds)
+ returns_seeds = np.array(returns_seeds)
+
+ print("=" * 20)
+ print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.')
+ print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}')
+ print('In all seeds, reward_mean:', returns_mean_seeds.mean())
+ print("=" * 20)
diff --git a/zoo/bsuite/envs/bsuite_lightzero_env.py b/zoo/bsuite/envs/bsuite_lightzero_env.py
new file mode 100644
index 000000000..754c7a009
--- /dev/null
+++ b/zoo/bsuite/envs/bsuite_lightzero_env.py
@@ -0,0 +1,110 @@
+import copy
+from typing import List
+
+import bsuite
+import gym
+import numpy as np
+from bsuite import sweep
+from bsuite.utils import gym_wrapper
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.torch_utils import to_ndarray
+from ding.utils import ENV_REGISTRY
+
+
+@ENV_REGISTRY.register('bsuite_lightzero')
+class BSuiteEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+ self.env_name = cfg.env_name
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ raw_env = bsuite.load_from_id(bsuite_id=self.env_name)
+ self._env = gym_wrapper.GymFromDMEnv(raw_env)
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1,), dtype=np.float64
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ obs = self._env.reset()
+ if obs.shape[0] == 1:
+ obs = obs[0]
+ obs = to_ndarray(obs).astype(np.float32)
+ self._eval_episode_return = 0
+
+ action_mask = np.ones(self.action_space.n, 'int8')
+ obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
+
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ if obs.shape[0] == 1:
+ obs = obs[0]
+ obs = to_ndarray(obs)
+ rew = to_ndarray([rew]) # wrapped to be transfered to an array with shape (1,)
+
+ action_mask = np.ones(self.action_space.n, 'int8')
+ obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
+
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def config_info(self) -> dict:
+ config_info = sweep.SETTINGS[self.env_name] # additional info that are specific to each env configuration
+ config_info['num_episodes'] = self._env.bsuite_num_episodes
+ return config_info
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_env_num = cfg.pop('collector_env_num')
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = True
+ return [cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_env_num = cfg.pop('evaluator_env_num')
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = False
+ return [cfg for _ in range(evaluator_env_num)]
+
+ def __repr__(self) -> str:
+ return "LightZero BSuite Env({})".format(self.env_name)
diff --git a/zoo/bsuite/envs/check_bsuite_config.py b/zoo/bsuite/envs/check_bsuite_config.py
new file mode 100644
index 000000000..7109b3984
--- /dev/null
+++ b/zoo/bsuite/envs/check_bsuite_config.py
@@ -0,0 +1,23 @@
+import bsuite
+from bsuite import sweep
+
+# List the configurations for the given experiment
+for bsuite_id in sweep.BANDIT_NOISE:
+ env = bsuite.load_from_id(bsuite_id)
+ print('bsuite_id={}, settings={}, num_episodes={}'
+ .format(bsuite_id, sweep.SETTINGS[bsuite_id], env.bsuite_num_episodes))
+
+for bsuite_id in sweep.CARTPOLE_SWINGUP:
+ env = bsuite.load_from_id(bsuite_id)
+ print('bsuite_id={}, settings={}, num_episodes={}'
+ .format(bsuite_id, sweep.SETTINGS[bsuite_id], env.bsuite_num_episodes))
+
+for bsuite_id in sweep.MEMORY_LEN:
+ env = bsuite.load_from_id(bsuite_id)
+ print('bsuite_id={}, settings={}, num_episodes={}'
+ .format(bsuite_id, sweep.SETTINGS[bsuite_id], env.bsuite_num_episodes))
+
+for bsuite_id in sweep.MEMORY_SIZE:
+ env = bsuite.load_from_id(bsuite_id)
+ print('bsuite_id={}, settings={}, num_episodes={}'
+ .format(bsuite_id, sweep.SETTINGS[bsuite_id], env.bsuite_num_episodes))
diff --git a/zoo/minigrid/__init__.py b/zoo/minigrid/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/zoo/minigrid/config/__init__.py b/zoo/minigrid/config/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/zoo/minigrid/config/minigrd_sampled_efficientzero_config.py b/zoo/minigrid/config/minigrd_sampled_efficientzero_config.py
new file mode 100644
index 000000000..f9312afd0
--- /dev/null
+++ b/zoo/minigrid/config/minigrd_sampled_efficientzero_config.py
@@ -0,0 +1,114 @@
+from easydict import EasyDict
+
+# The typical MiniGrid env id: {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'},
+# please refer to https://github.com/Farama-Foundation/MiniGrid for details.
+env_name = 'MiniGrid-Empty-8x8-v0'
+max_env_step = int(1e6)
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+seed = 0
+
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 3
+continuous_action_space = False
+K = 5 # num_of_sampled_actions
+num_simulations = 50
+update_per_collect = 200
+batch_size = 256
+
+reanalyze_ratio = 0
+random_collect_episode_num = 0
+td_steps = 5
+policy_entropy_loss_weight = 0.
+threshold_training_steps_for_final_temperature = int(5e5)
+eps_greedy_exploration_in_collect = False
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+minigrid_sampled_efficientzero_config = dict(
+ exp_name=f'data_sez_ctree/{env_name}_sampled_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}',
+ env=dict(
+ env_name=env_name,
+ continuous=False,
+ manually_discretization=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=2835,
+ action_space_size=7,
+ continuous_action_space=continuous_action_space,
+ num_of_sampled_actions=K,
+ model_type='mlp',
+ lstm_hidden_size=256,
+ latent_state_dim=256,
+ discrete_action_encoding_type='one_hot',
+ norm_type='BN',
+ ),
+ policy_entropy_loss_weight=policy_entropy_loss_weight,
+ eps=dict(
+ eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect,
+ decay=int(2e5),
+ ),
+ td_steps=td_steps,
+ manual_temperature_decay=True,
+ threshold_training_steps_for_final_temperature=threshold_training_steps_for_final_temperature,
+ cuda=True,
+ env_type='not_board_games',
+ game_segment_length=50,
+ update_per_collect=update_per_collect,
+ batch_size=batch_size,
+ optim_type='Adam',
+ lr_piecewise_constant_decay=False,
+ learning_rate=0.003,
+ num_simulations=num_simulations,
+ reanalyze_ratio=reanalyze_ratio,
+ n_episode=n_episode,
+ eval_freq=int(2e2),
+ replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ ),
+)
+
+minigrid_sampled_efficientzero_config = EasyDict(minigrid_sampled_efficientzero_config)
+main_config = minigrid_sampled_efficientzero_config
+
+minigrid_sampled_efficientzero_create_config = dict(
+ env=dict(
+ type='minigrid_lightzero',
+ import_names=['zoo.minigrid.envs.minigrid_lightzero_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sampled_efficientzero',
+ import_names=['lzero.policy.sampled_efficientzero'],
+ ),
+ collector=dict(
+ type='episode_muzero',
+ import_names=['lzero.worker.muzero_collector'],
+ )
+)
+minigrid_sampled_efficientzero_create_config = EasyDict(minigrid_sampled_efficientzero_create_config)
+create_config = minigrid_sampled_efficientzero_create_config
+
+if __name__ == "__main__":
+ # Users can use different train entry by specifying the entry_type.
+ entry_type = "train_muzero" # options={"train_muzero", "train_muzero_with_gym_env"}
+
+ if entry_type == "train_muzero":
+ from lzero.entry import train_muzero
+ elif entry_type == "train_muzero_with_gym_env":
+ """
+ The ``train_muzero_with_gym_env`` entry means that the environment used in the training process is generated by wrapping the original gym environment with LightZeroEnvWrapper.
+ Users can refer to lzero/envs/wrappers for more details.
+ """
+ from lzero.entry import train_muzero_with_gym_env as train_muzero
+
+ train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step)
diff --git a/zoo/minigrid/config/minigrid_efficientzero_config.py b/zoo/minigrid/config/minigrid_efficientzero_config.py
new file mode 100644
index 000000000..4f48550e8
--- /dev/null
+++ b/zoo/minigrid/config/minigrid_efficientzero_config.py
@@ -0,0 +1,98 @@
+from easydict import EasyDict
+
+# The typical MiniGrid env id: {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'},
+# please refer to https://github.com/Farama-Foundation/MiniGrid for details.
+env_name = 'MiniGrid-Empty-8x8-v0'
+max_env_step = int(1e6)
+
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+seed = 0
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 3
+num_simulations = 50
+update_per_collect = 200
+batch_size = 256
+reanalyze_ratio = 0
+td_steps = 5
+policy_entropy_loss_weight = 0.
+threshold_training_steps_for_final_temperature = int(5e5)
+eps_greedy_exploration_in_collect = False
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+minigrid_efficientzero_config = dict(
+ exp_name=f'data_ez_ctree/{env_name}_efficientzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_seed{seed}',
+ env=dict(
+ stop_value=int(1e6),
+ env_name=env_name,
+ continuous=False,
+ manually_discretization=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=2835,
+ action_space_size=7,
+ model_type='mlp',
+ lstm_hidden_size=256,
+ latent_state_dim=256,
+ discrete_action_encoding_type='one_hot',
+ norm_type='BN',
+ ),
+ policy_entropy_loss_weight=policy_entropy_loss_weight,
+ eps=dict(
+ eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect,
+ decay=int(2e5),
+ ),
+ td_steps=td_steps,
+ manual_temperature_decay=True,
+ threshold_training_steps_for_final_temperature=threshold_training_steps_for_final_temperature,
+ cuda=True,
+ env_type='not_board_games',
+ game_segment_length=50,
+ update_per_collect=update_per_collect,
+ batch_size=batch_size,
+ optim_type='Adam',
+ lr_piecewise_constant_decay=False,
+ learning_rate=0.003,
+ num_simulations=num_simulations,
+ reanalyze_ratio=reanalyze_ratio,
+ n_episode=n_episode,
+ eval_freq=int(2e2),
+ replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ ),
+)
+
+minigrid_efficientzero_config = EasyDict(minigrid_efficientzero_config)
+main_config = minigrid_efficientzero_config
+
+minigrid_efficientzero_create_config = dict(
+ env=dict(
+ type='minigrid_lightzero',
+ import_names=['zoo.minigrid.envs.minigrid_lightzero_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='efficientzero',
+ import_names=['lzero.policy.efficientzero'],
+ ),
+ collector=dict(
+ type='episode_muzero',
+ import_names=['lzero.worker.muzero_collector'],
+ )
+)
+minigrid_efficientzero_create_config = EasyDict(minigrid_efficientzero_create_config)
+create_config = minigrid_efficientzero_create_config
+
+if __name__ == "__main__":
+ from lzero.entry import train_muzero
+ train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step)
diff --git a/zoo/minigrid/config/minigrid_muzero_config.py b/zoo/minigrid/config/minigrid_muzero_config.py
new file mode 100644
index 000000000..3a1a7ec28
--- /dev/null
+++ b/zoo/minigrid/config/minigrid_muzero_config.py
@@ -0,0 +1,101 @@
+from easydict import EasyDict
+
+# The typical MiniGrid env id: {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'},
+# please refer to https://github.com/Farama-Foundation/MiniGrid for details.
+env_name = 'MiniGrid-Empty-8x8-v0'
+max_env_step = int(1e6)
+
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+seed = 0
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 3
+num_simulations = 50
+update_per_collect = 200
+batch_size = 256
+reanalyze_ratio = 0
+td_steps = 5
+policy_entropy_loss_weight = 0. # 0.005
+threshold_training_steps_for_final_temperature = int(5e5)
+eps_greedy_exploration_in_collect = False
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+minigrid_muzero_config = dict(
+ exp_name=f'data_mz_ctree/{env_name}_muzero_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}_'
+ f'collect-eps-{eps_greedy_exploration_in_collect}_temp-final-steps-{threshold_training_steps_for_final_temperature}_pelw{policy_entropy_loss_weight}_seed{seed}',
+ env=dict(
+ stop_value=int(1e6),
+ env_name=env_name,
+ continuous=False,
+ manually_discretization=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=2835,
+ action_space_size=7,
+ model_type='mlp',
+ lstm_hidden_size=256,
+ latent_state_dim=512,
+ discrete_action_encoding_type='one_hot',
+ norm_type='BN',
+ self_supervised_learning_loss=True, # NOTE: default is False.
+ ),
+ eps=dict(
+ eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect,
+ decay=int(2e5),
+ ),
+ policy_entropy_loss_weight=policy_entropy_loss_weight,
+ td_steps=td_steps,
+ manual_temperature_decay=True,
+ threshold_training_steps_for_final_temperature=threshold_training_steps_for_final_temperature,
+ cuda=True,
+ env_type='not_board_games',
+ game_segment_length=50,
+ update_per_collect=update_per_collect,
+ batch_size=batch_size,
+ optim_type='Adam',
+ lr_piecewise_constant_decay=False,
+ learning_rate=0.003,
+ ssl_loss_weight=2, # NOTE: default is 0.
+ num_simulations=num_simulations,
+ reanalyze_ratio=reanalyze_ratio,
+ n_episode=n_episode,
+ eval_freq=int(2e2),
+ replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ ),
+)
+
+minigrid_muzero_config = EasyDict(minigrid_muzero_config)
+main_config = minigrid_muzero_config
+
+minigrid_muzero_create_config = dict(
+ env=dict(
+ type='minigrid_lightzero',
+ import_names=['zoo.minigrid.envs.minigrid_lightzero_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='muzero',
+ import_names=['lzero.policy.muzero'],
+ ),
+ collector=dict(
+ type='episode_muzero',
+ import_names=['lzero.worker.muzero_collector'],
+ )
+)
+minigrid_muzero_create_config = EasyDict(minigrid_muzero_create_config)
+create_config = minigrid_muzero_create_config
+
+if __name__ == "__main__":
+ from lzero.entry import train_muzero
+ train_muzero([main_config, create_config], seed=seed, max_env_step=max_env_step)
diff --git a/zoo/minigrid/config/minigrid_muzero_rnd_config.py b/zoo/minigrid/config/minigrid_muzero_rnd_config.py
new file mode 100644
index 000000000..eac8abe16
--- /dev/null
+++ b/zoo/minigrid/config/minigrid_muzero_rnd_config.py
@@ -0,0 +1,137 @@
+from easydict import EasyDict
+
+# The typical MiniGrid env id: {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'},
+# please refer to https://github.com/Farama-Foundation/MiniGrid for details.
+env_name = 'MiniGrid-Empty-8x8-v0'
+max_env_step = int(1e6)
+
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+seed = 0
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 3
+num_simulations = 50
+update_per_collect = 200
+batch_size = 256
+reanalyze_ratio = 0
+td_steps = 5
+
+# key exploration related config
+policy_entropy_loss_weight = 0.
+threshold_training_steps_for_final_temperature = int(5e5)
+eps_greedy_exploration_in_collect = True
+input_type = 'obs' # options=['obs', 'latent_state', 'obs_latent_state']
+target_model_for_intrinsic_reward_update_type = 'assign' # 'assign' or 'momentum'
+
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+minigrid_muzero_rnd_config = dict(
+ exp_name=f'data_mz_rnd_ctree/{env_name}_muzero-rnd_ns{num_simulations}_upc{update_per_collect}_rr{reanalyze_ratio}'
+ f'_collect-eps-{eps_greedy_exploration_in_collect}_temp-final-steps-{threshold_training_steps_for_final_temperature}_pelw{policy_entropy_loss_weight}'
+ f'_rnd-rew-{input_type}-{target_model_for_intrinsic_reward_update_type}_seed{seed}',
+ env=dict(
+ stop_value=int(1e6),
+ env_name=env_name,
+ continuous=False,
+ manually_discretization=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ reward_model=dict(
+ type='rnd_muzero',
+ intrinsic_reward_type='add',
+ input_type=input_type, # options=['obs', 'latent_state', 'obs_latent_state']
+ # intrinsic_reward_weight means the relative weight of RND intrinsic_reward.
+ # Specifically for sparse reward env MiniGrid, in this env, if we reach goal, the agent gets reward ~1, otherwise 0.
+ # We could set the intrinsic_reward_weight approximately equal to the inverse of max_episode_steps.Please refer to rnd_reward_model for details.
+ intrinsic_reward_weight=0.003, # 1/300
+ obs_shape=2835,
+ latent_state_dim=512,
+ hidden_size_list=[256, 256],
+ learning_rate=3e-3,
+ weight_decay=1e-4,
+ batch_size=batch_size,
+ update_per_collect=200,
+ rnd_buffer_size=int(1e6),
+ input_norm=True,
+ input_norm_clamp_max=5,
+ input_norm_clamp_min=-5,
+ extrinsic_reward_norm=True,
+ extrinsic_reward_norm_max=1,
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=2835,
+ action_space_size=7,
+ model_type='mlp',
+ lstm_hidden_size=256,
+ latent_state_dim=512,
+ discrete_action_encoding_type='one_hot',
+ norm_type='BN',
+ self_supervised_learning_loss=True, # NOTE: default is False.
+ ),
+ use_rnd_model=True,
+ # RND related config
+ use_momentum_representation_network=True,
+ target_model_for_intrinsic_reward_update_type=target_model_for_intrinsic_reward_update_type,
+ target_update_freq_for_intrinsic_reward=1000,
+ target_update_theta_for_intrinsic_reward=0.005,
+ # key exploration related config
+ policy_entropy_loss_weight=policy_entropy_loss_weight,
+ eps=dict(
+ eps_greedy_exploration_in_collect=eps_greedy_exploration_in_collect,
+ decay=int(2e5),
+ ),
+ manual_temperature_decay=True,
+ threshold_training_steps_for_final_temperature=threshold_training_steps_for_final_temperature,
+
+ cuda=True,
+ env_type='not_board_games',
+ game_segment_length=300,
+ update_per_collect=update_per_collect,
+ batch_size=batch_size,
+ optim_type='Adam',
+ lr_piecewise_constant_decay=False,
+ learning_rate=0.003,
+ ssl_loss_weight=2, # NOTE: default is 0.
+ td_steps=td_steps,
+ num_simulations=num_simulations,
+ reanalyze_ratio=reanalyze_ratio,
+ n_episode=n_episode,
+ eval_freq=int(2e2),
+ replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ ),
+)
+
+minigrid_muzero_rnd_config = EasyDict(minigrid_muzero_rnd_config)
+main_config = minigrid_muzero_rnd_config
+
+minigrid_muzero_create_config = dict(
+ env=dict(
+ type='minigrid_lightzero',
+ import_names=['zoo.minigrid.envs.minigrid_lightzero_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='muzero',
+ import_names=['lzero.policy.muzero'],
+ ),
+ collector=dict(
+ type='episode_muzero',
+ import_names=['lzero.worker.muzero_collector'],
+ )
+)
+minigrid_muzero_create_config = EasyDict(minigrid_muzero_create_config)
+create_config = minigrid_muzero_create_config
+
+if __name__ == "__main__":
+ from lzero.entry import train_muzero_with_reward_model
+ train_muzero_with_reward_model([main_config, create_config], seed=seed, max_env_step=max_env_step)
\ No newline at end of file
diff --git a/zoo/minigrid/entry/__init__.py b/zoo/minigrid/entry/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/zoo/minigrid/entry/minigrid_eval.py b/zoo/minigrid/entry/minigrid_eval.py
new file mode 100644
index 000000000..4f1b13141
--- /dev/null
+++ b/zoo/minigrid/entry/minigrid_eval.py
@@ -0,0 +1,37 @@
+from zoo.minigrid.config.minigrid_muzero_config import main_config, create_config
+from lzero.entry import eval_muzero
+import numpy as np
+
+if __name__ == "__main__":
+ """
+ model_path (:obj:`Optional[str]`): The pretrained model path, which should
+ point to the ckpt file of the pretrained model, and an absolute path is recommended.
+ In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
+ """
+ model_path = "./ckpt/ckpt_best.pth.tar"
+ seeds = [0]
+ num_episodes_each_seed = 5
+ main_config.env.evaluator_env_num = 1
+ main_config.env.n_evaluator_episode = 1
+ total_test_episodes = num_episodes_each_seed * len(seeds)
+ returns_mean_seeds = []
+ returns_seeds = []
+ for seed in seeds:
+ returns_mean, returns = eval_muzero(
+ [main_config, create_config],
+ seed=seed,
+ num_episodes_each_seed=num_episodes_each_seed,
+ print_seed_details=False,
+ model_path=model_path
+ )
+ returns_mean_seeds.append(returns_mean)
+ returns_seeds.append(returns)
+
+ returns_mean_seeds = np.array(returns_mean_seeds)
+ returns_seeds = np.array(returns_seeds)
+
+ print("=" * 20)
+ print(f'We eval total {len(seeds)} seeds. In each seed, we eval {num_episodes_each_seed} episodes.')
+ print(f'In seeds {seeds}, returns_mean_seeds is {returns_mean_seeds}, returns is {returns_seeds}')
+ print('In all seeds, reward_mean:', returns_mean_seeds.mean())
+ print("=" * 20)
diff --git a/zoo/minigrid/envs/minigrid_lightzero_env.py b/zoo/minigrid/envs/minigrid_lightzero_env.py
new file mode 100644
index 000000000..24ad5e686
--- /dev/null
+++ b/zoo/minigrid/envs/minigrid_lightzero_env.py
@@ -0,0 +1,216 @@
+import copy
+import os
+from typing import List, Optional
+
+import gymnasium as gym
+import matplotlib.pyplot as plt
+import numpy as np
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.envs import ObsPlusPrevActRewWrapper
+from ding.torch_utils import to_ndarray
+from ding.utils import ENV_REGISTRY
+from dizoo.minigrid.envs.minigrid_wrapper import ViewSizeWrapper
+from dizoo.minigrid.envs.minigrid_env import MiniGridEnv
+from easydict import EasyDict
+from matplotlib import animation
+from minigrid.wrappers import FlatObsWrapper
+
+
+@ENV_REGISTRY.register('minigrid_lightzero')
+class MiniGridEnvLightZero(MiniGridEnv):
+ """
+ Overview:
+ A MiniGrid environment for LightZero, based on OpenAI Gym.
+ Attributes:
+ config (dict): Configuration dict. Default configurations can be updated using this.
+ _cfg (dict): Internal configuration dict that stores runtime configurations.
+ _init_flag (bool): Flag to check if the environment is initialized.
+ _env_name (str): The name of the MiniGrid environment.
+ _flat_obs (bool): Flag to check if flat observations are returned.
+ _save_replay (bool): Flag to check if replays are saved.
+ _max_step (int): Maximum number of steps for the environment.
+ """
+ config = dict(
+ env_name='MiniGrid-Empty-8x8-v0',
+ flat_obs=True,
+ max_step=300,
+ )
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ """
+ Overview:
+ Returns the default configuration with the current environment class name.
+ Returns:
+ - cfg (:obj:`dict`): Configuration dict.
+ """
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ def __init__(self, cfg: dict) -> None:
+ """
+ Overview:
+ Initialize the environment.
+ Arguments:
+ - cfg (:obj:`dict`): Configuration dict. The configuration should include the environment name,
+ whether to use flat observations, and the maximum number of steps.
+ """
+ self._cfg = cfg
+ self._init_flag = False
+ self._env_name = cfg.env_name
+ self._flat_obs = cfg.flat_obs
+ self._save_replay = False
+ self._max_step = cfg.max_step
+
+ def reset(self) -> np.ndarray:
+ """
+ Overview:
+ Reset the environment and return the initial observation.
+ Returns:
+ - obs (:obj:`np.ndarray`): Initial observation from the environment.
+ """
+ if not self._init_flag:
+ if self._save_replay:
+ self._env = gym.make(self._env_name, render_mode="rgb_array")
+ else:
+ self._env = gym.make(self._env_name)
+ # NOTE: customize the max step of the env
+ self._env.max_steps = self._max_step
+
+ if self._env_name in ['MiniGrid-AKTDT-13x13-v0' or 'MiniGrid-AKTDT-13x13-1-v0']:
+ # customize the agent field of view size, note this must be an odd number
+ # This also related to the observation space, see gym_minigrid.wrappers for more details
+ self._env = ViewSizeWrapper(self._env, agent_view_size=5)
+ if self._env_name == 'MiniGrid-AKTDT-7x7-1-v0':
+ self._env = ViewSizeWrapper(self._env, agent_view_size=3)
+ if self._flat_obs:
+ self._env = FlatObsWrapper(self._env)
+ # self._env = ImgObsWrapper(self._env)
+ # self._env = RGBImgPartialObsWrapper(self._env)
+ if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward:
+ self._env = ObsPlusPrevActRewWrapper(self._env)
+ self._init_flag = True
+ if self._flat_obs:
+ self._observation_space = gym.spaces.Box(0, 1, shape=(2835, ))
+ else:
+ self._observation_space = self._env.observation_space
+ # to be compatible with subprocess env manager
+ if isinstance(self._observation_space, gym.spaces.Dict):
+ self._observation_space['obs'].dtype = np.dtype('float32')
+ else:
+ self._observation_space.dtype = np.dtype('float32')
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._seed = self._seed + np_seed
+ obs, _ = self._env.reset(seed=self._seed) # using the reset method of Gymnasium env
+ elif hasattr(self, '_seed'):
+ obs, _ = self._env.reset(seed=self._seed)
+ else:
+ obs, _ = self._env.reset()
+ obs = to_ndarray(obs)
+ self._eval_episode_return = 0
+ self._current_step = 0
+ if self._save_replay:
+ self._frames = []
+
+ action_mask = np.ones(self.action_space.n, 'int8')
+ obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
+
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ if isinstance(action, np.ndarray) and action.shape == (1, ):
+ action = action.squeeze() # 0-dim array
+ if self._save_replay:
+ self._frames.append(self._env.render())
+ # using the step method of Gymnasium env, return is (observation, reward, terminated, truncated, info)
+ obs, rew, done, _, info = self._env.step(action)
+ rew = float(rew)
+ self._eval_episode_return += rew
+ self._current_step += 1
+ if self._current_step >= self._max_step:
+ done = True
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ info['current_step'] = self._current_step
+ info['max_step'] = self._max_step
+ if self._save_replay:
+ path = os.path.join(
+ self._replay_path, '{}_episode_{}.gif'.format(self._env_name, self._save_replay_count)
+ )
+ self.display_frames_as_gif(self._frames, path)
+ self._save_replay_count += 1
+ obs = to_ndarray(obs)
+ rew = to_ndarray([rew]) # wrapped to be transferred to an array with shape (1,)
+
+ action_mask = np.ones(self.action_space.n, 'int8')
+ obs = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
+
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._save_replay = True
+ self._replay_path = replay_path
+ self._save_replay_count = 0
+
+ @staticmethod
+ def display_frames_as_gif(frames: list, path: str) -> None:
+ patch = plt.imshow(frames[0])
+ plt.axis('off')
+
+ def animate(i):
+ patch.set_data(frames[i])
+
+ anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=5)
+ anim.save(path, writer='imagemagick', fps=20)
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_env_num = cfg.pop('collector_env_num')
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = True
+ return [cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_env_num = cfg.pop('evaluator_env_num')
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = False
+ return [cfg for _ in range(evaluator_env_num)]
+
+ def __repr__(self) -> str:
+ return "LightZero MiniGrid Env({})".format(self._cfg.env_name)
\ No newline at end of file