Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TMP: polish(pu): polish unizero efficiency and tune atari100k performance #279

Closed
wants to merge 44 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
95886bd
polish(pu): polish quantize_state_hash and deepcopy
PaParaZz1 Aug 18, 2024
0e49a30
fix(pu): fix np.array dtype bug in buffer
PaParaZz1 Aug 18, 2024
00147f4
polish(pu): use 0 deepcopy in kv_cache operation in collect/eval phas…
PaParaZz1 Aug 19, 2024
b40c71b
polish(pu): use custom deepcopy for kv_cache
PaParaZz1 Aug 22, 2024
2cc81be
polish(pu): use value_array rather than value_list in compute_target_…
PaParaZz1 Aug 22, 2024
bc5332f
polish(pu): optimize compute_target_policy_non_re
PaParaZz1 Aug 22, 2024
a6c6a8e
polish(pu): optimize kv_caching update()
PaParaZz1 Aug 22, 2024
b5dcdcc
polish(pu): kv_cache_dict no to_cpu
PaParaZz1 Aug 22, 2024
5b0cbd4
polish(pu): optimize custom kv_cache copy
PaParaZz1 Aug 22, 2024
0035829
polish(pu): kv_cache_dict no to_cpu
PaParaZz1 Aug 22, 2024
043727b
feature(pu): add unizero ddp config
PaParaZz1 Aug 23, 2024
d568008
fix(pu): fix unizero ddp
dyyoungg Aug 23, 2024
d349137
sync code
dyyoungg Aug 25, 2024
3a344aa
polish(pu): use de kv_cacheepcopy only in recur_infer load
PaParaZz1 Aug 26, 2024
40053f7
Merge branch 'dev-efficiency' of https://github.com/opendilab/LightZe…
dyyoungg Aug 26, 2024
61a1139
sync code
dyyoungg Aug 26, 2024
bb38a10
sync code
dyyoungg Aug 26, 2024
b813be7
Merge branch 'dev-efficiency' of https://github.com/opendilab/LightZe…
dyyoungg Aug 27, 2024
715d17e
polish(pu): use share_polol for kv_cache in recurrent_inference and u…
jiayilee65 Aug 27, 2024
39d6bbe
polish(pu): all kv_cache copy use predefined share_pool
jiayilee65 Aug 27, 2024
f18be2a
polish(pu): unuse decoder_net and lpips in ddp config
dyyoungg Aug 28, 2024
1d010d3
sync code
dyyoungg Aug 28, 2024
fea98ee
feature(pu): add muzero_segment_collector.py
dyyoungg Sep 5, 2024
2a376ec
fix(pu): fix self.action_mask_dict init bug
dyyoungg Sep 5, 2024
d36196e
fix(pu): fix muzero_segment_collector
dyyoungg Sep 12, 2024
51e10f2
fix(pu): uz target-value obs also use aug when use_aug=True
dyyoungg Sep 12, 2024
8615899
fix(pu): fix last_game_segment bug in muzero_segment_collector.py
dyyoungg Sep 13, 2024
4c969f6
fix(pu): one episode done then return in muzero_segment_collector.py
dyyoungg Sep 13, 2024
cf2fd81
fix(pu): fix muzero_collector
dyyoungg Sep 14, 2024
bff16f7
polish(pu): polish unizero config and polish sample from segments
dyyoungg Sep 16, 2024
f0ff953
fix(pu): fix reanalyze in uz
dyyoungg Sep 17, 2024
91d48c1
polish(pu): add batch config and bash
dyyoungg Sep 17, 2024
1c8b92b
polish(pu): polish uz configs
dyyoungg Sep 20, 2024
1eed401
feature(pu): add unizero buffer_reanalyze variant
dyyoungg Sep 20, 2024
380f693
fix(pu): fix uz reanalyze_buffer
dyyoungg Sep 20, 2024
c43cdd4
polish(pu): polish configs
dyyoungg Sep 22, 2024
05a2ec3
feature(pu): add atari_muzero_segment_config
dyyoungg Sep 23, 2024
a78fa70
polish(pu):polish configs
dyyoungg Sep 24, 2024
f634e1f
polish(pu):polish configs
dyyoungg Sep 24, 2024
b7243ea
polish(pu): polish uz related configs, segment collector, train_entry
puyuan1996 Sep 26, 2024
dba9ca7
polish(pu): polish unizero world_model
puyuan1996 Sep 26, 2024
d5fff6d
polish(pu): polish reanalyze in buffer
puyuan1996 Sep 26, 2024
29197d2
fix(pu): fix entry import and nparray object bug in buffer
puyuan1996 Sep 26, 2024
eb268ac
polish(pu): polish configs
puyuan1996 Sep 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
from .eval_muzero_with_gym_env import eval_muzero_with_gym_env
from .train_alphazero import train_alphazero
from .train_muzero import train_muzero
from .train_muzero_reanalyze import train_muzero_reanalyze
from .train_muzero_with_gym_env import train_muzero_with_gym_env
from .train_muzero_with_gym_env import train_muzero_with_gym_env
from .train_muzero_with_reward_model import train_muzero_with_reward_model
from .train_rezero import train_rezero
from .train_unizero import train_unizero
from .train_unizero_reanalyze import train_unizero_reanalyze
2 changes: 1 addition & 1 deletion lzero/entry/train_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroEvaluator as Evaluator
from .utils import random_collect, initialize_zeros_batch
from .utils import random_collect


