Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev-unizero-multitask-v2' into d…
Browse files Browse the repository at this point in the history
…ev-multitask-v3
  • Loading branch information
puyuan1996 committed Oct 16, 2024
2 parents b4ae014 + 2495d60 commit 6abef12
Show file tree
Hide file tree
Showing 25 changed files with 3,911 additions and 72 deletions.
1 change: 1 addition & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@
from .train_rezero import train_rezero
from .train_unizero import train_unizero
from .train_unizero_segment import train_unizero_segment
from .train_unizero_multitask import train_unizero_multitask
1 change: 0 additions & 1 deletion lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def train_unizero(
# Clear caches and precompute positional embedding matrices
policy.recompute_pos_emb_diff_and_clear_cache() # TODO

train_data.append({'train_which_component': 'transformer'})
log_vars = learner.train(train_data, collector.envstep)

if cfg.policy.use_priority:
Expand Down
256 changes: 256 additions & 0 deletions lzero/entry/train_unizero_multitask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import logging
import os
from functools import partial
from typing import Tuple, Optional, List

import torch
import numpy as np
from ding.config import compile_config
from ding.envs import create_env_manager, get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import set_pkg_seed, get_rank
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage
from lzero.policy import visit_count_temperature
from lzero.worker import MuZeroCollector as Collector, MuZeroEvaluator as Evaluator
from lzero.mcts import UniZeroGameBuffer as GameBuffer

from line_profiler import line_profiler

#@profile
def train_unizero_multitask(
input_cfg_list: List[Tuple[int, Tuple[dict, dict]]],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
model_path: Optional[str] = None,
max_train_iter: Optional[int] = int(1e10),
max_env_step: Optional[int] = int(1e10),
) -> 'Policy':
"""
Overview:
The train entry for UniZero, proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models.
UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms,
particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.
Arguments:
- input_cfg_list (List[Tuple[int, Tuple[dict, dict]]]): List of configurations for different tasks.
- seed (int): Random seed.
- model (Optional[torch.nn.Module]): Instance of torch.nn.Module.
- model_path (Optional[str]): The pretrained model path, which should point to the ckpt file of the pretrained model.
- max_train_iter (Optional[int]): Maximum policy update iterations in training.
- max_env_step (Optional[int]): Maximum collected environment interaction steps.
Returns:
- policy (Policy): Converged policy.
"""
cfgs = []
game_buffers = []
collector_envs = []
evaluator_envs = []
collectors = []
evaluators = []

task_id, [cfg, create_cfg] = input_cfg_list[0]

# Ensure the specified policy type is supported
assert create_cfg.policy.type in ['unizero_multitask'], "train_unizero entry now only supports 'unizero'"

# Set device based on CUDA availability
cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
logging.info(f'cfg.policy.device: {cfg.policy.device}')

# Compile the configuration
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
# Create shared policy for all tasks
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])

# Load pretrained model if specified
if model_path is not None:
logging.info(f'Loading model from {model_path} begin...')
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
logging.info(f'Loading model from {model_path} end!')

# Create SummaryWriter for TensorBoard logging
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
# Create shared learner for all tasks
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)

# TODO task_id = 0:
policy_config = cfg.policy
batch_size = policy_config.batch_size[0]

for task_id, input_cfg in input_cfg_list:
if task_id > 0:
# Get the configuration for each task
cfg, create_cfg = input_cfg
cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu'
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
policy_config = cfg.policy
policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode
policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode

env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
collector_env.seed(cfg.seed + task_id)
evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False)
set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda)

# ===== NOTE: Create different game buffer, collector, evaluator for each task ====
# TODO: share replay buffer for all tasks
replay_buffer = GameBuffer(policy_config)
collector = Collector(
env=collector_env,
policy=policy.collect_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config,
task_id=task_id
)
evaluator = Evaluator(
eval_freq=cfg.policy.eval_freq,
n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value,
env=evaluator_env,
policy=policy.eval_mode,
tb_logger=tb_logger,
exp_name=cfg.exp_name,
policy_config=policy_config,
task_id=task_id
)

