diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index 200b91441..12c7ed78a 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -7,4 +7,5 @@ from .train_muzero_with_gym_env import train_muzero_with_gym_env from .train_muzero_with_reward_model import train_muzero_with_reward_model from .train_rezero import train_rezero +from .train_rezero_uz import train_rezero_uz from .train_unizero import train_unizero diff --git a/lzero/entry/train_rezero_uz.py b/lzero/entry/train_rezero_uz.py new file mode 100644 index 000000000..e4ada9f66 --- /dev/null +++ b/lzero/entry/train_rezero_uz.py @@ -0,0 +1,219 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter +from torch.utils.tensorboard import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +# from lzero.worker import MuZeroCollector as Collector +from lzero.worker import MuZeroSegmentCollector as Collector # ============ TODO: ============ +from lzero.worker import MuZeroEvaluator as Evaluator +from .utils import random_collect + + +def train_rezero_uz( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The train entry for UniZero, proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models. + UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms, + particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + + # Ensure the specified policy type is supported + assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'" + + # Import the correct GameBuffer class based on the policy type + # game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'} + game_buffer_classes = {'unizero': 'UniZeroReGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'} + + GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]), + game_buffer_classes[create_cfg.policy.type]) + + # Set device based on CUDA availability + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # Compile the configuration + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available()) + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load pretrained model if specified + if model_path is not None: + logging.info(f'Loading model from {model_path} begin...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Loading model from {model_path} end!') + + # Create worker components: learner, collector, evaluator, replay buffer, commander + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # MCTS+RL algorithms related core code + policy_config = cfg.policy + replay_buffer = GameBuffer(policy_config) + collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name, + policy_config=policy_config) + evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode, + tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config) + + # Learner's before_run hook + learner.call_hook('before_run') + + # Collect random data before training + if cfg.policy.random_collect_episode_num > 0: + random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + + batch_size = policy._cfg.batch_size + + # TODO: for visualize + # stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + + buffer_reanalyze_count = 0 + train_epoch = 0 + reanalyze_batch_size = cfg.policy.reanalyze_batch_size + + while True: + # Log buffer memory usage + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + + # Set temperature for visit count distributions + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # Default epsilon value + } + + # Configure epsilon for epsilon-greedy exploration + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # Evaluate policy performance + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect new data + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # Determine updates per collection + update_per_collect = cfg.policy.update_per_collect + if update_per_collect is None: + collected_transitions_num = sum(len(game_segment) for game_segment in new_data[0]) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) + + # Update replay buffer + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + + # Periodically reanalyze buffer + if cfg.policy.buffer_reanalyze_freq >= 1: + # Reanalyze buffer times in one train_epoch + reanalyze_interval = update_per_collect // cfg.policy.buffer_reanalyze_freq + else: + # Reanalyze buffer each <1/buffer_reanalyze_freq> train_epoch + if train_epoch % (1//cfg.policy.buffer_reanalyze_freq) == 0 and replay_buffer.get_num_of_transitions() > reanalyze_batch_size: + # When reanalyzing the buffer, the samples in the entire buffer are processed in mini-batches with a batch size of reanalyze_batch_size. + # This is an empirically selected value for optimal efficiency. + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + + + # Train the policy if sufficient data is available + if collector.envstep > cfg.policy.train_start_after_envsteps: + if cfg.policy.sample_type == 'episode': + data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size + else: + data_sufficient = replay_buffer.get_num_of_transitions() > batch_size + if not data_sufficient: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....' + ) + continue + + for i in range(update_per_collect): + + if cfg.policy.buffer_reanalyze_freq >= 1: + # Reanalyze buffer times in one train_epoch + if i % reanalyze_interval == 0 and replay_buffer.get_num_of_transitions() > reanalyze_batch_size: + # When reanalyzing the buffer, the samples in the entire buffer are processed in mini-batches with a batch size of reanalyze_batch_size. + # This is an empirically selected value for optimal efficiency. + replay_buffer.reanalyze_buffer(reanalyze_batch_size, policy) + buffer_reanalyze_count += 1 + logging.info(f'Buffer reanalyze count: {buffer_reanalyze_count}') + + train_data = replay_buffer.sample(batch_size, policy) + if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0: + # Clear caches and precompute positional embedding matrices + policy.recompute_pos_emb_diff_and_clear_cache() # TODO + + train_data.append({'train_which_component': 'transformer'}) + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + policy.recompute_pos_emb_diff_and_clear_cache() + + # Check stopping criteria + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + learner.call_hook('after_run') + return policy diff --git a/lzero/mcts/buffer/__init__.py b/lzero/mcts/buffer/__init__.py index d7ccb0678..dad94f24f 100644 --- a/lzero/mcts/buffer/__init__.py +++ b/lzero/mcts/buffer/__init__.py @@ -4,6 +4,7 @@ from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer from .game_buffer_sampled_muzero import SampledMuZeroGameBuffer from .game_buffer_sampled_unizero import SampledUniZeroGameBuffer +from .game_buffer_rezero_uz import UniZeroReGameBuffer from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer from .game_buffer_stochastic_muzero import StochasticMuZeroGameBuffer from .game_buffer_rezero_mz import ReZeroMZGameBuffer diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 4f4736faa..63094d257 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -195,6 +195,10 @@ def _sample_orig_reanalyze_data(self, batch_size: int) -> Tuple: game_segment_list.append(game_segment) pos_in_game_segment_list.append(pos_in_game_segment) + # TODO + # if pos_in_game_segment > self._cfg.game_segment_length - self._cfg.num_unroll_steps: + # pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps + 1, 1).item() + # pos_in_game_segment_list.append(pos_in_game_segment) make_time = [time.time() for _ in range(len(batch_index_list))] diff --git a/lzero/mcts/buffer/game_buffer_rezero_uz.py b/lzero/mcts/buffer/game_buffer_rezero_uz.py new file mode 100644 index 000000000..e0f7febca --- /dev/null +++ b/lzero/mcts/buffer/game_buffer_rezero_uz.py @@ -0,0 +1,514 @@ +from typing import Any, List, Tuple, Union, TYPE_CHECKING, Optional + +import numpy as np +import torch +from ding.utils import BUFFER_REGISTRY + +from lzero.mcts.tree_search.mcts_ctree import UniZeroMCTSCtree as MCTSCtree +from lzero.mcts.utils import prepare_observation +from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform +from .game_buffer_muzero import MuZeroGameBuffer +from ding.utils import BUFFER_REGISTRY, EasyTimer + +if TYPE_CHECKING: + from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy + + +@BUFFER_REGISTRY.register('game_buffer_unizero_re') +class UniZeroReGameBuffer(MuZeroGameBuffer): + """ + Overview: + The specific game buffer for MuZero policy. + """ + + def __init__(self, cfg: dict): + super().__init__(cfg) + """ + Overview: + Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key + in the default configuration, the user-provided value will override the default configuration. Otherwise, + the default configuration will be used. + """ + default_config = self.default_config() + default_config.update(cfg) + self._cfg = default_config + assert self._cfg.env_type in ['not_board_games', 'board_games'] + self.replay_buffer_size = self._cfg.replay_buffer_size + self.batch_size = self._cfg.batch_size + self._alpha = self._cfg.priority_prob_alpha + self._beta = self._cfg.priority_prob_beta + + self.keep_ratio = 1 + self.model_update_interval = 10 + self.num_of_collected_episodes = 0 + self.base_idx = 0 + self.clear_time = 0 + + self.game_segment_buffer = [] + self.game_pos_priorities = [] + self.game_segment_game_pos_look_up = [] + # self.task_id = self._cfg.task_id + self.sample_type = self._cfg.sample_type # 'transition' or 'episode' + + def reanalyze_buffer( + self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] + ) -> List[Any]: + """ + Overview: + sample data from ``GameBuffer`` and prepare the current and target batch for training. + Arguments: + - batch_size (:obj:`int`): batch size. + - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): policy. + Returns: + - train_data (:obj:`List`): List of train data, including current_batch and target_batch. + """ + policy._target_model.to(self._cfg.device) + policy._target_model.eval() + self.policy = policy + + # obtain the current_batch and prepare target context + reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( + batch_size, 1 + ) + # target policy + batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, + current_batch[1]) + + + def sample( + self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] + ) -> List[Any]: + """ + Overview: + sample data from ``GameBuffer`` and prepare the current and target batch for training. + Arguments: + - batch_size (:obj:`int`): batch size. + - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): policy. + Returns: + - train_data (:obj:`List`): List of train data, including current_batch and target_batch. + """ + policy._target_model.to(self._cfg.device) + policy._target_model.eval() + self.policy = policy + + # obtain the current_batch and prepare target context + reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( + batch_size, self._cfg.reanalyze_ratio + ) + + # current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] + + # target reward, target value + batch_rewards, batch_target_values = self._compute_target_reward_value( + reward_value_context, policy._target_model, current_batch[1] # current_batch[1] is action_batch + ) + # target policy + batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, + current_batch[1]) + batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( + policy_non_re_context, self._cfg.model.action_space_size + ) + + # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies + if 0 < self._cfg.reanalyze_ratio < 1: + batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) + elif self._cfg.reanalyze_ratio == 1: + batch_target_policies = batch_target_policies_re + elif self._cfg.reanalyze_ratio == 0: + batch_target_policies = batch_target_policies_non_re + + target_batch = [batch_rewards, batch_target_values, batch_target_policies] + + # a batch contains the current_batch and the target_batch + train_data = [current_batch, target_batch] + return train_data + + def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: + """ + Overview: + first sample orig_data through ``_sample_orig_data()``, + then prepare the context of a batch: + reward_value_context: the context of reanalyzed value targets + policy_re_context: the context of reanalyzed policy targets + policy_non_re_context: the context of non-reanalyzed policy targets + current_batch: the inputs of batch + Arguments: + - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. + - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) + Returns: + - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch + """ + # obtain the batch context from replay buffer + if self.sample_type == 'transition': + orig_data = self._sample_orig_data(batch_size) + elif self.sample_type == 'episode': + orig_data = self._sample_orig_data_episode(batch_size) + game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data + batch_size = len(batch_index_list) + obs_list, action_list, mask_list = [], [], [] + # prepare the inputs of a batch + for i in range(batch_size): + game = game_segment_list[i] + pos_in_game_segment = pos_in_game_segment_list[i] + + actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + + self._cfg.num_unroll_steps].tolist() + # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid + mask_tmp = [1. for i in range(len(actions_tmp))] + mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + + # pad random action + actions_tmp += [ + np.random.randint(0, game.action_space_size) + for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) + ] + + # obtain the input observations + # pad if length of obs in game_segment is less than stack+num_unroll_steps + # e.g. stack+num_unroll_steps = 4+5 + obs_list.append( + game_segment_list[i].get_unroll_obs( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + ) + ) + action_list.append(actions_tmp) + mask_list.append(mask_tmp) + + # formalize the input observations + obs_list = prepare_observation(obs_list, self._cfg.model.model_type) + + # formalize the inputs of a batch + current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] + for i in range(len(current_batch)): + current_batch[i] = np.asarray(current_batch[i]) + + total_transitions = self.get_num_of_transitions() + + # obtain the context of value targets + reward_value_context = self._prepare_reward_value_context( + batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions + ) + """ + only reanalyze recent reanalyze_ratio (e.g. 50%) data + if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps + 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy + """ + reanalyze_num = max(int(batch_size * reanalyze_ratio), 1) if reanalyze_ratio > 0 else 0 + # print(f'reanalyze_ratio: {reanalyze_ratio}, reanalyze_num: {reanalyze_num}') + self.reanalyze_num = reanalyze_num + # reanalyzed policy + if reanalyze_num > 0: + # obtain the context of reanalyzed policy targets + policy_re_context = self._prepare_policy_reanalyzed_context( + batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num], + pos_in_game_segment_list[:reanalyze_num] + ) + else: + policy_re_context = None + + # non reanalyzed policy + if reanalyze_num < batch_size: + # obtain the context of non-reanalyzed policy targets + policy_non_re_context = self._prepare_policy_non_reanalyzed_context( + batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:], + pos_in_game_segment_list[reanalyze_num:] + ) + else: + policy_non_re_context = None + + context = reward_value_context, policy_re_context, policy_non_re_context, current_batch + return context + + + def _prepare_policy_reanalyzed_context( + self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str] + ) -> List[Any]: + """ + Overview: + prepare the context of policies for calculating policy target in reanalyzing part. + Arguments: + - batch_index_list (:obj:'list'): start transition index in the replay buffer + - game_segment_list (:obj:'list'): list of game segments + - pos_in_game_segment_list (:obj:'list'): position of transition index in one game history + Returns: + - policy_re_context (:obj:`list`): policy_obs_list, policy_mask, pos_in_game_segment_list, indices, + child_visits, game_segment_lens, action_mask_segment, to_play_segment + """ + zero_obs = game_segment_list[0].zero_obs() + with torch.no_grad(): + # for policy + policy_obs_list = [] + policy_mask = [] + # 0 -> Invalid target policy for padding outside of game segments, + # 1 -> Previous target policy for game segments. + rewards, child_visits, game_segment_lens = [], [], [] + # for board games + action_mask_segment, to_play_segment = [], [] + for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list): + game_segment_len = len(game_segment) + game_segment_lens.append(game_segment_len) + rewards.append(game_segment.reward_segment) + # for board games + action_mask_segment.append(game_segment.action_mask_segment) + to_play_segment.append(game_segment.to_play_segment) + + child_visits.append(game_segment.child_visit_segment) + # prepare the corresponding observations + game_obs = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps) + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + + if current_index < game_segment_len: + policy_mask.append(1) + beg_index = current_index - state_index + end_index = beg_index + self._cfg.model.frame_stack_num + obs = game_obs[beg_index:end_index] + else: + policy_mask.append(0) + obs = zero_obs + policy_obs_list.append(obs) + + policy_re_context = [ + policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, + action_mask_segment, to_play_segment + ] + return policy_re_context + + def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, action_batch) -> np.ndarray: + """ + Overview: + prepare policy targets from the reanalyzed context of policies + Arguments: + - policy_re_context (:obj:`List`): List of policy context to reanalyzed + Returns: + - batch_target_policies_re + """ + if policy_re_context is None: + return [] + batch_target_policies_re = [] + + # for board games + policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \ + to_play_segment = policy_re_context # noqa + transition_batch_size = len(policy_obs_list) + game_segment_batch_size = len(pos_in_game_segment_list) + + to_play, action_mask = self._preprocess_to_play_and_action_mask( + game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list + ) + + if self._cfg.model.continuous_action_space is True: + # when the action space of the environment is continuous, action_mask[:] is None. + action_mask = [ + list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) + ] + # NOTE: in continuous action space env: we set all legal_actions as -1 + legal_actions = [ + [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) + ] + else: + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] + + # NOTE: TODO + model.world_model.reanalyze_phase = True + + with torch.no_grad(): + policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) + network_output = [] + m_obs = torch.from_numpy(policy_obs_list).to(self._cfg.device) + + # =============== NOTE: The key difference with MuZero ================= + # calculate the target value + # action_batch.shape (32, 10) + # m_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352 + m_output = model.initial_inference(m_obs, action_batch[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + # ======================================================================= + + if not model.training: + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + + network_output.append(m_output) + + _, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero') + reward_pool = reward_pool.squeeze().tolist() + policy_logits_pool = policy_logits_pool.tolist() + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size + ).astype(np.float32).tolist() for _ in range(transition_batch_size) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) + else: + # python mcts_tree + roots = MCTSPtree.roots(transition_batch_size, legal_actions) + roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) + # do MCTS for a new policy with the recent target model + MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) + + roots_legal_actions_list = legal_actions + roots_distributions = roots.get_distributions() + policy_index = 0 + for state_index, child_visit, game_index in zip(pos_in_game_segment_list, child_visits, batch_index_list): + target_policies = [] + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + distributions = roots_distributions[policy_index] + if policy_mask[policy_index] == 0: + # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 + target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) + else: + # NOTE: It is very important to use the latest MCTS visit count distribution. + sum_visits = sum(distributions) + child_visit[current_index] = [visit_count / sum_visits for visit_count in distributions] + + if distributions is None: + # if at some obs, the legal_action is None, add the fake target_policy + target_policies.append( + list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) + ) + else: + if self._cfg.env_type == 'not_board_games': + # for atari/classic_control/box2d environments that only have one player. + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + target_policies.append(policy) + else: + # for board games that have two players and legal_actions is dy + policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] + # to make sure target_policies have the same dimension + sum_visits = sum(distributions) + policy = [visit_count / sum_visits for visit_count in distributions] + for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): + policy_tmp[legal_action] = policy[index] + target_policies.append(policy_tmp) + + policy_index += 1 + + batch_target_policies_re.append(target_policies) + + batch_target_policies_re = np.array(batch_target_policies_re) + + # NOTE: TODO + model.world_model.reanalyze_phase = False + + return batch_target_policies_re + + + # _compute_target_reward_value_v1 去掉了action_mask legal_action相关的处理 + def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, action_batch) -> Tuple[ + Any, Any]: + """ + Overview: + prepare reward and value targets from the context of rewards and values. + Arguments: + - reward_value_context (:obj:'list'): the reward value context + - model (:obj:'torch.tensor'):model of the target model + Returns: + - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix + - batch_target_values (:obj:'np.ndarray): batch of value estimation + """ + value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ + to_play_segment = reward_value_context # noqa + # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) + transition_batch_size = len(value_obs_list) + + batch_target_values, batch_rewards = [], [] + with torch.no_grad(): + value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) + network_output = [] + m_obs = torch.from_numpy(value_obs_list).to(self._cfg.device) + + # =============== NOTE: The key difference with MuZero ================= + # calculate the target value + # m_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352 + m_output = model.initial_inference(m_obs, action_batch) + # ====================================================================== + + # if not model.training: + if self._cfg.device == 'cuda': + # if not in training, obtain the scalars of the value/reward + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + elif self._cfg.device == 'cpu': + # TODO + [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( + [ + m_output.latent_state, + inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), + m_output.policy_logits + ] + ) + + network_output.append(m_output) + + # use the predicted values + value_numpy = concat_output_value(network_output) + + # get last state value + if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: + # TODO(pu): for board_games, very important, to check + value_numpy = value_numpy.reshape(-1) * np.array( + [ + self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % + 2 == 0 else -self._cfg.discount_factor ** + td_steps_list[i] + for i in range(transition_batch_size) + ] + ) + else: + value_numpy = value_numpy.reshape(-1) * ( + np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list + ) + + value_numpy= value_numpy * np.array(value_mask) + value_list = value_numpy.tolist() + horizon_id, value_index = 0, 0 + + for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, + pos_in_game_segment_list, + to_play_segment): + target_values = [] + target_rewards = [] + base_index = state_index + for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): + bootstrap_index = current_index + td_steps_list[value_index] + for i, reward in enumerate(reward_list[current_index:bootstrap_index]): + if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: + # TODO(pu): for board_games, very important, to check + if to_play_list[base_index] == to_play_list[i]: + value_list[value_index] += reward * self._cfg.discount_factor ** i + else: + value_list[value_index] += -reward * self._cfg.discount_factor ** i + else: + value_list[value_index] += reward * self._cfg.discount_factor ** i + horizon_id += 1 + + if current_index < game_segment_len_non_re: + target_values.append(value_list[value_index]) + target_rewards.append(reward_list[current_index]) + else: + target_values.append(np.array(0.)) + target_rewards.append(np.array(0.)) + value_index += 1 + + batch_rewards.append(target_rewards) + batch_target_values.append(target_values) + + batch_rewards = np.asarray(batch_rewards) + batch_target_values = np.asarray(batch_target_values) + + return batch_rewards, batch_target_values diff --git a/zoo/atari/config/atari_rezero_mz_config.py b/zoo/atari/config/atari_rezero_mz_config.py index b3b42afc1..580591a7c 100644 --- a/zoo/atari/config/atari_rezero_mz_config.py +++ b/zoo/atari/config/atari_rezero_mz_config.py @@ -19,6 +19,17 @@ reuse_search = True collect_with_pure_policy = True buffer_reanalyze_freq = 1 + +# ====== only for debug ===== +collector_env_num = 8 +num_segments = 8 +evaluator_env_num = 2 +num_simulations = 5 +max_env_step = int(2e5) +reanalyze_ratio = 0.1 +batch_size = 64 +num_unroll_steps = 10 +replay_ratio = 0.01 # ============================================================== # end of the most frequently changed config specified by the user # ============================================================== @@ -33,6 +44,9 @@ evaluator_env_num=evaluator_env_num, n_evaluator_episode=evaluator_env_num, manager=dict(shared_memory=False, ), + # TODO: only for debug + collect_max_episode_steps=int(20), + eval_max_episode_steps=int(20), ), policy=dict( model=dict( diff --git a/zoo/atari/config/atari_unizero_sgement_config.py b/zoo/atari/config/atari_unizero_sgement_config.py index ba37271fd..0de6541e8 100644 --- a/zoo/atari/config/atari_unizero_sgement_config.py +++ b/zoo/atari/config/atari_unizero_sgement_config.py @@ -5,7 +5,6 @@ # env_id = 'SeaquestNoFrameskip-v4' # You can specify any Atari game here # env_id = 'QbertNoFrameskip-v4' # You can specify any Atari game here - action_space_size = atari_env_action_space_map[env_id] # ============================================================== @@ -13,36 +12,20 @@ # ============================================================== update_per_collect = None replay_ratio = 0.25 -# replay_ratio = 0.1 - -# replay_ratio = 1 - collector_env_num = 8 num_segments = 8 - -# collector_env_num = 4 -# num_segments = 4 - - -# num_segments = 1 game_segment_length=20 -# game_segment_length=15 -# game_segment_length=50 -# game_segment_length=100 -# game_segment_length=400 evaluator_env_num = 3 num_simulations = 50 max_env_step = int(2e5) - -# reanalyze_ratio = 0.1 reanalyze_ratio = 0. - batch_size = 64 num_unroll_steps = 10 infer_context_length = 4 - num_layers = 2 +buffer_reanalyze_freq = 1/10 # modify according to num_segments +reanalyze_batch_size = 2000 # ====== only for debug ===== collector_env_num = 8 @@ -53,7 +36,9 @@ reanalyze_ratio = 0. batch_size = 64 num_unroll_steps = 10 -replay_ratio = 0.05 +# buffer_reanalyze_freq = 1 +buffer_reanalyze_freq = 1/2 +reanalyze_batch_size = 20 # ============================================================== # end of the most frequently changed config specified by the user @@ -63,7 +48,6 @@ env=dict( stop_value=int(1e6), env_id=env_id, - # observation_shape=(3, 64, 64), observation_shape=(3, 96, 96), gray_scale=False, collector_env_num=collector_env_num, @@ -127,6 +111,9 @@ eval_freq=int(5e3), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, + # ============= The key different params for ReZero ============= + buffer_reanalyze_freq=buffer_reanalyze_freq, # 1 means reanalyze one times per epoch, 2 means reanalyze one times each two epoch + reanalyze_batch_size=reanalyze_batch_size, ), ) atari_unizero_config = EasyDict(atari_unizero_config) @@ -160,12 +147,14 @@ for seed in seeds: # Update exp_name to include the current seed # main_config.exp_name = f'data_efficiency0829_plus_tune-uz_0920/{env_id[:-14]}/{env_id[:-14]}_uz_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_temp025_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_H{num_unroll_steps}-infer{infer_context_length}_bs{batch_size}_seed{seed}' - # main_config.exp_name = f'data_efficiency0829_plus_tune-uz_0917/numsegments-{num_segments}_gsl{game_segment_length}_origin-target-value-policy_pew0_fixsample_temp025_useprio/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}_nlayer2' main_config.exp_name = f'data_efficiency0829_plus_tune-uz_debug/numsegments-{num_segments}_gsl{game_segment_length}_fix/obshape96_use-augmentation-obsw10/{env_id[:-14]}_stack1_unizero_upc{update_per_collect}-rr{replay_ratio}_H{num_unroll_steps}_bs{batch_size}_seed{seed}_nlayer2' - from lzero.entry import train_unizero - train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) + # from lzero.entry import train_unizero + # train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) + + from lzero.entry import train_rezero_uz + train_rezero_uz([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) # from lzero.entry import train_unizero diff --git a/zoo/atari/config/atari_unizero_sgement_config_batch.py b/zoo/atari/config/atari_unizero_sgement_config_batch.py index a6fe992dd..12f5e9961 100644 --- a/zoo/atari/config/atari_unizero_sgement_config_batch.py +++ b/zoo/atari/config/atari_unizero_sgement_config_batch.py @@ -10,17 +10,18 @@ def main(env_id, seed): # ============================================================== update_per_collect = None replay_ratio = 0.25 - # collector_env_num = 8 # TODO - # num_segments = 8 + collector_env_num = 8 # TODO + num_segments = 8 + # collector_env_num = 4 # TODO # num_segments = 4 # game_segment_length=10 + # collector_env_num = 1 # TODO + # num_segments = 1 - collector_env_num = 1 # TODO - num_segments = 1 game_segment_length=20 - evaluator_env_num = 8 # TODO + evaluator_env_num = 5 # TODO num_simulations = 50 max_env_step = int(5e5) # TODO @@ -34,6 +35,8 @@ def main(env_id, seed): # infer_context_length = 4 num_layers = 2 + buffer_reanalyze_freq = 1/10 # modify according to num_segments + reanalyze_batch_size = 2000 # ====== only for debug ===== # collector_env_num = 8 @@ -117,6 +120,9 @@ def main(env_id, seed): eval_freq=int(5e3), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, + # ============= The key different params for ReZero ============= + buffer_reanalyze_freq=buffer_reanalyze_freq, # 1 means reanalyze one times per epoch, 2 means reanalyze one times each two epoch + reanalyze_batch_size=reanalyze_batch_size, ), ) atari_unizero_config = EasyDict(atari_unizero_config) @@ -144,10 +150,14 @@ def main(env_id, seed): atari_unizero_create_config = EasyDict(atari_unizero_create_config) create_config = atari_unizero_create_config - main_config.exp_name = f'data_efficiency0829_plus_tune-uz_0920/{env_id[:-14]}/{env_id[:-14]}_uz_nlayer{num_layers}_eval8_collect{collector_env_num}-numsegments-{num_segments}_gsl{game_segment_length}_temp025_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_H{num_unroll_steps}-infer{infer_context_length}_bs{batch_size}_seed{seed}' - from lzero.entry import train_unizero - train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) + main_config.exp_name = f'data_efficiency0829_plus_tune-uz_0920/{env_id[:-14]}/{env_id[:-14]}_uz_brf{buffer_reanalyze_freq}_nlayer{num_layers}_eval5_collect{collector_env_num}-numsegments-{num_segments}_gsl{game_segment_length}_temp025_upc{update_per_collect}-rr{replay_ratio}_rer{reanalyze_ratio}_H{num_unroll_steps}-infer{infer_context_length}_bs{batch_size}_seed{seed}' + + # from lzero.entry import train_unizero + # train_unizero([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) + + from lzero.entry import train_rezero_uz + train_rezero_uz([main_config, create_config], seed=seed, model_path=main_config.policy.model_path, max_env_step=max_env_step) diff --git a/zoo/atari/config/sco_acp_mbq_uz_batch.sh b/zoo/atari/config/sco_acp_mbq_uz_batch.sh index 11b02c037..23c8ad1aa 100644 --- a/zoo/atari/config/sco_acp_mbq_uz_batch.sh +++ b/zoo/atari/config/sco_acp_mbq_uz_batch.sh @@ -64,7 +64,7 @@ for env in "${envs[@]}"; do sco acp jobs create --workspace-name=fb1861da-1c6c-42c7-87ed-e08d8b314a99 \ --aec2-name=eb37789e-90bb-418d-ad4a-19ce4b81ab0c\ - --job-name="uz-nlayer2-H10-seg1-gsl20-$env-s$seed" \ + --job-name="uz-nlayer2-H10-seg8-gsl20-brf1-5-$env-s$seed" \ --container-image-url='registry.cn-sh-01.sensecore.cn/basemodel-ccr/aicl-b27637a9-660e-4927:20231222-17h24m12s' \ --training-framework=pytorch \ --enable-mpi \