def train_muzero(
Expand Down
256 changes: 256 additions & 0 deletions lzero/entry/train_muzero_reanalyze.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import logging
import os
from functools import partial
from typing import Optional, Tuple

import torch
from ding.config import compile_config
from ding.envs import create_env_manager
from ding.envs import get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import EasyTimer
from ding.utils import set_pkg_seed, get_rank
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage, log_buffer_run_time
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroSegmentCollector as Collector
from .utils import random_collect

timer = EasyTimer()


def train_muzero_reanalyze(
input_cfg: Tuple[dict, dict],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
model_path: Optional[str] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy': # noqa
"""
Overview:
The train entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel Muzero.
Arguments:
- input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- model_path (:obj:`Optional[str]`): The pretrained model path, which should
point to the ckpt file of the pretrained model, and an absolute path is recommended.
In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
- max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
- max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
Returns:
- policy (:obj:`Policy`): Converged policy.
"""

cfg, create_cfg = input_cfg
assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_context', 'muzero_rnn_full_obs', 'sampled_efficientzero', 'sampled_muzero', 'gumbel_muzero', 'stochastic_muzero'], \
"train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'"

if create_cfg.policy.type in ['muzero', 'muzero_context', 'muzero_rnn_full_obs']:
from lzero.mcts import MuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'efficientzero':
from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'sampled_efficientzero':
from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'sampled_muzero':
from lzero.mcts import SampledMuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'gumbel_muzero':
from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer
elif create_cfg.policy.type == 'stochastic_muzero':
from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer

if cfg.policy.cuda and torch.cuda.is_available():
cfg.policy.device = 'cuda'
else:
cfg.policy.device = 'cpu'

cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create main components: env, policy
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])

collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)

if cfg.policy.eval_offline:
cfg.policy.learn.learner.hook.save_ckpt_after_iter = cfg.policy.eval_freq

policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# load pretrained model
if model_path is not None:
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))

# Create worker components: learner, collector, evaluator, replay buffer, commander.
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)

# ==============================================================
# MCTS+RL algorithms related core code
# ==============================================================
policy_config = cfg.policy
batch_size = policy_config.batch_size
# specific game buffer for MCTS+RL algorithms
replay_buffer = GameBuffer(policy_config)
collector = Collector(
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config
)
evaluator = Evaluator(
eval_freq=cfg.policy.eval_freq,
n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value,
env=evaluator_env,
policy=policy.eval_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config
)

# ==============================================================
# Main loop
# ==============================================================
# Learner's before_run hook.
learner.call_hook('before_run')

if cfg.policy.update_per_collect is not None:
update_per_collect = cfg.policy.update_per_collect