cfgs.append(cfg)
replay_buffer.batch_size = cfg.policy.batch_size[task_id]
game_buffers.append(replay_buffer)
collector_envs.append(collector_env)
evaluator_envs.append(evaluator_env)
collectors.append(collector)
evaluators.append(evaluator)

learner.call_hook('before_run')
value_priority_tasks = {}

while True:
# Precompute positional embedding matrices for collect/eval (not training)
policy._collect_model.world_model.precompute_pos_emb_diff_kv()
policy._target_model.world_model.precompute_pos_emb_diff_kv()

# Collect data for each task
for task_id, (cfg, collector, evaluator, replay_buffer) in enumerate(
zip(cfgs, collectors, evaluators, game_buffers)):
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)

collect_kwargs = {
'temperature': visit_count_temperature(
policy_config.manual_temperature_decay,
policy_config.fixed_temperature_value,
policy_config.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
),
'epsilon': 0.0 # Default epsilon value
}

if policy_config.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(
start=policy_config.eps.start,
end=policy_config.eps.end,
decay=policy_config.eps.decay,
type_=policy_config.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)

if evaluator.should_eval(learner.train_iter):
print('=' * 20)
print(f'evaluate task_id: {task_id}...')
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
if stop:
break

print('=' * 20)
print(f'collect task_id: {task_id}...')

# Reset initial data before each collection
collector._policy.reset(reset_init_data=True)
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# Determine updates per collection
update_per_collect = cfg.policy.update_per_collect
if update_per_collect is None:
collected_transitions_num = sum(len(game_segment) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)

# Update replay buffer
replay_buffer.push_game_segments(new_data)
replay_buffer.remove_oldest_data_to_fit()

not_enough_data = any(replay_buffer.get_num_of_transitions() < batch_size for replay_buffer in game_buffers)

# Learn policy from collected data.
if not not_enough_data:
# Learner will train ``update_per_collect`` times in one iteration.
for i in range(update_per_collect):
train_data_multi_task = []
envstep_multi_task = 0
for task_id, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)):
envstep_multi_task += collector.envstep
if replay_buffer.get_num_of_transitions() > batch_size:
batch_size = cfg.policy.batch_size[task_id]
train_data = replay_buffer.sample(batch_size, policy)
if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0:
policy.recompute_pos_emb_diff_and_clear_cache()
# Append task_id to train_data
train_data.append(task_id)
train_data_multi_task.append(train_data)
else:
logging.warning(
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, replay_buffer: {replay_buffer}'
)
break

if train_data_multi_task:
log_vars = learner.train(train_data_multi_task, envstep_multi_task)

if cfg.policy.use_priority:
for task_id, replay_buffer in enumerate(game_buffers):
# Update the priority for the task-specific replay buffer.
replay_buffer.update_priority(train_data_multi_task[task_id], log_vars[0][f'value_priority_task{task_id}'])

# Retrieve the updated priorities for the current task.
current_priorities = log_vars[0][f'value_priority_task{task_id}']

# Calculate statistics: mean, running mean, standard deviation for the priorities.
mean_priority = np.mean(current_priorities)
std_priority = np.std(current_priorities)

# Using exponential moving average for running mean (alpha is the smoothing factor).
alpha = 0.1 # You can adjust this smoothing factor as needed.
if f'running_mean_priority_task{task_id}' not in value_priority_tasks:
# Initialize running mean if it does not exist.
value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority
else:
# Update running mean.
value_priority_tasks[f'running_mean_priority_task{task_id}'] = (
alpha * mean_priority + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}']
)

# Calculate the normalized priority using the running mean.
running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}']
normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6)

# Store the normalized priorities back to the replay buffer (if needed).
# replay_buffer.update_priority(train_data_multi_task[task_id], normalized_priorities)

# Log the statistics if the print_task_priority_logs flag is set.
if cfg.policy.print_task_priority_logs:
print(f"Task {task_id} - Mean Priority: {mean_priority:.8f}, "
f"Running Mean Priority: {running_mean_priority:.8f}, "
f"Standard Deviation: {std_priority:.8f}")


