Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(pu): add minigrid/bsuite env and config, add muzero rnd algo. #110

Merged
merged 3 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .train_alphazero import train_alphazero
from .eval_alphazero import eval_alphazero
from .train_muzero import train_muzero
from .train_muzero_with_reward_model import train_muzero_with_reward_model
from .eval_muzero import eval_muzero
from .eval_muzero_with_gym_env import eval_muzero_with_gym_env
from .train_muzero_with_gym_env import train_muzero_with_gym_env
4 changes: 2 additions & 2 deletions lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,8 @@ def train_muzero(
update_per_collect = cfg.policy.update_per_collect

# The purpose of collecting random data before training:
# Exploration: The collection of random data aids the agent in exploring the environment and prevents premature convergence to a suboptimal policy.
# Comparation: The agent's performance during random action-taking can be used as a reference point to evaluate the efficacy of reinforcement learning algorithms.
# Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely.
# Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms.
if cfg.policy.random_collect_episode_num > 0:
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)

Expand Down
210 changes: 210 additions & 0 deletions lzero/entry/train_muzero_with_reward_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
import logging
import os
from functools import partial
from typing import Optional, Tuple

import torch
from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import set_pkg_seed
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage, random_collect
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.reward_model.rnd_reward_model import RNDRewardModel
from lzero.worker import MuZeroCollector, MuZeroEvaluator


def train_muzero_with_reward_model(
input_cfg: Tuple[dict, dict],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
model_path: Optional[str] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
The train entry for MCTS+RL algorithms augmented with reward_model.
Arguments:
- input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- model_path (:obj:`Optional[str]`): The pretrained model path, which should
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""

cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_rnd', 'sampled_efficientzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero'"

if create_cfg.policy.type in ['muzero', 'muzero_rnd']:
from lzero.mcts import MuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'efficientzero':
from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'sampled_efficientzero':
from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer

if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
else:
cfg.policy.device = 'cpu'

cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)

collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# load pretrained model
if model_path is not None:
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)

# ==============================================================
# MCTS+RL algorithms related core code
# ==============================================================
policy_config = cfg.policy
batch_size = policy_config.batch_size
# specific game buffer for MCTS+RL algorithms
replay_buffer = GameBuffer(policy_config)
collector = MuZeroCollector(
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config
)
evaluator = MuZeroEvaluator(
eval_freq=cfg.policy.eval_freq,
n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value,
env=evaluator_env,
policy=policy.eval_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config
)
# create reward_model
reward_model = RNDRewardModel(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger,
policy._learn_model.representation_network,
policy._target_model_for_intrinsic_reward.representation_network,
cfg.policy.use_momentum_representation_network
)

# ==============================================================
# Main loop
# ==============================================================
# Learner's before_run hook.
learner.call_hook('before_run')
if cfg.policy.update_per_collect is not None:
update_per_collect = cfg.policy.update_per_collect

# The purpose of collecting random data before training:
# Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely.
# Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms.
if cfg.policy.random_collect_episode_num > 0:
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)

while True:
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
collect_kwargs = {}
# set temperature for visit count distributions according to the train_iter,
# please refer to Appendix D in MuZero paper for details.
collect_kwargs['temperature'] = visit_count_temperature(
policy_config.manual_temperature_decay,
policy_config.fixed_temperature_value,
policy_config.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter,
)

if policy_config.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(start=policy_config.eps.start, end=policy_config.eps.end,
decay=policy_config.eps.decay, type_=policy_config.eps.type)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
else:
collect_kwargs['epsilon'] = 0.0

# Evaluate policy performance.
if evaluator.should_eval(learner.train_iter):
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break

# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# ****** reward_model related code ******
# collect data for reward_model training
reward_model.collect_data(new_data)
# update reward_model
if reward_model.cfg.input_type == 'latent_state':
# train reward_model with latent_state
if len(reward_model.train_latent_state) > reward_model.cfg.batch_size:
reward_model.train_with_data()
elif reward_model.cfg.input_type in ['obs', 'latent_state']:
# train reward_model with obs
if len(reward_model.train_obs) > reward_model.cfg.batch_size:
reward_model.train_with_data()
# clear old data in reward_model
reward_model.clear_old_data()

if cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
replay_buffer.remove_oldest_data_to_fit()

# Learn policy from collected data.
for i in range(update_per_collect):
# Learner will train ``update_per_collect`` times in one iteration.
if replay_buffer.get_num_of_transitions() > batch_size:
train_data = replay_buffer.sample(batch_size, policy)
else:
logging.warning(
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, '
f'{replay_buffer} '
f'continue to collect now ....'
)
break

# update train_data reward using the augmented reward
train_data_augmented = reward_model.estimate(train_data)

# The core train steps for MCTS+RL algorithms.
log_vars = learner.train(train_data_augmented, collector.envstep)

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break

# Learner's after_run hook.
learner.call_hook('after_run')
return policy
45 changes: 42 additions & 3 deletions lzero/policy/muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from ding.policy.base_policy import Policy
from ding.torch_utils import to_tensor
from ding.utils import POLICY_REGISTRY
from torch.distributions import Categorical
from torch.nn import L1Loss

