Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/opendilab/LightZero into de…
Browse files Browse the repository at this point in the history
…v-cont-random-policy
  • Loading branch information
puyuan1996 committed Oct 30, 2023
2 parents 4969d87 + 4de2a9e commit 4bbb309
Show file tree
Hide file tree
Showing 37 changed files with 2,023 additions and 78 deletions.
63 changes: 63 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# 🚀 Welcome to LightZero! 🌟

We're thrilled that you want to contribute to LightZero. Your help is invaluable, and we appreciate your efforts to make this project even better. 😄

## 📝 How to Contribute

1. **Fork the Repository** 🍴
- Click on the "Fork" button at the top right of the [LightZero repository](https://github.com/opendilab/LightZero).

2. **Clone your Fork** 💻
- `git clone https://github.com/your-username/LightZero.git`

3. **Create a New Branch** 🌿
- `git checkout -b your-new-feature`

4. **Make Your Awesome Changes** 💥
- Add some cool features.
- Fix a bug.
- Improve the documentation.
- Anything that adds value!

5. **Commit Your Changes** 📦
- `git commit -m "Your descriptive commit message"`

6. **Push to Your Fork** 🚢
- `git push origin your-new-feature`

7. **Create a Pull Request** 🎉
- Go to the [LightZero repository](https://github.com/opendilab/LightZero).
- Click on "New Pull Request."
- Fill in the details and submit your PR.
- Please make sure your PR has a clear title and description.

8. **Review & Collaborate** 🤝
- Be prepared to answer questions or make changes to your PR as requested by the maintainers.

9. **Celebrate! 🎉** Your contribution has been added to LightZero.

## 📦 Reporting Issues

If you encounter a bug or have an idea for an improvement, please create an issue in the [Issues](https://github.com/opendilab/LightZero/issues) section. Make sure to include details about the problem and how to reproduce it.

## 🛠 Code Style and Guidelines

We follow a few simple guidelines:
- Keep your code clean and readable.
- Use meaningful variable and function names.
- Comment your code when necessary.
- Ensure your code adheres to existing coding styles and standards.

For detailed information on code style, unit testing, and code review, please refer to our documentation:

- [Code Style](https://di-engine-docs.readthedocs.io/en/latest/21_code_style/index.html)
- [Unit Test](https://di-engine-docs.readthedocs.io/en/latest/22_test/index.html)
- [Code Review](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/git_guide.html)

## 🤖 Code of Conduct

Please be kind and respectful when interacting with other contributors. We have a [Code of Conduct](LICENSE) to ensure a positive and welcoming environment for everyone.

## 🙌 Thank You! 🙏

Your contribution helps make LightZero even better. We appreciate your dedication to the project. Keep coding and stay awesome! 😃
9 changes: 6 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ Updated on 2023.09.21 LightZero-v0.0.2

> LightZero is a lightweight, efficient, and easy-to-understand open-source algorithm toolkit that combines Monte Carlo Tree Search (MCTS) and Deep Reinforcement Learning (RL).
English | [简体中文](https://github.com/opendilab/LightZero/blob/main/README.zh.md) | [Paper](https://arxiv.org/pdf/2310.08348.pdf)
English | [简体中文(Simplified Chinese)](https://github.com/opendilab/LightZero/blob/main/README.zh.md) | [Paper](https://arxiv.org/pdf/2310.08348.pdf)

## Background

The method of combining Monte Carlo Tree Search and Deep Reinforcement Learning represented by AlphaZero and MuZero has achieved superhuman level in various games such as Go and Atari,and has also made gratifying progress in scientific fields such as protein structure prediction, matrix multiplication algorithm search, etc.
The integration of Monte Carlo Tree Search and Deep Reinforcement Learning,
exemplified by AlphaZero and MuZero,
has achieved unprecedented performance levels in various games, including Go and Atari.
This advanced methodology has also made significant strides in scientific domains like protein structure prediction and the search for matrix multiplication algorithms.
The following is an overview of the historical evolution of the Monte Carlo Tree Search algorithm series:
![pipeline](assets/mcts_rl_evolution_overview.png)

Expand Down Expand Up @@ -484,4 +487,4 @@ Special thanks to [@PaParaZz1](https://github.com/PaParaZz1), [@karroyan](https:
## License
All code within this repository is under [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
<p align="right">(<a href="#top">back to top</a>)</p>
<p align="right">(<a href="#top">Back to top</a>)</p>
5 changes: 5 additions & 0 deletions README.zh.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
<div id="top"></div>

# LightZero

<div align="center">
Expand Down Expand Up @@ -477,3 +479,6 @@ python3 -u zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py
## 许可证
本仓库中的所有代码都符合 [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0)。
<p align="right">(<a href="#top">回到顶部</a>)</p>
1 change: 1 addition & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .train_alphazero import train_alphazero
from .eval_alphazero import eval_alphazero
from .train_muzero import train_muzero
from .train_muzero_with_reward_model import train_muzero_with_reward_model
from .eval_muzero import eval_muzero
from .eval_muzero_with_gym_env import eval_muzero_with_gym_env
from .train_muzero_with_gym_env import train_muzero_with_gym_env
4 changes: 2 additions & 2 deletions lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def train_muzero(
update_per_collect = cfg.policy.update_per_collect

# The purpose of collecting random data before training:
# Exploration: The collection of random data aids the agent in exploring the environment and prevents premature convergence to a suboptimal policy.
# Comparation: The agent's performance during random action-taking can be used as a reference point to evaluate the efficacy of reinforcement learning algorithms.
# Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely.
# Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms.
if cfg.policy.random_collect_episode_num > 0:
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)

Expand Down
210 changes: 210 additions & 0 deletions lzero/entry/train_muzero_with_reward_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import logging
import os
from functools import partial
from typing import Optional, Tuple

import torch
from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import set_pkg_seed
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage, random_collect
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.reward_model.rnd_reward_model import RNDRewardModel
from lzero.worker import MuZeroCollector, MuZeroEvaluator


def train_muzero_with_reward_model(
input_cfg: Tuple[dict, dict],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
model_path: Optional[str] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
The train entry for MCTS+RL algorithms augmented with reward_model.
Arguments:
- input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- model_path (:obj:`Optional[str]`): The pretrained model path, which should
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""

cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_rnd', 'sampled_efficientzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero'"

if create_cfg.policy.type in ['muzero', 'muzero_rnd']:
from lzero.mcts import MuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'efficientzero':
from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'sampled_efficientzero':
from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer

if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
else:
cfg.policy.device = 'cpu'

cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)

collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# load pretrained model
if model_path is not None:
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)

# ==============================================================
# MCTS+RL algorithms related core code
# ==============================================================
policy_config = cfg.policy
batch_size = policy_config.batch_size
# specific game buffer for MCTS+RL algorithms
replay_buffer = GameBuffer(policy_config)
collector = MuZeroCollector(
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config
)
evaluator = MuZeroEvaluator(
eval_freq=cfg.policy.eval_freq,
n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value,
env=evaluator_env,
policy=policy.eval_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config
)
# create reward_model
reward_model = RNDRewardModel(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger,
policy._learn_model.representation_network,
policy._target_model_for_intrinsic_reward.representation_network,
cfg.policy.use_momentum_representation_network
)

# ==============================================================
# Main loop
# ==============================================================
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.update_per_collect is not None:
update_per_collect = cfg.policy.update_per_collect

# The purpose of collecting random data before training:
# Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely.
# Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms.
if cfg.policy.random_collect_episode_num > 0:
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)

while True:
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
collect_kwargs = {}
# set temperature for visit count distributions according to the train_iter,
# please refer to Appendix D in MuZero paper for details.
collect_kwargs['temperature'] = visit_count_temperature(
policy_config.manual_temperature_decay,
policy_config.fixed_temperature_value,
policy_config.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter,
)

if policy_config.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(start=policy_config.eps.start, end=policy_config.eps.end,
decay=policy_config.eps.decay, type_=policy_config.eps.type)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
else:
collect_kwargs['epsilon'] = 0.0

# Evaluate policy performance.
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break

# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# ****** reward_model related code ******
# collect data for reward_model training
reward_model.collect_data(new_data)
# update reward_model
if reward_model.cfg.input_type == 'latent_state':
# train reward_model with latent_state
if len(reward_model.train_latent_state) > reward_model.cfg.batch_size:
reward_model.train_with_data()
elif reward_model.cfg.input_type in ['obs', 'latent_state']:
# train reward_model with obs
if len(reward_model.train_obs) > reward_model.cfg.batch_size:
reward_model.train_with_data()
# clear old data in reward_model
reward_model.clear_old_data()

if cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
replay_buffer.remove_oldest_data_to_fit()

# Learn policy from collected data.
for i in range(update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
if replay_buffer.get_num_of_transitions() > batch_size:
train_data = replay_buffer.sample(batch_size, policy)
else:
logging.warning(
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, '
f'{replay_buffer} '
f'continue to collect now ....'
)
break

# update train_data reward using the augmented reward
train_data_augmented = reward_model.estimate(train_data)

# The core train steps for MCTS+RL algorithms.
log_vars = learner.train(train_data_augmented, collector.envstep)

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break

# Learner's after_run hook.
learner.call_hook('after_run')
return policy
Loading

0 comments on commit 4bbb309

Please sign in to comment.