if all(collector.envstep >= max_env_step for collector in collectors) or learner.train_iter >= max_train_iter:
break

learner.call_hook('after_run')
return policy
45 changes: 27 additions & 18 deletions lzero/mcts/buffer/game_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,22 +102,23 @@ def _make_batch(self, orig_data: Any, reanalyze_ratio: float) -> Tuple[Any]:
"""
pass

def _sample_orig_data(self, batch_size: int) -> Tuple:
def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) -> Tuple:
"""
Overview:
sample orig_data that contains:
game_segment_list: a list of game segments
pos_in_game_segment_list: transition index in game (relative index)
batch_index_list: the index of start transition of sampled minibatch in replay buffer
weights_list: the weight concerning the priority
make_time: the time the batch is made (for correctly updating replay buffer when data is deleted)
Sample original data which includes:
- game_segment_list: A list of game segments.
- pos_in_game_segment_list: Transition index in the game (relative index).
- batch_index_list: The index of the start transition of the sampled mini-batch in the replay buffer.
- weights_list: The weight concerning the priority.
- make_time: The time the batch is made (for correctly updating the replay buffer when data is deleted).
Arguments:
- batch_size (:obj:`int`): batch size
- beta: float the parameter in PER for calculating the priority
- batch_size (:obj:`int`): The size of the batch.
- print_priority_logs (:obj:`bool`): Whether to print logs related to priority statistics, defaults to False.
"""
assert self._beta > 0
assert self._beta > 0, "Beta should be greater than 0"
num_of_transitions = self.get_num_of_transitions()
if self._cfg.use_priority is False:
if not self._cfg.use_priority:
# If priority is not used, set all priorities to 1
self.game_pos_priorities = np.ones_like(self.game_pos_priorities)

# +1e-6 for numerical stability
Expand All @@ -126,20 +127,21 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:

# sample according to transition index
batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False)

if self._cfg.reanalyze_outdated is True:
# NOTE: used in reanalyze part
if self._cfg.reanalyze_outdated:
# Sort the batch indices if reanalyze is enabled
batch_index_list.sort()


# Calculate weights for the sampled transitions
weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta)
weights_list /= weights_list.max()
weights_list /= weights_list.max() # Normalize weights

game_segment_list = []
pos_in_game_segment_list = []

for idx in batch_index_list:
game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx]
game_segment_idx -= self.base_idx
game_segment_idx -= self.base_idx # Adjust index based on base index
game_segment = self.game_segment_buffer[game_segment_idx]

game_segment_list.append(game_segment)
Expand All @@ -161,6 +163,12 @@ def _sample_orig_data(self, batch_size: int) -> Tuple:
make_time = [time.time() for _ in range(len(batch_index_list))]

orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time)

if print_priority_logs:
print(f"Sampled batch indices: {batch_index_list}")
print(f"Sampled priorities: {self.game_pos_priorities[batch_index_list]}")
print(f"Sampled weights: {weights_list}")

return orig_data

def _sample_orig_reanalyze_batch(self, batch_size: int) -> Tuple:
Expand Down Expand Up @@ -580,7 +588,8 @@ def remove_oldest_data_to_fit(self) -> None:
Overview:
remove some oldest data if the replay buffer is full.
"""
assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size"
if isinstance(self._cfg.batch_size, int):
assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size"
nums_of_game_segments = self.get_num_of_game_segments()
total_transition = self.get_num_of_transitions()
if total_transition > self.replay_buffer_size:
Expand Down
1 change: 1 addition & 0 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,7 @@ def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -
NOTE:
train_data = [current_batch, target_batch]
current_batch = [obs_list, action_list, improved_policy_list(only in Gumbel MuZero), mask_list, batch_index_list, weights, make_time_list]
target_batch = [batch_rewards, batch_target_values, batch_target_policies]
"""
indices = train_data[0][-3]
metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities}
Expand Down
Loading

0 comments on commit 6abef12

Please sign in to comment.