from lzero.mcts import MuZeroMCTSCtree as MCTSCtree
Expand Down Expand Up @@ -60,6 +61,8 @@ class MuZeroPolicy(Policy):
norm_type='BN',
),
# ****** common ******
# (bool) whether to use rnd model.
use_rnd_model=False,
# (bool) Whether to use multi-gpu training.
multi_gpu=False,
# (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero)
Expand Down Expand Up @@ -116,6 +119,8 @@ class MuZeroPolicy(Policy):
learning_rate=0.2,
# (int) Frequency of target network update.
target_update_freq=100,
# (int) Frequency of target network update.
target_update_freq_for_intrinsic_reward=1000,
# (float) Weight decay for training policy network.
weight_decay=1e-4,
# (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction).
Expand All @@ -138,6 +143,8 @@ class MuZeroPolicy(Policy):
value_loss_weight=0.25,
# (float) The weight of policy loss.
policy_loss_weight=1,
# (float) The weight of policy entropy loss.
policy_entropy_loss_weight=0,
# (float) The weight of ssl (self-supervised learning) loss.
ssl_loss_weight=0,
# (bool) Whether to use piecewise constant learning rate decay.
Expand Down Expand Up @@ -256,6 +263,22 @@ def _init_learn(self) -> None:
self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution
)

if self._cfg.use_rnd_model:
if self._cfg.target_model_for_intrinsic_reward_update_type == 'assign':
self._target_model_for_intrinsic_reward = model_wrap(
self._target_model,
wrapper_name='target',
update_type='assign',
update_kwargs={'freq': self._cfg.target_update_freq_for_intrinsic_reward}
)
elif self._cfg.target_model_for_intrinsic_reward_update_type == 'momentum':
self._target_model_for_intrinsic_reward = model_wrap(
self._target_model,
wrapper_name='target',
update_type='momentum',
update_kwargs={'theta': self._cfg.target_update_theta_for_intrinsic_reward}
)

def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]:
"""
Overview:
Expand All @@ -271,6 +294,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
"""
self._learn_model.train()
self._target_model.train()
if self._cfg.use_rnd_model:
self._target_model_for_intrinsic_reward.train()

current_batch, target_batch = data
obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch
Expand Down Expand Up @@ -340,6 +365,11 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
policy_loss = cross_entropy_loss(policy_logits, target_policy[:, 0])
value_loss = cross_entropy_loss(value, target_value_categorical[:, 0])

prob = torch.softmax(policy_logits, dim=-1)
dist = Categorical(prob)
policy_entropy_loss = -dist.entropy()


reward_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device)
consistency_loss = torch.zeros(self._cfg.batch_size, device=self._cfg.device)

Expand Down Expand Up @@ -383,6 +413,10 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# ==============================================================
policy_loss += cross_entropy_loss(policy_logits, target_policy[:, step_k + 1])

prob = torch.softmax(policy_logits, dim=-1)
dist = Categorical(prob)
policy_entropy_loss += -dist.entropy()

value_loss += cross_entropy_loss(value, target_value_categorical[:, step_k + 1])
reward_loss += cross_entropy_loss(reward, target_reward_categorical[:, step_k])

Expand All @@ -406,7 +440,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# weighted loss with masks (some invalid states which are out of trajectory.)
loss = (
self._cfg.ssl_loss_weight * consistency_loss + self._cfg.policy_loss_weight * policy_loss +
self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss
self._cfg.value_loss_weight * value_loss + self._cfg.reward_loss_weight * reward_loss +
self._cfg.policy_entropy_loss_weight * policy_entropy_loss
)
weighted_total_loss = (weights * loss).mean()

Expand All @@ -427,6 +462,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# the core target model update step.
# ==============================================================
self._target_model.update(self._learn_model.state_dict())
if self._cfg.use_rnd_model:
self._target_model_for_intrinsic_reward.update(self._learn_model.state_dict())

if self._cfg.monitor_extra_statistics:
predicted_rewards = torch.stack(predicted_rewards).transpose(1, 0).squeeze(-1)
Expand All @@ -439,6 +476,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
'weighted_total_loss': weighted_total_loss.item(),
'total_loss': loss.mean().item(),
'policy_loss': policy_loss.mean().item(),
'policy_entropy': - policy_entropy_loss.mean().item() / (self._cfg.num_unroll_steps + 1),
'reward_loss': reward_loss.mean().item(),
'value_loss': value_loss.mean().item(),
'consistency_loss': consistency_loss.mean().item() / self._cfg.num_unroll_steps,
Expand Down Expand Up @@ -467,7 +505,7 @@ def _init_collect(self) -> None:
self._mcts_collect = MCTSCtree(self._cfg)
else:
self._mcts_collect = MCTSPtree(self._cfg)
self._collect_mcts_temperature = 1
self._collect_mcts_temperature = 1.
self.collect_epsilon = 0.0

def _forward_collect(
Expand All @@ -477,7 +515,7 @@ def _forward_collect(
temperature: float = 1,
to_play: List = [-1],
epsilon: float = 0.25,
ready_env_id=None
ready_env_id: List = None,
) -> Dict:
"""
Overview:
Expand Down Expand Up @@ -697,6 +735,7 @@ def _monitor_vars_learn(self) -> List[str]:
'weighted_total_loss',
'total_loss',
'policy_loss',
'policy_entropy',
'reward_loss',
'value_loss',
'consistency_loss',
Expand Down
Loading