diff --git a/lzero/entry/train_unizero_reanalyze.py b/lzero/entry/train_unizero_reanalyze.py index 282f72405..aa3fc6aae 100644 --- a/lzero/entry/train_unizero_reanalyze.py +++ b/lzero/entry/train_unizero_reanalyze.py @@ -57,8 +57,7 @@ def train_unizero_reanalyze( assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'" # Import the correct GameBuffer class based on the policy type - # game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'} - game_buffer_classes = {'unizero': 'UniZeroReGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'} + game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'} GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]), game_buffer_classes[create_cfg.policy.type]) diff --git a/lzero/mcts/buffer/__init__.py b/lzero/mcts/buffer/__init__.py index dad94f24f..d7ccb0678 100644 --- a/lzero/mcts/buffer/__init__.py +++ b/lzero/mcts/buffer/__init__.py @@ -4,7 +4,6 @@ from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer from .game_buffer_sampled_muzero import SampledMuZeroGameBuffer from .game_buffer_sampled_unizero import SampledUniZeroGameBuffer -from .game_buffer_rezero_uz import UniZeroReGameBuffer from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer from .game_buffer_stochastic_muzero import StochasticMuZeroGameBuffer from .game_buffer_rezero_mz import ReZeroMZGameBuffer diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index d0520a48a..bad188912 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -125,9 +125,6 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: probs /= probs.sum() # sample according to transition index - # TODO(pu): replace=True - # print(f"num transitions is {num_of_transitions}") - # print(f"length of probs is {len(probs)}") batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False) if self._cfg.reanalyze_outdated is True: @@ -147,9 +144,7 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: game_segment_list.append(game_segment) # pos_in_game_segment_list.append(pos_in_game_segment) - - # pos_in_game_segment_list.append(max(pos_in_game_segment, self._cfg.game_segment_length - self._cfg.num_unroll_steps)) - # TODO + # TODO: check if pos_in_game_segment > self._cfg.game_segment_length - self._cfg.num_unroll_steps: pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps + 1, 1).item() pos_in_game_segment_list.append(pos_in_game_segment) @@ -160,7 +155,7 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) return orig_data - def _sample_orig_reanalyze_data_uz(self, batch_size: int) -> Tuple: + def _sample_orig_reanalyze_batch_data(self, batch_size: int) -> Tuple: """ Overview: sample orig_data that contains: @@ -176,16 +171,14 @@ def _sample_orig_reanalyze_data_uz(self, batch_size: int) -> Tuple: assert self._beta > 0 train_sample_num = (self.get_num_of_transitions()//self._cfg.num_unroll_steps) - # TODO: 只选择前 3/4 的样本 valid_sample_num = int(train_sample_num * self._cfg.reanalyze_partition) - # TODO: 动态调整衰减率,假设你希望衰减率与 valid_sample_num 成反比 - base_decay_rate = 5 # 基础衰减率,可以根据经验设定 - decay_rate = base_decay_rate / valid_sample_num # 随着样本数量增加,衰减率变小 - # 生成指数衰减的权重 (仅对前 3/4 的样本) + base_decay_rate = 5 + # decay rate becomes smaller as the number of samples increases + decay_rate = base_decay_rate / valid_sample_num + # Generate exponentially decaying weights (only for the first 3/4 of the samples) weights = np.exp(-decay_rate * np.arange(valid_sample_num)) - # 将权重归一化为概率分布 + # Normalize the weights to a probability distribution probabilities = weights / np.sum(weights) - # 按照概率分布进行采样 (仅在前 3/4 中采样) batch_index_list = np.random.choice(valid_sample_num, batch_size, replace=False, p=probabilities) if self._cfg.reanalyze_outdated is True: @@ -203,12 +196,6 @@ def _sample_orig_reanalyze_data_uz(self, batch_size: int) -> Tuple: game_segment_list.append(game_segment) pos_in_game_segment_list.append(pos_in_game_segment) - # pos_in_game_segment_list.append(max(pos_in_game_segment, self._cfg.game_segment_length - self._cfg.num_unroll_steps)) - # # TODO - # if pos_in_game_segment > self._cfg.game_segment_length - self._cfg.num_unroll_steps: - # pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps + 1, 1).item() - # pos_in_game_segment_list.append(pos_in_game_segment) - make_time = [time.time() for _ in range(len(batch_index_list))] @@ -250,10 +237,6 @@ def _sample_orig_reanalyze_data(self, batch_size: int) -> Tuple: game_segment_list.append(game_segment) pos_in_game_segment_list.append(pos_in_game_segment) - # TODO - # if pos_in_game_segment > self._cfg.game_segment_length - self._cfg.num_unroll_steps: - # pos_in_game_segment = np.random.choice(self._cfg.game_segment_length - self._cfg.num_unroll_steps + 1, 1).item() - # pos_in_game_segment_list.append(pos_in_game_segment) make_time = [time.time() for _ in range(len(batch_index_list))] diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index dd1a8d614..2f0acfdca 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -88,11 +88,11 @@ def reanalyze_buffer( policy._target_model.eval() self.policy = policy # obtain the current_batch and prepare target context - policy_re_context = self._make_batch_for_reanalyze(batch_size, 1) + policy_re_context = self._make_batch_for_reanalyze(batch_size) # target policy self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model) - def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: + def _make_batch_for_reanalyze(self, batch_size: int) -> Tuple[Any]: """ Overview: first sample orig_data through ``_sample_orig_data()``, @@ -103,12 +103,11 @@ def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) -> current_batch: the inputs of batch Arguments: - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. - - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) Returns: - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch """ # obtain the batch context from replay buffer - orig_data = self._sample_orig_reanalyze_data_uz(batch_size) + orig_data = self._sample_orig_reanalyze_batch_data(batch_size) game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data batch_size = len(batch_index_list) # obtain the context of reanalyzed policy targets diff --git a/lzero/mcts/buffer/game_buffer_rezero_uz.py b/lzero/mcts/buffer/game_buffer_rezero_uz.py deleted file mode 100644 index 003256ffd..000000000 --- a/lzero/mcts/buffer/game_buffer_rezero_uz.py +++ /dev/null @@ -1,580 +0,0 @@ -from typing import Any, List, Tuple, Union, TYPE_CHECKING, Optional - -import numpy as np -import torch -from ding.utils import BUFFER_REGISTRY - -from lzero.mcts.tree_search.mcts_ctree import UniZeroMCTSCtree as MCTSCtree -from lzero.mcts.utils import prepare_observation -from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform -from .game_buffer_muzero import MuZeroGameBuffer -from ding.utils import BUFFER_REGISTRY, EasyTimer - -if TYPE_CHECKING: - from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy - - -@BUFFER_REGISTRY.register('game_buffer_unizero_re') -class UniZeroReGameBuffer(MuZeroGameBuffer): - """ - Overview: - The specific game buffer for MuZero policy. - """ - - def __init__(self, cfg: dict): - super().__init__(cfg) - """ - Overview: - Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key - in the default configuration, the user-provided value will override the default configuration. Otherwise, - the default configuration will be used. - """ - default_config = self.default_config() - default_config.update(cfg) - self._cfg = default_config - assert self._cfg.env_type in ['not_board_games', 'board_games'] - self.replay_buffer_size = self._cfg.replay_buffer_size - self.batch_size = self._cfg.batch_size - self._alpha = self._cfg.priority_prob_alpha - self._beta = self._cfg.priority_prob_beta - - self.keep_ratio = 1 - self.model_update_interval = 10 - self.num_of_collected_episodes = 0 - self.base_idx = 0 - self.clear_time = 0 - - self.game_segment_buffer = [] - self.game_pos_priorities = [] - self.game_segment_game_pos_look_up = [] - # self.task_id = self._cfg.task_id - self.sample_type = self._cfg.sample_type # 'transition' or 'episode' - - def reanalyze_buffer( - self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] - ) -> List[Any]: - """ - Overview: - sample data from ``GameBuffer`` and prepare the current and target batch for training. - Arguments: - - batch_size (:obj:`int`): batch size. - - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): policy. - Returns: - - train_data (:obj:`List`): List of train data, including current_batch and target_batch. - """ - policy._target_model.to(self._cfg.device) - policy._target_model.eval() - self.policy = policy - - # obtain the current_batch and prepare target context - policy_re_context, current_batch = self._make_batch_for_reanalyze(batch_size, 1) - # target policy - self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, current_batch[1]) - - - def sample( - self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] - ) -> List[Any]: - """ - Overview: - sample data from ``GameBuffer`` and prepare the current and target batch for training. - Arguments: - - batch_size (:obj:`int`): batch size. - - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): policy. - Returns: - - train_data (:obj:`List`): List of train data, including current_batch and target_batch. - """ - policy._target_model.to(self._cfg.device) - policy._target_model.eval() - self.policy = policy - - # obtain the current_batch and prepare target context - reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( - batch_size, self._cfg.reanalyze_ratio - ) - - # current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] - - # target reward, target value - batch_rewards, batch_target_values = self._compute_target_reward_value( - reward_value_context, policy._target_model, current_batch[1] # current_batch[1] is action_batch - ) - # target policy - batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, - current_batch[1]) - batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed( - policy_non_re_context, self._cfg.model.action_space_size - ) - - # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies - if 0 < self._cfg.reanalyze_ratio < 1: - batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re]) - elif self._cfg.reanalyze_ratio == 1: - batch_target_policies = batch_target_policies_re - elif self._cfg.reanalyze_ratio == 0: - batch_target_policies = batch_target_policies_non_re - - target_batch = [batch_rewards, batch_target_values, batch_target_policies] - - # a batch contains the current_batch and the target_batch - train_data = [current_batch, target_batch] - return train_data - - def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: - """ - Overview: - first sample orig_data through ``_sample_orig_data()``, - then prepare the context of a batch: - reward_value_context: the context of reanalyzed value targets - policy_re_context: the context of reanalyzed policy targets - policy_non_re_context: the context of non-reanalyzed policy targets - current_batch: the inputs of batch - Arguments: - - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. - - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) - Returns: - - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch - """ - # obtain the batch context from replay buffer - if self.sample_type == 'transition': - orig_data = self._sample_orig_data(batch_size) - elif self.sample_type == 'episode': - orig_data = self._sample_orig_data_episode(batch_size) - game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data - batch_size = len(batch_index_list) - obs_list, action_list, mask_list = [], [], [] - # prepare the inputs of a batch - for i in range(batch_size): - game = game_segment_list[i] - pos_in_game_segment = pos_in_game_segment_list[i] - - actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + - self._cfg.num_unroll_steps].tolist() - # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid - mask_tmp = [1. for i in range(len(actions_tmp))] - mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] - - # pad random action - actions_tmp += [ - np.random.randint(0, game.action_space_size) - for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) - ] - - # obtain the input observations - # pad if length of obs in game_segment is less than stack+num_unroll_steps - # e.g. stack+num_unroll_steps = 4+5 - obs_list.append( - game_segment_list[i].get_unroll_obs( - pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True - ) - ) - action_list.append(actions_tmp) - mask_list.append(mask_tmp) - - # formalize the input observations - obs_list = prepare_observation(obs_list, self._cfg.model.model_type) - - # formalize the inputs of a batch - current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] - for i in range(len(current_batch)): - current_batch[i] = np.asarray(current_batch[i]) - - total_transitions = self.get_num_of_transitions() - - # obtain the context of value targets - reward_value_context = self._prepare_reward_value_context( - batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions - ) - """ - only reanalyze recent reanalyze_ratio (e.g. 50%) data - if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps - 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy - """ - reanalyze_num = max(int(batch_size * reanalyze_ratio), 1) if reanalyze_ratio > 0 else 0 - # print(f'reanalyze_ratio: {reanalyze_ratio}, reanalyze_num: {reanalyze_num}') - self.reanalyze_num = reanalyze_num - # reanalyzed policy - if reanalyze_num > 0: - # obtain the context of reanalyzed policy targets - policy_re_context = self._prepare_policy_reanalyzed_context( - batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num], - pos_in_game_segment_list[:reanalyze_num] - ) - else: - policy_re_context = None - - # non reanalyzed policy - if reanalyze_num < batch_size: - # obtain the context of non-reanalyzed policy targets - policy_non_re_context = self._prepare_policy_non_reanalyzed_context( - batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:], - pos_in_game_segment_list[reanalyze_num:] - ) - else: - policy_non_re_context = None - - context = reward_value_context, policy_re_context, policy_non_re_context, current_batch - return context - - def _make_batch_for_reanalyze(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: - """ - Overview: - first sample orig_data through ``_sample_orig_data()``, - then prepare the context of a batch: - reward_value_context: the context of reanalyzed value targets - policy_re_context: the context of reanalyzed policy targets - policy_non_re_context: the context of non-reanalyzed policy targets - current_batch: the inputs of batch - Arguments: - - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. - - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed) - Returns: - - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch - """ - # obtain the batch context from replay buffer - if self.sample_type == 'transition': - orig_data = self._sample_orig_reanalyze_data_uz(batch_size) - # elif self.sample_type == 'episode': # TODO - # orig_data = self._sample_orig_data_episode(batch_size) - game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data - batch_size = len(batch_index_list) - obs_list, action_list, mask_list = [], [], [] - # prepare the inputs of a batch - for i in range(batch_size): - game = game_segment_list[i] - pos_in_game_segment = pos_in_game_segment_list[i] - - actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + - self._cfg.num_unroll_steps].tolist() - # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid - mask_tmp = [1. for i in range(len(actions_tmp))] - mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] - - # pad random action - actions_tmp += [ - np.random.randint(0, game.action_space_size) - for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) - ] - - # obtain the input observations - # pad if length of obs in game_segment is less than stack+num_unroll_steps - # e.g. stack+num_unroll_steps = 4+5 - obs_list.append( - game_segment_list[i].get_unroll_obs( - pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True - ) - ) - action_list.append(actions_tmp) - mask_list.append(mask_tmp) - - # formalize the input observations - obs_list = prepare_observation(obs_list, self._cfg.model.model_type) - - # formalize the inputs of a batch - current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] - for i in range(len(current_batch)): - current_batch[i] = np.asarray(current_batch[i]) - - # reanalyzed policy - # obtain the context of reanalyzed policy targets - policy_re_context = self._prepare_policy_reanalyzed_context( - batch_index_list, game_segment_list, - pos_in_game_segment_list - ) - - context = policy_re_context, current_batch - self.reanalyze_num = batch_size - return context - - def _prepare_policy_reanalyzed_context( - self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str] - ) -> List[Any]: - """ - Overview: - prepare the context of policies for calculating policy target in reanalyzing part. - Arguments: - - batch_index_list (:obj:'list'): start transition index in the replay buffer - - game_segment_list (:obj:'list'): list of game segments - - pos_in_game_segment_list (:obj:'list'): position of transition index in one game history - Returns: - - policy_re_context (:obj:`list`): policy_obs_list, policy_mask, pos_in_game_segment_list, indices, - child_visits, game_segment_lens, action_mask_segment, to_play_segment - """ - zero_obs = game_segment_list[0].zero_obs() - with torch.no_grad(): - # for policy - policy_obs_list = [] - policy_mask = [] - # 0 -> Invalid target policy for padding outside of game segments, - # 1 -> Previous target policy for game segments. - rewards, child_visits, game_segment_lens = [], [], [] - # for board games - action_mask_segment, to_play_segment = [], [] - for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list): - game_segment_len = len(game_segment) - game_segment_lens.append(game_segment_len) - rewards.append(game_segment.reward_segment) - # for board games - action_mask_segment.append(game_segment.action_mask_segment) - to_play_segment.append(game_segment.to_play_segment) - - child_visits.append(game_segment.child_visit_segment) - # prepare the corresponding observations - game_obs = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps) - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - - if current_index < game_segment_len: - policy_mask.append(1) - beg_index = current_index - state_index - end_index = beg_index + self._cfg.model.frame_stack_num - obs = game_obs[beg_index:end_index] - else: - policy_mask.append(0) - obs = zero_obs - policy_obs_list.append(obs) - - policy_re_context = [ - policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, - action_mask_segment, to_play_segment - ] - return policy_re_context - - def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, action_batch) -> np.ndarray: - """ - Overview: - prepare policy targets from the reanalyzed context of policies - Arguments: - - policy_re_context (:obj:`List`): List of policy context to reanalyzed - Returns: - - batch_target_policies_re - """ - if policy_re_context is None: - return [] - batch_target_policies_re = [] - - # for board games - policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \ - to_play_segment = policy_re_context # noqa - transition_batch_size = len(policy_obs_list) - game_segment_batch_size = len(pos_in_game_segment_list) - - to_play, action_mask = self._preprocess_to_play_and_action_mask( - game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - ) - - if self._cfg.model.continuous_action_space is True: - # when the action space of the environment is continuous, action_mask[:] is None. - action_mask = [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size) - ] - # NOTE: in continuous action space env: we set all legal_actions as -1 - legal_actions = [ - [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size) - ] - else: - legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] - - # NOTE: TODO - model.world_model.reanalyze_phase = True - - with torch.no_grad(): - policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) - network_output = [] - m_obs = torch.from_numpy(policy_obs_list).to(self._cfg.device) - - # =============== NOTE: The key difference with MuZero ================= - # calculate the target value - # action_batch.shape (32, 10) - # m_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352 - m_output = model.initial_inference(m_obs, action_batch[:self.reanalyze_num]) # NOTE: :self.reanalyze_num - # ======================================================================= - - if not model.training: - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) - - network_output.append(m_output) - - _, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero') - reward_pool = reward_pool.squeeze().tolist() - policy_logits_pool = policy_logits_pool.tolist() - noises = [ - np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size - ).astype(np.float32).tolist() for _ in range(transition_batch_size) - ] - if self._cfg.mcts_ctree: - # cpp mcts_tree - roots = MCTSCtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play) - else: - # python mcts_tree - roots = MCTSPtree.roots(transition_batch_size, legal_actions) - roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play) - # do MCTS for a new policy with the recent target model - MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play) - - roots_legal_actions_list = legal_actions - roots_distributions = roots.get_distributions() - policy_index = 0 - for state_index, child_visit, game_index in zip(pos_in_game_segment_list, child_visits, batch_index_list): - target_policies = [] - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - distributions = roots_distributions[policy_index] - if policy_mask[policy_index] == 0: - # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 - target_policies.append([0 for _ in range(self._cfg.model.action_space_size)]) - else: - # NOTE: It is very important to use the latest MCTS visit count distribution. - sum_visits = sum(distributions) - child_visit[current_index] = [visit_count / sum_visits for visit_count in distributions] - - if distributions is None: - # if at some obs, the legal_action is None, add the fake target_policy - target_policies.append( - list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size) - ) - else: - if self._cfg.env_type == 'not_board_games': - # for atari/classic_control/box2d environments that only have one player. - sum_visits = sum(distributions) - policy = [visit_count / sum_visits for visit_count in distributions] - target_policies.append(policy) - else: - # for board games that have two players and legal_actions is dy - policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)] - # to make sure target_policies have the same dimension - sum_visits = sum(distributions) - policy = [visit_count / sum_visits for visit_count in distributions] - for index, legal_action in enumerate(roots_legal_actions_list[policy_index]): - policy_tmp[legal_action] = policy[index] - target_policies.append(policy_tmp) - - policy_index += 1 - - batch_target_policies_re.append(target_policies) - - batch_target_policies_re = np.array(batch_target_policies_re) - - # NOTE: TODO - model.world_model.reanalyze_phase = False - - return batch_target_policies_re - - - # _compute_target_reward_value_v1 去掉了action_mask legal_action相关的处理 - def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, action_batch) -> Tuple[ - Any, Any]: - """ - Overview: - prepare reward and value targets from the context of rewards and values. - Arguments: - - reward_value_context (:obj:'list'): the reward value context - - model (:obj:'torch.tensor'):model of the target model - Returns: - - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix - - batch_target_values (:obj:'np.ndarray): batch of value estimation - """ - value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \ - to_play_segment = reward_value_context # noqa - # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1) - transition_batch_size = len(value_obs_list) - - batch_target_values, batch_rewards = [], [] - with torch.no_grad(): - value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) - network_output = [] - m_obs = torch.from_numpy(value_obs_list).to(self._cfg.device) - - # =============== NOTE: The key difference with MuZero ================= - # calculate the target value - # m_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352 - m_output = model.initial_inference(m_obs, action_batch) - # ====================================================================== - - # if not model.training: - if self._cfg.device == 'cuda': - # if not in training, obtain the scalars of the value/reward - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) - elif self._cfg.device == 'cpu': - # TODO - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) - - network_output.append(m_output) - - # use the predicted values - value_numpy = concat_output_value(network_output) - - # get last state value - if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: - # TODO(pu): for board_games, very important, to check - value_numpy = value_numpy.reshape(-1) * np.array( - [ - self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % - 2 == 0 else -self._cfg.discount_factor ** - td_steps_list[i] - for i in range(transition_batch_size) - ] - ) - else: - value_numpy = value_numpy.reshape(-1) * ( - np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list - ) - - value_numpy= value_numpy * np.array(value_mask) - value_list = value_numpy.tolist() - horizon_id, value_index = 0, 0 - - for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list, - pos_in_game_segment_list, - to_play_segment): - target_values = [] - target_rewards = [] - base_index = state_index - for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1): - bootstrap_index = current_index + td_steps_list[value_index] - for i, reward in enumerate(reward_list[current_index:bootstrap_index]): - if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: - # TODO(pu): for board_games, very important, to check - if to_play_list[base_index] == to_play_list[i]: - value_list[value_index] += reward * self._cfg.discount_factor ** i - else: - value_list[value_index] += -reward * self._cfg.discount_factor ** i - else: - value_list[value_index] += reward * self._cfg.discount_factor ** i - horizon_id += 1 - - if current_index < game_segment_len_non_re: - target_values.append(value_list[value_index]) - target_rewards.append(reward_list[current_index]) - else: - target_values.append(np.array(0.)) - target_rewards.append(np.array(0.)) - value_index += 1 - - batch_rewards.append(target_rewards) - batch_target_values.append(target_values) - - batch_rewards = np.asarray(batch_rewards) - batch_target_values = np.asarray(batch_target_values) - - return batch_rewards, batch_target_values diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index 6fc985352..9787bb21c 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -63,7 +63,6 @@ def sample( """ policy._target_model.to(self._cfg.device) policy._target_model.eval() - self.policy = policy # obtain the current_batch and prepare target context reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch( @@ -74,7 +73,7 @@ def sample( # target reward, target value batch_rewards, batch_target_values = self._compute_target_reward_value( - reward_value_context, policy._target_model, current_batch[1] # current_batch[1] is action_batch + reward_value_context, policy._target_model, current_batch[1] # current_batch[1] is batch_action ) # target policy batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, @@ -193,6 +192,94 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: context = reward_value_context, policy_re_context, policy_non_re_context, current_batch return context + def reanalyze_buffer( + self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] + ) -> List[Any]: + """ + Overview: + sample data from ``GameBuffer`` and prepare the current and target batch for training. + Arguments: + - batch_size (:obj:`int`): batch size. + - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): policy. + Returns: + - train_data (:obj:`List`): List of train data, including current_batch and target_batch. + """ + policy._target_model.to(self._cfg.device) + policy._target_model.eval() + + # obtain the current_batch and prepare target context + policy_re_context, current_batch = self._make_batch_for_reanalyze(batch_size) + # target policy + self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model, current_batch[1]) + + def _make_batch_for_reanalyze(self, batch_size: int) -> Tuple[Any]: + """ + Overview: + first sample orig_data through ``_sample_orig_data()``, + then prepare the context of a batch: + reward_value_context: the context of reanalyzed value targets + policy_re_context: the context of reanalyzed policy targets + policy_non_re_context: the context of non-reanalyzed policy targets + current_batch: the inputs of batch + Arguments: + - batch_size (:obj:`int`): the batch size of orig_data from replay buffer. + Returns: + - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch + """ + # obtain the batch context from replay buffer + if self.sample_type == 'transition': + orig_data = self._sample_orig_reanalyze_batch_data(batch_size) + # elif self.sample_type == 'episode': # TODO + # orig_data = self._sample_orig_data_episode(batch_size) + game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data + batch_size = len(batch_index_list) + obs_list, action_list, mask_list = [], [], [] + # prepare the inputs of a batch + for i in range(batch_size): + game = game_segment_list[i] + pos_in_game_segment = pos_in_game_segment_list[i] + + actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment + + self._cfg.num_unroll_steps].tolist() + # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid + mask_tmp = [1. for i in range(len(actions_tmp))] + mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))] + + # pad random action + actions_tmp += [ + np.random.randint(0, game.action_space_size) + for _ in range(self._cfg.num_unroll_steps - len(actions_tmp)) + ] + + # obtain the input observations + # pad if length of obs in game_segment is less than stack+num_unroll_steps + # e.g. stack+num_unroll_steps = 4+5 + obs_list.append( + game_segment_list[i].get_unroll_obs( + pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True + ) + ) + action_list.append(actions_tmp) + mask_list.append(mask_tmp) + + # formalize the input observations + obs_list = prepare_observation(obs_list, self._cfg.model.model_type) + + # formalize the inputs of a batch + current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list] + for i in range(len(current_batch)): + current_batch[i] = np.asarray(current_batch[i]) + + # reanalyzed policy + # obtain the context of reanalyzed policy targets + policy_re_context = self._prepare_policy_reanalyzed_context( + batch_index_list, game_segment_list, + pos_in_game_segment_list + ) + + context = policy_re_context, current_batch + self.reanalyze_num = batch_size + return context def _prepare_policy_reanalyzed_context( self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str] @@ -247,7 +334,7 @@ def _prepare_policy_reanalyzed_context( ] return policy_re_context - def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, action_batch) -> np.ndarray: + def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, batch_action) -> np.ndarray: """ Overview: prepare policy targets from the reanalyzed context of policies @@ -288,13 +375,13 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: with torch.no_grad(): policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type) network_output = [] - m_obs = torch.from_numpy(policy_obs_list).to(self._cfg.device) + batch_obs = torch.from_numpy(policy_obs_list).to(self._cfg.device) # =============== NOTE: The key difference with MuZero ================= # calculate the target value - # action_batch.shape (32, 10) - # m_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352 - m_output = model.initial_inference(m_obs, action_batch[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + # batch_action.shape (32, 10) + # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352 + m_output = model.initial_inference(batch_obs, batch_action[:self.reanalyze_num]) # NOTE: :self.reanalyze_num # ======================================================================= if not model.training: @@ -376,135 +463,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: return batch_target_policies_re - # 可以直接替换game_buffer_muzero中相应函数 - # _compute_target_policy_non_reanalyzed_v2 - # def _compute_target_policy_non_reanalyzed( - # self, policy_non_re_context: List[Any], policy_shape: Optional[int] - # ) -> np.ndarray: - # """ - # Overview: - # prepare policy targets from the non-reanalyzed context of policies - # Arguments: - # - policy_non_re_context (:obj:`List`): List containing: - # - pos_in_game_segment_list - # - child_visits - # - game_segment_lens - # - action_mask_segment - # - to_play_segment - # - policy_shape: self._cfg.model.action_space_size - # Returns: - # - batch_target_policies_non_re - # """ - # if policy_non_re_context is None: - # return np.array([]) - - # pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment = policy_non_re_context - # game_segment_batch_size = len(pos_in_game_segment_list) - # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1) - - # if self._cfg.action_type != 'fixed_action_space': - # to_play, action_mask = self._preprocess_to_play_and_action_mask( - # game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list - # ) - - # if self._cfg.model.continuous_action_space: - # action_mask = np.ones((transition_batch_size, self._cfg.model.action_space_size), dtype=np.int8) - # legal_actions = np.full((transition_batch_size, self._cfg.model.action_space_size), -1) - # else: - # if self._cfg.action_type != 'fixed_action_space': - # legal_actions = np.array([[i for i, x in enumerate(mask) if x == 1] for mask in action_mask]) - - # batch_target_policies_non_re = np.zeros((game_segment_batch_size, self._cfg.num_unroll_steps + 1, policy_shape)) - - # for i, (game_segment_len, child_visit, state_index) in enumerate(zip(game_segment_lens, child_visits, pos_in_game_segment_list)): - # valid_steps = min(game_segment_len - state_index, self._cfg.num_unroll_steps + 1) - - # for j in range(valid_steps): - # current_index = state_index + j - # distributions = child_visit[current_index] - - # if self._cfg.action_type == 'fixed_action_space': - # batch_target_policies_non_re[i, j] = distributions - # else: - # policy_tmp = np.zeros(policy_shape) - # policy_tmp[legal_actions[i * (self._cfg.num_unroll_steps + 1) + j]] = distributions - # batch_target_policies_non_re[i, j] = policy_tmp - - # return batch_target_policies_non_re - - # _compute_target_reward_value_v2 去掉了action_mask legal_action相关的处理, 并额外进行了优化 - # def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, action_batch) -> Tuple[np.ndarray, np.ndarray]: - # value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, to_play_segment = reward_value_context - # transition_batch_size = len(value_obs_list) - - # with torch.no_grad(): - # value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) - # m_obs = torch.from_numpy(value_obs_list).to(self._cfg.device) - - # if self._cfg.use_augmentation: - # # NOTE: TODO - # m_obs = self.policy.image_transforms.transform(m_obs) - - # m_output = model.initial_inference(m_obs, action_batch) - - # if self._cfg.device == 'cuda': - # [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - # [ - # m_output.latent_state, - # inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - # m_output.policy_logits - # ] - # ) - # elif self._cfg.device == 'cpu': - # [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - # [ - # m_output.latent_state, - # inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - # m_output.policy_logits - # ] - # ) - - # value_numpy = concat_output_value([m_output]).reshape(-1) - - # if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: - # value_numpy = value_numpy * np.array( - # [ - # self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) % 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i] - # for i in range(transition_batch_size) - # ] - # ) - # else: - # value_numpy = value_numpy * (np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list) - - # value_numpy = value_numpy * np.array(value_mask) - - # batch_rewards = np.zeros((len(game_segment_lens), self._cfg.num_unroll_steps + 1)) - # batch_target_values = np.zeros((len(game_segment_lens), self._cfg.num_unroll_steps + 1)) - - # value_index = 0 - # for i, (game_segment_len_non_re, reward_list, state_index, to_play_list) in enumerate(zip(game_segment_lens, rewards_list, pos_in_game_segment_list, to_play_segment)): - # base_index = state_index - # for j in range(state_index, min(state_index + self._cfg.num_unroll_steps + 1, game_segment_len_non_re + 1)): - # bootstrap_index = min(j + td_steps_list[value_index], len(reward_list)) - # discount_factors = self._cfg.discount_factor ** np.arange(bootstrap_index - j) - - # if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]: - # rewards = np.where(to_play_list[base_index] == to_play_list[j:bootstrap_index], reward_list[j:bootstrap_index], -reward_list[j:bootstrap_index]) - # else: - # rewards = reward_list[j:bootstrap_index] - - # value_numpy[value_index] += np.sum(rewards * discount_factors[:len(rewards)]) - - # if j < game_segment_len_non_re: - # batch_target_values[i, j - state_index] = value_numpy[value_index] - # batch_rewards[i, j - state_index] = reward_list[j] - - # value_index += 1 - - # return batch_rewards, batch_target_values - - # _compute_target_reward_value_v1 去掉了action_mask legal_action相关的处理 - def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, action_batch) -> Tuple[ + def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, batch_action) -> Tuple[ Any, Any]: """ Overview: @@ -525,16 +484,15 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A with torch.no_grad(): value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type) network_output = [] - m_obs = torch.from_numpy(value_obs_list).to(self._cfg.device) + batch_obs = torch.from_numpy(value_obs_list).to(self._cfg.device) # =============== NOTE: The key difference with MuZero ================= # calculate the target value - # m_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352 - m_output = model.initial_inference(m_obs, action_batch) + # batch_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352 + m_output = model.initial_inference(batch_obs, batch_action) # ====================================================================== - # if not model.training: - if self._cfg.device == 'cuda': + if not model.training: # if not in training, obtain the scalars of the value/reward [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( [ @@ -543,15 +501,6 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A m_output.policy_logits ] ) - elif self._cfg.device == 'cpu': - # TODO - [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy( - [ - m_output.latent_state, - inverse_scalar_transform(m_output.value, self._cfg.model.support_scale), - m_output.policy_logits - ] - ) network_output.append(m_output)