# The purpose of collecting random data before training:
# Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely.
# Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms.
if cfg.policy.random_collect_episode_num > 0:
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
if cfg.policy.eval_offline:
eval_train_iter_list = []
eval_train_envstep_list = []

# Evaluate the random agent
# stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)

buffer_reanalyze_count = 0
train_epoch = 0
reanalyze_batch_size = cfg.policy.reanalyze_batch_size

while True:
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
log_buffer_run_time(learner.train_iter, replay_buffer, tb_logger)
collect_kwargs = {}
# set temperature for visit count distributions according to the train_iter,
# please refer to Appendix D in MuZero paper for details.
collect_kwargs['temperature'] = visit_count_temperature(
policy_config.manual_temperature_decay,
policy_config.fixed_temperature_value,
policy_config.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
)

if policy_config.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(
start=policy_config.eps.start,
end=policy_config.eps.end,
decay=policy_config.eps.decay,
type_=policy_config.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
else:
collect_kwargs['epsilon'] = 0.0

# Evaluate policy performance.
if evaluator.should_eval(learner.train_iter):
if cfg.policy.eval_offline:
eval_train_iter_list.append(learner.train_iter)
eval_train_envstep_list.append(collector.envstep)
else:
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break

# Collect data by default config n_sample/n_episode.
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
if cfg.policy.update_per_collect is None:
# update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
# save returned new_data collected by the collector
replay_buffer.push_game_segments(new_data)
# remove the oldest data if the replay buffer is full.
replay_buffer.remove_oldest_data_to_fit()

# Periodically reanalyze buffer
if cfg.policy.buffer_reanalyze_freq >= 1:
# Reanalyze buffer <buffer_reanalyze_freq> times in one train_epoch
reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq
else:
# Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch
if train_epoch % (1//cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions()//cfg.policy.num_unroll_steps > int(reanalyze_batch_size/cfg.policy.reanalyze_partition):
with timer:
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy)
buffer_reanalyze_count += 1
logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}')
logging.info(f'Buffer reanalyze time: {timer.value}')

# Learn policy from collected data.
for i in range(update_per_collect):

if cfg.policy.buffer_reanalyze_freq >= 1:
# Reanalyze buffer <buffer_reanalyze_freq> times in one train_epoch
if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions() // cfg.policy.num_unroll_steps > int(
reanalyze_batch_size / cfg.policy.reanalyze_partition):
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy)
buffer_reanalyze_count += 1
logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}')

# Learner will train ``update_per_collect`` times in one iteration.
if replay_buffer.get_num_of_transitions() > batch_size:
train_data = replay_buffer.sample(batch_size, policy)
else:
logging.warning(
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, '
f'{replay_buffer} '
f'continue to collect now ....'
)
break

# The core train steps for MCTS+RL algorithms.
log_vars = learner.train(train_data, collector.envstep)

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

train_epoch += 1

if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
if cfg.policy.eval_offline:
logging.info(f'eval offline beginning...')
ckpt_dirname = './{}/ckpt'.format(learner.exp_name)
# Evaluate the performance of the pretrained model.
for train_iter, collector_envstep in zip(eval_train_iter_list, eval_train_envstep_list):
ckpt_name = 'iteration_{}.pth.tar'.format(train_iter)
ckpt_path = os.path.join(ckpt_dirname, ckpt_name)
# load the ckpt of pretrained model
policy.learn_mode.load_state_dict(torch.load(ckpt_path, map_location=cfg.policy.device))
stop, reward = evaluator.eval(learner.save_checkpoint, train_iter, collector_envstep)
logging.info(
f'eval offline at train_iter: {train_iter}, collector_envstep: {collector_envstep}, reward: {reward}')
logging.info(f'eval offline finished!')
break

# Learner's after_run hook.
learner.call_hook('after_run')
return policy
2 changes: 1 addition & 1 deletion lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
from lzero.entry.utils import log_buffer_memory_usage
from lzero.policy import visit_count_temperature
from lzero.policy.random_policy import LightZeroRandomPolicy
from lzero.worker import MuZeroCollector as Collector
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroCollector as Collector
from .utils import random_collect


Expand Down
Loading
Loading