From dd2c95c549f568a8407ee05e4b06fe8c519bc9ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Fri, 5 Jul 2024 16:32:14 +0800 Subject: [PATCH 01/13] feature(pu): add UniZero multitask related pipeline --- lzero/entry/__init__.py | 1 + lzero/entry/train_unizero.py | 1 - lzero/entry/train_unizero_multitask.py | 215 +++ lzero/mcts/buffer/game_buffer.py | 3 +- lzero/mcts/buffer/game_buffer_unizero.py | 12 +- lzero/mcts/tree_search/mcts_ctree.py | 5 +- lzero/model/unizero_model.py | 27 +- lzero/model/unizero_model_multitask.py | 233 ++++ lzero/model/unizero_world_models/tokenizer.py | 19 +- .../world_model_multitask.py | 1222 +++++++++++++++++ lzero/policy/unizero.py | 11 +- lzero/policy/unizero_multitask.py | 1103 +++++++++++++++ lzero/worker/muzero_collector.py | 16 +- lzero/worker/muzero_evaluator.py | 16 +- .../config/atari_unizero_multitask_config.py | 144 ++ zoo/atari/envs/atari_lightzero_env.py | 2 + zoo/atari/envs/atari_wrappers.py | 4 +- 17 files changed, 2996 insertions(+), 38 deletions(-) create mode 100644 lzero/entry/train_unizero_multitask.py create mode 100644 lzero/model/unizero_model_multitask.py create mode 100644 lzero/model/unizero_world_models/world_model_multitask.py create mode 100644 lzero/policy/unizero_multitask.py create mode 100644 zoo/atari/config/atari_unizero_multitask_config.py diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index 200b91441..72aad548f 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -8,3 +8,4 @@ from .train_muzero_with_reward_model import train_muzero_with_reward_model from .train_rezero import train_rezero from .train_unizero import train_unizero +from .train_unizero_multitask import train_unizero_multitask \ No newline at end of file diff --git a/lzero/entry/train_unizero.py b/lzero/entry/train_unizero.py index 969b1f947..13ca50c3e 100644 --- a/lzero/entry/train_unizero.py +++ b/lzero/entry/train_unizero.py @@ -172,7 +172,6 @@ def train_unizero( # Clear caches and precompute positional embedding matrices policy.recompute_pos_emb_diff_and_clear_cache() # TODO - train_data.append({'train_which_component': 'transformer'}) log_vars = learner.train(train_data, collector.envstep) if cfg.policy.use_priority: diff --git a/lzero/entry/train_unizero_multitask.py b/lzero/entry/train_unizero_multitask.py new file mode 100644 index 000000000..0ba9d2233 --- /dev/null +++ b/lzero/entry/train_unizero_multitask.py @@ -0,0 +1,215 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional, List + +import torch +from ding.config import compile_config +from ding.envs import create_env_manager, get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.worker import MuZeroCollector as Collector, MuZeroEvaluator as Evaluator +from lzero.mcts import UniZeroGameBuffer as GameBuffer + + +def train_unizero_multitask( + input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The train entry for UniZero, proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models. + UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms, + particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + Arguments: + - input_cfg_list (List[Tuple[int, Tuple[dict, dict]]]): List of configurations for different tasks. + - seed (int): Random seed. + - model (Optional[torch.nn.Module]): Instance of torch.nn.Module. + - model_path (Optional[str]): The pretrained model path, which should point to the ckpt file of the pretrained model. + - max_train_iter (Optional[int]): Maximum policy update iterations in training. + - max_env_step (Optional[int]): Maximum collected environment interaction steps. + Returns: + - policy (Policy): Converged policy. + """ + cfgs = [] + game_buffers = [] + collector_envs = [] + evaluator_envs = [] + collectors = [] + evaluators = [] + + task_id, [cfg, create_cfg] = input_cfg_list[0] + + # Ensure the specified policy type is supported + assert create_cfg.policy.type in ['unizero_multitask'], "train_unizero entry now only supports 'unizero'" + + # Set device based on CUDA availability + cfg.policy.device = cfg.policy.model.world_model.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # Compile the configuration + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + # Create shared policy for all tasks + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load pretrained model if specified + if model_path is not None: + logging.info(f'Loading model from {model_path} begin...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Loading model from {model_path} end!') + + # Create SummaryWriter for TensorBoard logging + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + # Create shared learner for all tasks + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # TODO task_id = 0: + policy_config = cfg.policy + batch_size = policy_config.batch_size[0] + + for task_id, input_cfg in input_cfg_list: + if task_id > 0: + # Get the configuration for each task + cfg, create_cfg = input_cfg + cfg.policy.device = 'cuda' if cfg.policy.cuda and torch.cuda.is_available() else 'cpu' + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + policy_config = cfg.policy + policy.collect_mode.get_attribute('cfg').n_episode = policy_config.n_episode + policy.eval_mode.get_attribute('cfg').n_episode = policy_config.n_episode + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + collector_env.seed(cfg.seed + task_id) + evaluator_env.seed(cfg.seed + task_id, dynamic_seed=False) + set_pkg_seed(cfg.seed + task_id, use_cuda=cfg.policy.cuda) + + # ===== NOTE: Create different game buffer, collector, evaluator for each task ==== + # TODO: share replay buffer for all tasks + replay_buffer = GameBuffer(policy_config) + collector = Collector( + env=collector_env, + policy=policy.collect_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + evaluator = Evaluator( + eval_freq=cfg.policy.eval_freq, + n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, + env=evaluator_env, + policy=policy.eval_mode, + tb_logger=tb_logger, + exp_name=cfg.exp_name, + policy_config=policy_config, + task_id=task_id + ) + + cfgs.append(cfg) + replay_buffer.batch_size = cfg.policy.batch_size[task_id] + game_buffers.append(replay_buffer) + collector_envs.append(collector_env) + evaluator_envs.append(evaluator_env) + collectors.append(collector) + evaluators.append(evaluator) + + learner.call_hook('before_run') + + while True: + # Precompute positional embedding matrices for collect/eval (not training) + policy._collect_model.world_model.precompute_pos_emb_diff_kv() + policy._target_model.world_model.precompute_pos_emb_diff_kv() + + # Collect data for each task + for task_id, (cfg, collector, evaluator, replay_buffer) in enumerate( + zip(cfgs, collectors, evaluators, game_buffers)): + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # Default epsilon value + } + + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + if evaluator.should_eval(learner.train_iter): + print('=' * 20) + print(f'evaluate task_id: {task_id}...') + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + print('=' * 20) + print(f'collect task_id: {task_id}...') + + # Reset initial data before each collection + collector._policy.reset(reset_init_data=True) + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # Determine updates per collection + update_per_collect = cfg.policy.update_per_collect + if update_per_collect is None: + collected_transitions_num = sum(len(game_segment) for game_segment in new_data[0]) + update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio) + + # Update replay buffer + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + not_enough_data = any(replay_buffer.get_num_of_transitions() < batch_size for replay_buffer in game_buffers) + + # Learn policy from collected data. + if not not_enough_data: + # Learner will train ``update_per_collect`` times in one iteration. + for i in range(update_per_collect): + train_data_multi_task = [] + envstep_multi_task = 0 + for task_id, (cfg, collector, replay_buffer) in enumerate(zip(cfgs, collectors, game_buffers)): + envstep_multi_task += collector.envstep + if replay_buffer.get_num_of_transitions() > batch_size: + batch_size = cfg.policy.batch_size[task_id] + train_data = replay_buffer.sample(batch_size, policy) + if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0: + policy.recompute_pos_emb_diff_and_clear_cache() + # Append task_id to train_data + train_data.append(task_id) + train_data_multi_task.append(train_data) + else: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}' + ) + break + + if train_data_multi_task: + log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + if all(collector.envstep >= max_env_step for collector in collectors) or learner.train_iter >= max_train_iter: + break + + learner.call_hook('after_run') + return policy \ No newline at end of file diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index f9fbde1f8..ec1babb6b 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -463,7 +463,8 @@ def remove_oldest_data_to_fit(self) -> None: Overview: remove some oldest data if the replay buffer is full. """ - assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" + if isinstance(self._cfg.batch_size, int): + assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size" nums_of_game_segments = self.get_num_of_game_segments() total_transition = self.get_num_of_transitions() if total_transition > self.replay_buffer_size: diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index f7b6c4c9a..2803b0213 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -47,9 +47,15 @@ def __init__(self, cfg: dict): self.game_segment_buffer = [] self.game_pos_priorities = [] self.game_segment_game_pos_look_up = [] - # self.task_id = self._cfg.task_id self.sample_type = self._cfg.sample_type # 'transition' or 'episode' + if hasattr(self._cfg, 'task_id'): + self.task_id = self._cfg.task_id + print(f"Task ID is set to {self.task_id}.") + else: + self.task_id = None + print("No task_id found in configuration. Task ID is set to None.") + def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -289,7 +295,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: # calculate the target value # action_batch.shape (32, 10) # m_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352 - m_output = model.initial_inference(m_obs, action_batch[:self.reanalyze_num]) # NOTE: :self.reanalyze_num + m_output = model.initial_inference(m_obs, action_batch[:self.reanalyze_num], task_id=self.task_id) # NOTE: :self.reanalyze_num # ======================================================================= if not model.training: @@ -410,7 +416,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A # =============== NOTE: The key difference with MuZero ================= # calculate the target value # m_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352 - m_output = model.initial_inference(m_obs, action_batch) + m_output = model.initial_inference(m_obs, action_batch, task_id=self.task_id) # ====================================================================== if not model.training: diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index eb7616d19..f19410f3e 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -74,7 +74,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m # @profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, - List[Any]] + List[Any]], task_id=None ) -> None: """ Overview: @@ -144,7 +144,7 @@ def search( At the end of the simulation, the statistics along the trajectory are updated. """ # for UniZero - network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path) + network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path, task_id=task_id) network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state) network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits) @@ -169,6 +169,7 @@ def search( min_max_stats_lst, results, virtual_to_play_batch ) + class MuZeroMCTSCtree(object): """ Overview: diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index df50088ce..fad454a1a 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -88,16 +88,19 @@ def __init__( print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') print('==' * 20) elif world_model_cfg.obs_type == 'image': - self.representation_network = RepresentationNetworkUniZero( - observation_shape, - num_res_blocks, - num_channels, - self.downsample, - activation=self.activation, - norm_type=norm_type, - embedding_dim=world_model_cfg.embed_dim, - group_size=world_model_cfg.group_size, - ) + self.representation_network = nn.ModuleList() + # for task_id in range(self.task_num): # N independent encoder + for task_id in range(1): # one share encoder + self.representation_network.append(RepresentationNetworkUniZero( + observation_shape, + num_res_blocks, + num_channels, + self.downsample, + activation=self.activation, + norm_type=norm_type, + embedding_dim=world_model_cfg.embed_dim, + group_size=world_model_cfg.group_size, + )) # TODO: we should change the output_shape to the real observation shape self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64)) @@ -153,7 +156,7 @@ def __init__( print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') print('==' * 20) - def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None) -> MZNetworkOutput: + def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, task_id=None) -> MZNetworkOutput: """ Overview: Initial inference of UniZero model, which is the first step of the UniZero model. @@ -190,7 +193,7 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_ ) def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index=0, - latent_state_index_in_search_path=[]) -> MZNetworkOutput: + latent_state_index_in_search_path=[], task_id=None) -> MZNetworkOutput: """ Overview: Recurrent inference of UniZero model.To perform the recurrent inference, we concurrently predict the latent dynamics (reward/next_latent_state) diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py new file mode 100644 index 000000000..257bcb56c --- /dev/null +++ b/lzero/model/unizero_model_multitask.py @@ -0,0 +1,233 @@ +from typing import Optional + +import torch +import torch.nn as nn +from ding.utils import MODEL_REGISTRY, SequenceType +from easydict import EasyDict + +from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ + VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook +from .unizero_world_models.tokenizer import Tokenizer +from .unizero_world_models.world_model_multitask import WorldModelMT + + +# use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. +@MODEL_REGISTRY.register('UniZeroMTModel') +class UniZeroMTModel(nn.Module): + + def __init__( + self, + observation_shape: SequenceType = (4, 64, 64), + action_space_size: int = 6, + num_res_blocks: int = 1, + num_channels: int = 64, + activation: nn.Module = nn.GELU(approximate='tanh'), + downsample: bool = True, + norm_type: Optional[str] = 'BN', + world_model_cfg: EasyDict = None, + task_num: int = 1, + *args, + **kwargs + ): + """ + Overview: + The definition of data procession in the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), including two main parts: + - initial_inference, which is used to predict the value, policy, and latent state based on the current observation. + - recurrent_inference, which is used to predict the value, policy, reward, and next latent state based on the current latent state and action. + The world model consists of three main components: + - a tokenizer, which encodes observations into embeddings, + - a transformer, which processes the input sequences, + - and heads, which generate the logits for observations, rewards, policy, and value. + Arguments: + - observation_shape (:obj:`SequenceType`): Observation space shape, e.g. [C, W, H]=[3, 64, 64] for Atari. + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - num_res_blocks (:obj:`int`): The number of res blocks in UniZero model. + - num_channels (:obj:`int`): The channels of hidden states in representation network. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - world_model_cfg (:obj:`EasyDict`): The configuration of the world model, including the following keys: + - obs_type (:obj:`str`): The type of observation, which can be 'image', 'vector', or 'image_memory'. + - embed_dim (:obj:`int`): The dimension of the embedding. + - group_size (:obj:`int`): The group size of the transformer. + - max_blocks (:obj:`int`): The maximum number of blocks in the transformer. + - max_tokens (:obj:`int`): The maximum number of tokens in the transformer. + - context_length (:obj:`int`): The context length of the transformer. + - device (:obj:`str`): The device of the model, which can be 'cuda' or 'cpu'. + - action_space_size (:obj:`int`): The shape of the action. + - num_layers (:obj:`int`): The number of layers in the transformer. + - num_heads (:obj:`int`): The number of heads in the transformer. + - policy_entropy_weight (:obj:`float`): The weight of the policy entropy. + - analysis_sim_norm (:obj:`bool`): Whether to analyze the similarity of the norm. + """ + super(UniZeroMTModel, self).__init__() + self.action_space_size = action_space_size + + # for multi-task + self.action_space_size = 18 + self.task_num = task_num + + self.activation = activation + self.downsample = downsample + world_model_cfg.norm_type = norm_type + assert world_model_cfg.max_tokens == 2 * world_model_cfg.max_blocks, 'max_tokens should be 2 * max_blocks, because each timestep has 2 tokens: obs and action' + + if world_model_cfg.obs_type == 'vector': + self.representation_network = RepresentationNetworkMLP( + observation_shape, + hidden_channels=world_model_cfg.embed_dim, + layer_num=2, + activation=self.activation, + group_size=world_model_cfg.group_size, + ) + # TODO: only for MemoryEnv now + self.decoder_network = VectorDecoderForMemoryEnv(embedding_dim=world_model_cfg.embed_dim, output_shape=25) + self.tokenizer = Tokenizer(encoder=self.representation_network, + decoder_network=self.decoder_network, with_lpips=False) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) + elif world_model_cfg.obs_type == 'image': + self.representation_network = nn.ModuleList() + # for task_id in range(self.task_num): # N independent encoder + for task_id in range(1): # one share encoder + self.representation_network.append(RepresentationNetworkUniZero( + observation_shape, + num_res_blocks, + num_channels, + self.downsample, + activation=self.activation, + norm_type=norm_type, + embedding_dim=world_model_cfg.embed_dim, + group_size=world_model_cfg.group_size, + )) + # TODO: we should change the output_shape to the real observation shape + self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64)) + + # ====== for analysis ====== + if world_model_cfg.analysis_sim_norm: + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + self.tokenizer = Tokenizer(encoder=self.representation_network, + decoder_network=self.decoder_network, with_lpips=True,) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') + + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') + print('==' * 20) + elif world_model_cfg.obs_type == 'image_memory': + self.representation_network = LatentEncoderForMemoryEnv( + image_shape=(3, 5, 5), + embedding_size=world_model_cfg.embed_dim, + channels=[16, 32, 64], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation=self.activation, + group_size=world_model_cfg.group_size, + ) + self.decoder_network = LatentDecoderForMemoryEnv( + image_shape=(3, 5, 5), + embedding_size=world_model_cfg.embed_dim, + channels=[64, 32, 16], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation=self.activation, + ) + + if world_model_cfg.analysis_sim_norm: + # ====== for analysis ====== + self.encoder_hook = FeatureAndGradientHook() + self.encoder_hook.setup_hooks(self.representation_network) + + self.tokenizer = Tokenizer(with_lpips=True, encoder=self.representation_network, + decoder_network=self.decoder_network) + self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)') + + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') + print('==' * 20) + + def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, task_id=None) -> MZNetworkOutput: + """ + Overview: + Initial inference of UniZero model, which is the first step of the UniZero model. + To perform the initial inference, we first use the representation network to obtain the ``latent_state``. + Then we use the prediction network to predict ``value`` and ``policy_logits`` of the ``latent_state``. + Arguments: + - obs_batch (:obj:`torch.Tensor`): The 3D image observation data. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. \ + In initial inference, we set it to zero vector. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + batch_size = obs_batch.size(0) + obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch} + _, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict, task_id=task_id) + latent_state, reward, policy_logits, value = obs_token, logits_rewards, logits_policy, logits_value + policy_logits = policy_logits.squeeze(1) + value = value.squeeze(1) + + return MZNetworkOutput( + value, + [0. for _ in range(batch_size)], + policy_logits, + latent_state, + ) + + def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index=0, + latent_state_index_in_search_path=[], task_id=None) -> MZNetworkOutput: + """ + Overview: + Recurrent inference of UniZero model.To perform the recurrent inference, we concurrently predict the latent dynamics (reward/next_latent_state) + and decision-oriented quantities (value/policy) conditioned on the learned latent history in the world_model. + Arguments: + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - action (:obj:`torch.Tensor`): The predicted action to rollout. + Returns (MZNetworkOutput): + - value (:obj:`torch.Tensor`): The output value of input state to help policy improvement and evaluation. + - reward (:obj:`torch.Tensor`): The predicted reward of input state and selected action. + - policy_logits (:obj:`torch.Tensor`): The output logit to select discrete action. + - latent_state (:obj:`torch.Tensor`): The encoding latent state of input state. + - next_latent_state (:obj:`torch.Tensor`): The predicted next latent state. + Shapes: + - obs (:obj:`torch.Tensor`): :math:`(B, num_channel, obs_shape[1], obs_shape[2])`, where B is batch_size. + - action (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch_size. + - value (:obj:`torch.Tensor`): :math:`(B, value_support_size)`, where B is batch_size. + - reward (:obj:`torch.Tensor`): :math:`(B, reward_support_size)`, where B is batch_size. + - policy_logits (:obj:`torch.Tensor`): :math:`(B, action_dim)`, where B is batch_size. + - latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + - next_latent_state (:obj:`torch.Tensor`): :math:`(B, H_, W_)`, where B is batch_size, H_ is the height of \ + latent state, W_ is the width of latent state. + """ + _, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_recurrent_inference( + state_action_history, simulation_index, latent_state_index_in_search_path, task_id=task_id) + next_latent_state, reward, policy_logits, value = logits_observations, logits_rewards, logits_policy, logits_value + policy_logits = policy_logits.squeeze(1) + value = value.squeeze(1) + reward = reward.squeeze(1) + return MZNetworkOutput(value, reward, policy_logits, next_latent_state) \ No newline at end of file diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index 0826690dd..d4fde5cf5 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -51,35 +51,40 @@ def __init__(self, encoder=None, decoder_network=None, with_lpips: bool = False) self.encoder = encoder self.decoder_network = decoder_network - def encode_to_obs_embeddings(self, x: torch.Tensor) -> torch.Tensor: + def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Tensor: """ Encode observations to embeddings. Arguments: - x (torch.Tensor): Input tensor of shape (B, ...). + - x (torch.Tensor): Input tensor of shape (B, ...). Returns: - torch.Tensor: Encoded embeddings of shape (B, 1, E). + - torch.Tensor: Encoded embeddings of shape (B, 1, E). """ shape = x.shape + if task_id is None: + task_id = 0 + else: + task_id = 0 # one encoder + # task_id = task_id # Process input tensor based on its dimensionality if len(shape) == 2: # Case when input is 2D (B, E) - obs_embeddings = self.encoder(x) + obs_embeddings = self.encoder[task_id](x) obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 3: # Case when input is 3D (B, T, E) x = x.contiguous().view(-1, shape[-1]) # Flatten the last two dimensions (B * T, E) - obs_embeddings = self.encoder(x) + obs_embeddings = self.encoder[task_id](x) obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 4: # Case when input is 4D (B, C, H, W) - obs_embeddings = self.encoder(x) + obs_embeddings = self.encoder[task_id](x) obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 5: # Case when input is 5D (B, T, C, H, W) x = x.contiguous().view(-1, *shape[-3:]) # Flatten the first two dimensions (B * T, C, H, W) - obs_embeddings = self.encoder(x) + obs_embeddings = self.encoder[task_id](x) obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') else: raise ValueError(f"Invalid input shape: {shape}") diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py new file mode 100644 index 000000000..f1d6733d1 --- /dev/null +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -0,0 +1,1222 @@ +import collections +import copy +import logging +from typing import Any, Tuple +from typing import Optional +from typing import Union, Dict + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from lzero.model.common import SimNorm +from lzero.model.utils import cal_dormant_ratio +from .slicer import Head +from .tokenizer import Tokenizer +from .transformer import Transformer, TransformerConfig +from .utils import LossWithIntermediateLosses, init_weights, to_device_for_kvcache +from .utils import WorldModelOutput, quantize_state + +logging.getLogger().setLevel(logging.DEBUG) + + +class WorldModelMT(nn.Module): + """ + Overview: + The WorldModel class is responsible for the scalable latent world model of UniZero (https://arxiv.org/abs/2406.10667), + which is used to predict the next latent state, rewards, policy, and value based on the current latent state and action. + The world model consists of three main components: + - a tokenizer, which encodes observations into embeddings, + - a transformer, which processes the input sequences, + - and heads, which generate the logits for observations, rewards, policy, and value. + """ + def __init__(self, config: TransformerConfig, tokenizer) -> None: + """ + Overview: + Initialize the WorldModel class. + Arguments: + - config (:obj:`TransformerConfig`): The configuration for the transformer. + - tokenizer (:obj:`Tokenizer`): The tokenizer. + """ + super().__init__() + self.tokenizer = tokenizer + self.config = config + self.transformer = Transformer(self.config) + + # TODO: multitask + self.task_num = config.task_num + self.task_emb = nn.Embedding(self.task_num, config.embed_dim, max_norm=1) # TODO + self.head_policy_multi_task = nn.ModuleList() + self.head_value_multi_task = nn.ModuleList() + self.head_rewards_multi_task = nn.ModuleList() + self.head_observations_multi_task = nn.ModuleList() + + # Move all modules to the specified device + print(f"self.config.device: {self.config.device}") + self.to(self.config.device) + + # Initialize configuration parameters + self._initialize_config_parameters() + + # Initialize patterns for block masks + self._initialize_patterns() + + # Position embedding + self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device) + self.precompute_pos_emb_diff_kv() + print(f"self.pos_emb.weight.device: {self.pos_emb.weight.device}") + + # Initialize action embedding table + self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) + print(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") + + for task_id in range(self.task_num): # TODO + action_space_size = self.action_space_size # TODO:====================== + # action_space_size=18 # TODO:====================== + self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) + self.head_policy_multi_task.append(self.head_policy) + + self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + self.head_value_multi_task.append(self.head_value) + + self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) + self.head_rewards_multi_task.append(self.head_rewards) + + self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, + self.obs_per_embdding_dim, + self.sim_norm) # NOTE: we add a sim_norm to the head for observations + self.head_observations_multi_task.append(self.head_observations) + + # Head modules + # self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) + # self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, self.obs_per_embdding_dim, + # self.sim_norm) # NOTE: we add a sim_norm to the head for observations + # self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) + # self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + + # Apply weight initialization, the order is important + self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) + self._initialize_last_layer() + + # Cache structures + self._initialize_cache_structures() + + # Projection input dimension + self._initialize_projection_input_dim() + + # Hit count and query count statistics + self._initialize_statistics() + + # Initialize keys and values for transformer + self._initialize_transformer_keys_values() + + def _initialize_config_parameters(self) -> None: + """Initialize configuration parameters.""" + self.policy_entropy_weight = self.config.policy_entropy_weight + self.predict_latent_loss_type = self.config.predict_latent_loss_type + self.group_size = self.config.group_size + self.num_groups = self.config.embed_dim // self.group_size + self.obs_type = self.config.obs_type + self.embed_dim = self.config.embed_dim + self.num_heads = self.config.num_heads + self.gamma = self.config.gamma + self.context_length = self.config.context_length + self.dormant_threshold = self.config.dormant_threshold + self.analysis_dormant_ratio = self.config.analysis_dormant_ratio + self.num_observations_tokens = self.config.tokens_per_block - 1 + self.latent_recon_loss_weight = self.config.latent_recon_loss_weight + self.perceptual_loss_weight = self.config.perceptual_loss_weight + self.device = self.config.device + self.support_size = self.config.support_size + self.action_space_size = self.config.action_space_size + self.max_cache_size = self.config.max_cache_size + self.env_num = self.config.env_num + self.num_layers = self.config.num_layers + self.obs_per_embdding_dim = self.config.embed_dim + self.sim_norm = SimNorm(simnorm_dim=self.group_size) + + def _initialize_patterns(self) -> None: + """Initialize patterns for block masks.""" + self.all_but_last_latent_state_pattern = torch.ones(self.config.tokens_per_block) + self.all_but_last_latent_state_pattern[-2] = 0 + self.act_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.act_tokens_pattern[-1] = 1 + self.value_policy_tokens_pattern = torch.zeros(self.config.tokens_per_block) + self.value_policy_tokens_pattern[-2] = 1 + + def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None) -> Head: + """Create head modules for the transformer.""" + modules = [ + nn.Linear(self.config.embed_dim, self.config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def _initialize_last_layer(self) -> None: + """Initialize the last linear layer.""" + last_linear_layer_init_zero = True + if last_linear_layer_init_zero: + # TODO: multitask + if self.task_num == 1: + for head in [self.head_policy, self.head_value, self.head_rewards, self.head_observations]: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + elif self.task_num > 1: + for head in self.head_policy_multi_task + self.head_value_multi_task + self.head_rewards_multi_task + self.head_observations_multi_task: + for layer in reversed(head.head_module): + if isinstance(layer, nn.Linear): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + break + + def _initialize_cache_structures(self) -> None: + """Initialize cache structures for past keys and values.""" + self.past_kv_cache_recurrent_infer = collections.OrderedDict() + self.past_kv_cache_init_infer = collections.OrderedDict() + self.past_kv_cache_init_infer_envs = [collections.OrderedDict() for _ in range(self.env_num)] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + + def _initialize_projection_input_dim(self) -> None: + """Initialize the projection input dimension based on the number of observation tokens.""" + if self.num_observations_tokens == 16: + self.projection_input_dim = 128 + elif self.num_observations_tokens == 1: + self.projection_input_dim = self.obs_per_embdding_dim + + def _initialize_statistics(self) -> None: + """Initialize counters for hit count and query count statistics.""" + self.hit_count = 0 + self.total_query_count = 0 + self.length_largethan_maxminus5_context_cnt = 0 + self.length_largethan_maxminus7_context_cnt = 0 + self.root_hit_cnt = 0 + self.root_total_query_cnt = 0 + + def _initialize_transformer_keys_values(self) -> None: + """Initialize keys and values for the transformer.""" + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, + max_tokens=self.context_length) + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, + max_tokens=self.context_length) + + def precompute_pos_emb_diff_kv(self): + """ Precompute positional embedding differences for key and value. """ + if self.context_length <= 2: + # If context length is 2 or less, no context is present + return + + # Precompute positional embedding matrices for inference in collect/eval stages, not for training + self.positional_embedding_k = [ + self._get_positional_embedding(layer, 'key') + for layer in range(self.config.num_layers) + ] + self.positional_embedding_v = [ + self._get_positional_embedding(layer, 'value') + for layer in range(self.config.num_layers) + ] + + # Precompute all possible positional embedding differences + self.pos_emb_diff_k = [] + self.pos_emb_diff_v = [] + + for layer in range(self.config.num_layers): + layer_pos_emb_diff_k = {} + layer_pos_emb_diff_v = {} + + for start in [2]: + for end in [self.context_length - 1]: + original_pos_emb_k = self.positional_embedding_k[layer][:, :, start:end, :] + new_pos_emb_k = self.positional_embedding_k[layer][:, :, :end - start, :] + layer_pos_emb_diff_k[(start, end)] = new_pos_emb_k - original_pos_emb_k + + original_pos_emb_v = self.positional_embedding_v[layer][:, :, start:end, :] + new_pos_emb_v = self.positional_embedding_v[layer][:, :, :end - start, :] + layer_pos_emb_diff_v[(start, end)] = new_pos_emb_v - original_pos_emb_v + + self.pos_emb_diff_k.append(layer_pos_emb_diff_k) + self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + + def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: + """ + Helper function to get positional embedding for a given layer and attention type. + + Arguments: + - layer (:obj:`int`): Layer index. + - attn_type (:obj:`str`): Attention type, either 'key' or 'value'. + + Returns: + - torch.Tensor: The positional embedding tensor. + """ + attn_func = getattr(self.transformer.blocks[layer].attn, attn_type) + if torch.cuda.is_available(): + return attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2).to(self.device).detach() + else: + return attn_func(self.pos_emb.weight).view( + 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads + ).transpose(1, 2).detach() + + def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]], + past_keys_values: Optional[torch.Tensor] = None, + kvcache_independent: bool = False, is_init_infer: bool = True, + valid_context_lengths: Optional[torch.Tensor] = None, task_id=0) -> WorldModelOutput: + """ + Forward pass for the model. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing observation embeddings or action tokens. + - past_keys_values (:obj:`Optional[torch.Tensor]`): Previous keys and values for transformer. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - is_init_infer (:obj:`bool`): Initialize inference. + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid context lengths. + Returns: + - WorldModelOutput: Model output containing logits for observations, rewards, policy, and value. + """ + # task_embeddings = self.task_emb(torch.tensor(task_id, device=self.device)) # NOTE: TODO + self.task_embeddings = torch.zeros(768, device=self.device) # NOTE:TODO no task_embeddings ============= + + # Determine previous steps based on key-value caching method + if kvcache_independent: + prev_steps = torch.tensor([0 if past_keys_values is None else past_kv.size for past_kv in past_keys_values], + device=self.device) + else: + prev_steps = 0 if past_keys_values is None else past_keys_values.size + + # Reset valid_context_lengths during initial inference + if is_init_infer: + valid_context_lengths = None + + # Process observation embeddings + if 'obs_embeddings' in obs_embeddings_or_act_tokens: + obs_embeddings = obs_embeddings_or_act_tokens['obs_embeddings'] + if len(obs_embeddings.shape) == 2: + obs_embeddings = obs_embeddings.unsqueeze(1) + num_steps = obs_embeddings.size(1) + sequences = self._add_position_embeddings(obs_embeddings, prev_steps, num_steps, kvcache_independent, + is_init_infer, valid_context_lengths) + # TODO: multitask + sequences = sequences + self.task_embeddings + + # Process action tokens + elif 'act_tokens' in obs_embeddings_or_act_tokens: + act_tokens = obs_embeddings_or_act_tokens['act_tokens'] + if len(act_tokens.shape) == 3: + act_tokens = act_tokens.squeeze(1) + num_steps = act_tokens.size(1) + act_embeddings = self.act_embedding_table(act_tokens) + sequences = self._add_position_embeddings(act_embeddings, prev_steps, num_steps, kvcache_independent, + is_init_infer, valid_context_lengths) + + # TODO: multitask + # TODO: 对于action_token不需要增加task_embeddings会造成歧义,反而干扰学习 + self.task_embeddings = torch.zeros(768, device=self.device) + sequences = sequences + self.task_embeddings + + # Process combined observation embeddings and action tokens + else: + sequences, num_steps = self._process_obs_act_combined(obs_embeddings_or_act_tokens, prev_steps) + + # Pass sequences through transformer + x = self._transformer_pass(sequences, past_keys_values, kvcache_independent, valid_context_lengths) + + # Generate logits + + # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 + # one head or soft_moe + logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) + + # N head + # logits_observations = self.head_observations_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + # logits_rewards = self.head_rewards_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + # logits_policy = self.head_policy_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + # logits_value = self.head_value_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + + # logits_ends is None + return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + + def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, + valid_context_lengths): + """ + Add position embeddings to the input embeddings. + + Arguments: + - embeddings (:obj:`torch.Tensor`): Input embeddings. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + - num_steps (:obj:`int`): Number of steps. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - is_init_infer (:obj:`bool`): Initialize inference. + - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + Returns: + - torch.Tensor: Embeddings with position information added. + """ + if kvcache_independent: + steps_indices = prev_steps + torch.arange(num_steps, device=embeddings.device) + position_embeddings = self.pos_emb(steps_indices).view(-1, num_steps, embeddings.shape[-1]) + return embeddings + position_embeddings + else: + if is_init_infer: + return embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)) + else: + valid_context_lengths = torch.tensor(self.keys_values_wm_size_list_current, device=self.device) + position_embeddings = self.pos_emb( + valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) + return embeddings + position_embeddings + + def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): + """ + Process combined observation embeddings and action tokens. + + Arguments: + - obs_embeddings_or_act_tokens (:obj:`dict`): Dictionary containing combined observation embeddings and action tokens. + - prev_steps (:obj:`torch.Tensor`): Previous steps. + Returns: + - torch.Tensor: Combined observation and action embeddings with position information added. + """ + obs_embeddings, act_tokens = obs_embeddings_or_act_tokens['obs_embeddings_and_act_tokens'] + if len(obs_embeddings.shape) == 3: + obs_embeddings = obs_embeddings.view(act_tokens.shape[0], act_tokens.shape[1], self.num_observations_tokens, + -1) + + num_steps = int(obs_embeddings.size(1) * (obs_embeddings.size(2) + 1)) + act_embeddings = self.act_embedding_table(act_tokens) + + B, L, K, E = obs_embeddings.size() + obs_act_embeddings = torch.empty(B, L * (K + 1), E, device=self.device) + + for i in range(L): + # obs = obs_embeddings[:, i, :, :] + obs = obs_embeddings[:, i, :, :] + self.task_embeddings # Shape: (B, K, E) TODO: task_embeddings + act = act_embeddings[:, i, 0, :].unsqueeze(1) + obs_act = torch.cat([obs, act], dim=1) + obs_act_embeddings[:, i * (K + 1):(i + 1) * (K + 1), :] = obs_act + + return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + + def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths): + """ + Pass sequences through the transformer. + + Arguments: + - sequences (:obj:`torch.Tensor`): Input sequences. + - past_keys_values (:obj:`Optional[torch.Tensor]`): Previous keys and values for transformer. + - kvcache_independent (:obj:`bool`): Whether to use independent key-value caching. + - valid_context_lengths (:obj:`torch.Tensor`): Valid context lengths. + Returns: + - torch.Tensor: Transformer output. + """ + if kvcache_independent: + x = [self.transformer(sequences[k].unsqueeze(0), past_kv, + valid_context_lengths=valid_context_lengths[k].unsqueeze(0)) for k, past_kv in + enumerate(past_keys_values)] + return torch.cat(x, dim=0) + else: + return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths) + + @torch.no_grad() + def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor, task_id=0) -> torch.FloatTensor: + """ + Reset the model state based on initial observations and actions. + + Arguments: + - obs_act_dict (:obj:`torch.FloatTensor`): A dictionary containing 'obs', 'action', and 'current_obs'. + Returns: + - torch.FloatTensor: The outputs from the world model and the latent state. + """ + # Extract observations, actions, and current observations from the dictionary. + if isinstance(obs_act_dict, dict): + observations = obs_act_dict['obs'] + buffer_action = obs_act_dict['action'] + current_obs = obs_act_dict['current_obs'] + + # Encode observations to latent embeddings. + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(observations, task_id=task_id) + + if current_obs is not None: + # ================ Collect and Evaluation Phase ================ + # Encode current observations to latent embeddings + current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(current_obs, task_id=task_id) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + self.latent_state = current_obs_embeddings + outputs_wm = self.refresh_kvs_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action, current_obs_embeddings, task_id=task_id) + else: + # ================ calculate the target value in Train phase ================ + self.latent_state = obs_embeddings + outputs_wm = self.refresh_kvs_with_initial_latent_state_for_init_infer(obs_embeddings, buffer_action, None, task_id=task_id) + + return outputs_wm, self.latent_state + + @torch.no_grad() + def refresh_kvs_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, + buffer_action=None, + current_obs_embeddings=None, task_id=0) -> torch.FloatTensor: + """ + Refresh key-value pairs with the initial latent state for inference. + + Arguments: + - latent_state (:obj:`torch.LongTensor`): The latent state embeddings. + - buffer_action (optional): Actions taken. + - current_obs_embeddings (optional): Current observation embeddings. + Returns: + - torch.FloatTensor: The outputs from the world model. + """ + n, num_observations_tokens, _ = latent_state.shape + if n <= self.env_num: + # ================ Collect and Evaluation Phase ================ + if current_obs_embeddings is not None: + if max(buffer_action) == -1: + # First step in an episode + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, + max_tokens=self.context_length) + # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id) + + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + else: + # Assume latest_state is the new latent_state, containing information from ready_env_num environments + ready_env_num = current_obs_embeddings.shape[0] + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + for i in range(ready_env_num): + # Retrieve latent state for a single environment + state_single_env = latent_state[i] + quantized_state = state_single_env.detach().cpu().numpy() + # Compute hash value using quantized state + cache_key = quantize_state(quantized_state) + # Retrieve cached value + matched_value = self.past_kv_cache_init_infer_envs[i].get(cache_key) + + self.root_total_query_cnt += 1 + if matched_value is not None: + # If a matching value is found, add it to the list + self.root_hit_cnt += 1 + # deepcopy is needed because forward modifies matched_value in place + self.keys_values_wm_list.append(copy.deepcopy(to_device_for_kvcache(matched_value, self.device))) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # Reset using zero values + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, max_tokens=self.context_length) + outputs_wm = self.forward({'obs_embeddings': state_single_env.unsqueeze(0)}, + past_keys_values=self.keys_values_wm_single_env, + is_init_infer=True, task_id=task_id) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + # Input self.keys_values_wm_list, output self.keys_values_wm + self.keys_values_wm_size_list_current = self.trim_and_pad_kv_cache(is_init_infer=True) + + buffer_action = buffer_action[:ready_env_num] + # if ready_env_num < self.env_num: + # print(f'init inference ready_env_num: {ready_env_num} < env_num: {self.env_num}') + act_tokens = torch.from_numpy(np.array(buffer_action)).to(latent_state.device).unsqueeze(-1) + outputs_wm = self.forward({'act_tokens': act_tokens}, past_keys_values=self.keys_values_wm, + is_init_infer=True, task_id=task_id) + + outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, + past_keys_values=self.keys_values_wm, is_init_infer=True, task_id=task_id) + + # Copy and store keys_values_wm for a single environment + self.update_cache_context(current_obs_embeddings, is_init_infer=True) + + # elif n > self.env_num and buffer_action is not None and current_obs_embeddings is None: + elif buffer_action is not None and current_obs_embeddings is None: + # ================ calculate the target value in Train phase ================ + # [192, 16, 64] -> [32, 6, 16, 64] + latent_state = latent_state.contiguous().view(buffer_action.shape[0], -1, num_observations_tokens, + self.obs_per_embdding_dim) # (BL, K) for unroll_step=1 + + latent_state = latent_state[:, :-1, :] + buffer_action = torch.from_numpy(buffer_action).to(latent_state.device) + act_tokens = rearrange(buffer_action, 'b l -> b l 1') + + # select the last timestep for each sample + # This will select the last column while keeping the dimensions unchanged, and the target policy/value in the final step itself is not used. + last_steps_act = act_tokens[:, -1:, :] + act_tokens = torch.cat((act_tokens, last_steps_act), dim=1) + + outputs_wm = self.forward({'obs_embeddings_and_act_tokens': (latent_state, act_tokens)}, task_id=task_id) + + # select the last timestep for each sample + last_steps_value = outputs_wm.logits_value[:, -1:, :] + outputs_wm.logits_value = torch.cat((outputs_wm.logits_value, last_steps_value), dim=1) + + last_steps_policy = outputs_wm.logits_policy[:, -1:, :] + outputs_wm.logits_policy = torch.cat((outputs_wm.logits_policy, last_steps_policy), dim=1) + + # Reshape your tensors + # outputs_wm.logits_value.shape (B, H, 101) = (B*H, 101) + outputs_wm.logits_value = rearrange(outputs_wm.logits_value, 'b t e -> (b t) e') + outputs_wm.logits_policy = rearrange(outputs_wm.logits_policy, 'b t e -> (b t) e') + + return outputs_wm + + @torch.no_grad() + def forward_initial_inference(self, obs_act_dict, task_id=0): + """ + Perform initial inference based on the given observation-action dictionary. + + Arguments: + - obs_act_dict (:obj:`dict`): Dictionary containing observations and actions. + Returns: + - tuple: A tuple containing output sequence, latent state, logits rewards, logits policy, and logits value. + """ + # UniZero has context in the root node + outputs_wm, latent_state = self.reset_from_initial_observations(obs_act_dict, task_id=task_id) + self.past_kv_cache_recurrent_infer.clear() + + return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, + outputs_wm.logits_policy, outputs_wm.logits_value) + + @torch.no_grad() + def forward_recurrent_inference(self, state_action_history, simulation_index=0, + latent_state_index_in_search_path=[], task_id=0): + """ + Perform recurrent inference based on the state-action history. + + Arguments: + - state_action_history (:obj:`list`): List containing tuples of state and action history. + - simulation_index (:obj:`int`, optional): Index of the current simulation. Defaults to 0. + - latent_state_index_in_search_path (:obj:`list`, optional): List containing indices of latent states in the search path. Defaults to []. + Returns: + - tuple: A tuple containing output sequence, updated latent state, reward, logits policy, and logits value. + """ + latest_state, action = state_action_history[-1] + ready_env_num = latest_state.shape[0] + + self.keys_values_wm_list = [] + self.keys_values_wm_size_list = [] + self.keys_values_wm_size_list = self.retrieve_or_generate_kvcache(latest_state, ready_env_num, simulation_index, task_id=task_id) + + latent_state_list = [] + token = action.reshape(-1, 1) + + # ======= Print statistics for debugging ============= + # min_size = min(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 5: + # self.length_largethan_maxminus5_context_cnt += len(self.keys_values_wm_size_list) + # if min_size >= self.config.max_tokens - 7: + # self.length_largethan_maxminus7_context_cnt += len(self.keys_values_wm_size_list) + # if self.total_query_count > 0 and self.total_query_count % 10000 == 0: + # self.hit_freq = self.hit_count / self.total_query_count + # print('total_query_count:', self.total_query_count) + # length_largethan_maxminus5_context_cnt_ratio = self.length_largethan_maxminus5_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus5_context:', self.length_largethan_maxminus5_context_cnt) + # print('recurrent largethan_maxminus5_context_ratio:', length_largethan_maxminus5_context_cnt_ratio) + # length_largethan_maxminus7_context_cnt_ratio = self.length_largethan_maxminus7_context_cnt / self.total_query_count + # print('recurrent largethan_maxminus7_context_ratio:', length_largethan_maxminus7_context_cnt_ratio) + # print('recurrent largethan_maxminus7_context:', self.length_largethan_maxminus7_context_cnt) + + # Trim and pad kv_cache + self.keys_values_wm_size_list = self.trim_and_pad_kv_cache(is_init_infer=False) + self.keys_values_wm_size_list_current = self.keys_values_wm_size_list + + for k in range(2): + # action_token obs_token + if k == 0: + obs_embeddings_or_act_tokens = {'act_tokens': token} + else: + obs_embeddings_or_act_tokens = {'obs_embeddings': token} + + # Perform forward pass + outputs_wm = self.forward( + obs_embeddings_or_act_tokens, + past_keys_values=self.keys_values_wm, + kvcache_independent=False, + is_init_infer=False, + task_id = task_id + ) + + self.keys_values_wm_size_list_current = [i + 1 for i in self.keys_values_wm_size_list_current] + + if k == 0: + reward = outputs_wm.logits_rewards # (B,) + + if k < self.num_observations_tokens: + token = outputs_wm.logits_observations + if len(token.shape) != 3: + token = token.unsqueeze(1) # (8,1024) -> (8,1,1024) + latent_state_list.append(token) + + del self.latent_state # Very important to minimize cuda memory usage + self.latent_state = torch.cat(latent_state_list, dim=1) # (B, K) + + self.update_cache_context( + self.latent_state, + is_init_infer=False, + simulation_index=simulation_index, + latent_state_index_in_search_path=latent_state_index_in_search_path + ) + + return (outputs_wm.output_sequence, self.latent_state, reward, outputs_wm.logits_policy, outputs_wm.logits_value) + + def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: + """ + Adjusts the key-value cache for each environment to ensure they all have the same size. + + In a multi-environment setting, the key-value cache (kv_cache) for each environment is stored separately. + During recurrent inference, the kv_cache sizes may vary across environments. This method pads each kv_cache + to match the largest size found among them, facilitating batch processing in the transformer forward pass. + + Arguments: + - is_init_infer (:obj:`bool`): Indicates if this is an initial inference. Default is True. + Returns: + - list: Updated sizes of the key-value caches. + """ + # Find the maximum size among all key-value caches + max_size = max(self.keys_values_wm_size_list) + + # Iterate over each layer of the transformer + for layer in range(self.num_layers): + kv_cache_k_list = [] + kv_cache_v_list = [] + + # Enumerate through each environment's key-value pairs + for idx, keys_values in enumerate(self.keys_values_wm_list): + k_cache = keys_values[layer]._k_cache._cache + v_cache = keys_values[layer]._v_cache._cache + + effective_size = self.keys_values_wm_size_list[idx] + pad_size = max_size - effective_size + + # If padding is required, trim the end and pad the beginning of the cache + if pad_size > 0: + k_cache_trimmed = k_cache[:, :, :-pad_size, :] + v_cache_trimmed = v_cache[:, :, :-pad_size, :] + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, pad_size, 0), "constant", 0) + else: + k_cache_padded = k_cache + v_cache_padded = v_cache + + kv_cache_k_list.append(k_cache_padded) + kv_cache_v_list.append(v_cache_padded) + + # Stack the caches along a new dimension and remove any extra dimensions + self.keys_values_wm._keys_values[layer]._k_cache._cache = torch.stack(kv_cache_k_list, dim=0).squeeze(1) + self.keys_values_wm._keys_values[layer]._v_cache._cache = torch.stack(kv_cache_v_list, dim=0).squeeze(1) + + # Update the cache size to the maximum size + self.keys_values_wm._keys_values[layer]._k_cache._size = max_size + self.keys_values_wm._keys_values[layer]._v_cache._size = max_size + + return self.keys_values_wm_size_list + + def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, + latent_state_index_in_search_path=[], valid_context_lengths=None): + """ + Update the cache context with the given latent state. + + Arguments: + - latent_state (:obj:`torch.Tensor`): The latent state tensor. + - is_init_infer (:obj:`bool`): Flag to indicate if this is the initial inference. + - simulation_index (:obj:`int`): Index of the simulation. + - latent_state_index_in_search_path (:obj:`list`): List of indices in the search path. + - valid_context_lengths (:obj:`list`): List of valid context lengths. + """ + if self.context_length <= 2: + # No context to update if the context length is less than or equal to 2. + return + for i in range(latent_state.size(0)): + # ============ Iterate over each environment ============ + state_single_env = latent_state[i] + quantized_state = state_single_env.detach().cpu().numpy() + cache_key = quantize_state(quantized_state) + context_length = self.context_length + + if not is_init_infer: + # ============ Internal Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + current_max_context_length = max(self.keys_values_wm_size_list_current) + trim_size = current_max_context_length - self.keys_values_wm_size_list_current[i] + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + # cache shape [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + if trim_size > 0: + # Trim invalid leading zeros as per effective length + # Remove the first trim_size zero kv items + k_cache_trimmed = k_cache_current[:, trim_size:, :] + v_cache_trimmed = v_cache_current[:, trim_size:, :] + # If effective length < current_max_context_length, pad the end of cache with 'trim_size' zeros + k_cache_padded = F.pad(k_cache_trimmed, (0, 0, 0, trim_size), "constant", + 0) # Pad with 'trim_size' zeros at end of cache + v_cache_padded = F.pad(v_cache_trimmed, (0, 0, 0, trim_size), "constant", 0) + else: + k_cache_padded = k_cache_current + v_cache_padded = v_cache_current + + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = \ + self.keys_values_wm_size_list_current[i] + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = \ + self.keys_values_wm_size_list_current[i] + + # ============ NOTE: Very Important ============ + if self.keys_values_wm_single_env._keys_values[layer]._k_cache._size >= context_length - 1: + # Keep only the last self.context_length-3 timesteps of context + # For memory environments, training is for H steps, recurrent_inference might exceed H steps + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache + v_cache_current = self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + v_cache_trimmed = v_cache_current[:, :, 2:context_length - 1, :].squeeze(0) + + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update single environment cache + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + else: + # ============ Root Node ============ + # Retrieve KV from global KV cache self.keys_values_wm to single environment KV cache self.keys_values_wm_single_env, ensuring correct positional encoding + + for layer in range(self.num_layers): + # ============ Apply trimming and padding to each layer of kv_cache ============ + + if self.keys_values_wm._keys_values[layer]._k_cache._size < context_length - 1: # Keep only the last self.context_length-1 timesteps of context + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = self.keys_values_wm._keys_values[layer]._k_cache._cache[i].unsqueeze(0) # Shape torch.Size([2, 100, 512]) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = self.keys_values_wm._keys_values[layer]._v_cache._cache[i].unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = self.keys_values_wm._keys_values[layer]._k_cache._size + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = self.keys_values_wm._keys_values[layer]._v_cache._size + else: + # Assuming cache dimension is [batch_size, num_heads, sequence_length, features] + k_cache_current = self.keys_values_wm._keys_values[layer]._k_cache._cache[i] + v_cache_current = self.keys_values_wm._keys_values[layer]._v_cache._cache[i] + + # Remove the first 2 steps, keep the last self.context_length-3 steps + k_cache_trimmed = k_cache_current[:, 2:context_length - 1, :] + v_cache_trimmed = v_cache_current[:, 2:context_length - 1, :] + + # Index pre-computed positional encoding differences + pos_emb_diff_k = self.pos_emb_diff_k[layer][(2, context_length - 1)] + pos_emb_diff_v = self.pos_emb_diff_v[layer][(2, context_length - 1)] + # ============ NOTE: Very Important ============ + # Apply positional encoding correction to k and v + k_cache_trimmed += pos_emb_diff_k.squeeze(0) + v_cache_trimmed += pos_emb_diff_v.squeeze(0) + + # Pad the last 3 steps along the third dimension with zeros + # F.pad parameters (0, 0, 0, 3) specify padding amounts for each dimension: (left, right, top, bottom). For 3D tensor, they correspond to (dim2 left, dim2 right, dim1 left, dim1 right). + padding_size = (0, 0, 0, 3) + k_cache_padded = F.pad(k_cache_trimmed, padding_size, 'constant', 0) + v_cache_padded = F.pad(v_cache_trimmed, padding_size, 'constant', 0) + # Update cache of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._cache = k_cache_padded.unsqueeze(0) + self.keys_values_wm_single_env._keys_values[layer]._v_cache._cache = v_cache_padded.unsqueeze(0) + # Update size of self.keys_values_wm_single_env + self.keys_values_wm_single_env._keys_values[layer]._k_cache._size = context_length - 3 + self.keys_values_wm_single_env._keys_values[layer]._v_cache._size = context_length - 3 + + if is_init_infer: + # Store the latest key-value cache for initial inference + self.past_kv_cache_init_infer_envs[i][cache_key] = copy.deepcopy(to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) + else: + # Store the latest key-value cache for recurrent inference + self.past_kv_cache_recurrent_infer[cache_key] = copy.deepcopy(to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) + + def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, + simulation_index: int = 0, task_id=0) -> list: + """ + Retrieves or generates key-value caches for each environment based on the latent state. + + For each environment, this method either retrieves a matching cache from the predefined + caches if available, or generates a new cache if no match is found. The method updates + the internal lists with these caches and their sizes. + + Arguments: + - latent_state (:obj:`list`): List of latent states for each environment. + - ready_env_num (:obj:`int`): Number of environments ready for processing. + - simulation_index (:obj:`int`, optional): Index for simulation tracking. Default is 0. + Returns: + - list: Sizes of the key-value caches for each environment. + """ + for i in range(ready_env_num): + self.total_query_count += 1 + state_single_env = latent_state[i] # Get the latent state for a single environment + cache_key = quantize_state(state_single_env) # Compute the hash value using the quantized state + + # Try to retrieve the cached value from past_kv_cache_init_infer_envs + matched_value = self.past_kv_cache_init_infer_envs[i].get(cache_key) + + # If not found, try to retrieve from past_kv_cache_recurrent_infer + if matched_value is None: + matched_value = self.past_kv_cache_recurrent_infer.get(cache_key) + + if matched_value is not None: + # If a matching cache is found, add it to the lists + self.hit_count += 1 + # Perform a deep copy because the transformer's forward pass might modify matched_value in-place + self.keys_values_wm_list.append(copy.deepcopy(to_device_for_kvcache(matched_value, self.device))) + self.keys_values_wm_size_list.append(matched_value.size) + else: + # If no matching cache is found, generate a new one using zero reset + self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values( + n=1, max_tokens=self.context_length + ) + self.forward( + {'obs_embeddings': torch.from_numpy(state_single_env).unsqueeze(0).to(self.device)}, + past_keys_values=self.keys_values_wm_single_env, is_init_infer=True, task_id=task_id + ) + self.keys_values_wm_list.append(self.keys_values_wm_single_env) + self.keys_values_wm_size_list.append(1) + + return self.keys_values_wm_size_list + + def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, task_id=0, **kwargs: Any) -> LossWithIntermediateLosses: + # Encode observations into latent state representations + obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + # ========= for visual analysis ========= + # Uncomment the lines below for visual analysis in Pong + # self.plot_latent_tsne_each_and_all_for_pong(obs_embeddings, suffix='pong_H10_H4_tsne') + # self.save_as_image_with_timestep(batch['observations'], suffix='pong_H10_H4_tsne') + # Uncomment the lines below for visual analysis in visual match + # self.plot_latent_tsne_each_and_all(obs_embeddings, suffix='visual_match_memlen1-60-15_tsne') + # self.save_as_image_with_timestep(batch['observations'], suffix='visual_match_memlen1-60-15_tsne') + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio: + # Calculate dormant ratio of the encoder + shape = batch['observations'].shape # (..., C, H, W) + inputs = batch['observations'].contiguous().view(-1, *shape[-3:]) # (32,5,3,64,64) -> (160,3,64,64) + dormant_ratio_encoder = cal_dormant_ratio(self.tokenizer.encoder, inputs.detach(), + percentage=self.dormant_threshold) + self.past_kv_cache_init_infer.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_encoder = torch.tensor(0.) + + # Calculate the L2 norm of the latent state roots + latent_state_l2_norms = torch.norm(obs_embeddings, p=2, dim=2).mean() + + if self.obs_type == 'image': + # Reconstruct observations from latent state representations + reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # original_images, reconstructed_images = batch['observations'], reconstructed_images + # target_policy = batch['target_policy'] + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + # perceptual_loss = self.tokenizer.perceptual_loss(batch['observations'].reshape(-1, 3, 64, 64), reconstructed_images) # NOTE: for stack=1 + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + elif self.obs_type == 'vector': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim)) + # # Calculate reconstruction loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25), + # reconstructed_images) + latent_recon_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + elif self.obs_type == 'image_memory': + # Reconstruct observations from latent state representations + reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) + original_images, reconstructed_images = batch['observations'], reconstructed_images + + # ========== for visualization ========== + # Uncomment the lines below for visual analysis + # target_policy = batch['target_policy'] + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ========== for visualization ========== + + # Calculate reconstruction loss and perceptual loss + latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 3, 5, 5), + reconstructed_images) + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=batch['observations'].dtype) + + # Action tokens + act_tokens = rearrange(batch['actions'], 'b l -> b l 1') + + # Forward pass to obtain predictions for observations, rewards, and policies + outputs = self.forward({'obs_embeddings_and_act_tokens': (obs_embeddings, act_tokens)}, task_id=task_id) + + # ========= logging for analysis ========= + if self.analysis_dormant_ratio: + # Calculate dormant ratio of the world model + dormant_ratio_world_model = cal_dormant_ratio(self, { + 'obs_embeddings_and_act_tokens': (obs_embeddings.detach(), act_tokens.detach())}, + percentage=self.dormant_threshold) + self.past_kv_cache_init_infer.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + torch.cuda.empty_cache() + else: + dormant_ratio_world_model = torch.tensor(0.) + + # ========== for visualization ========== + # Uncomment the lines below for visualization + # predict_policy = outputs.logits_policy + # predict_policy = F.softmax(outputs.logits_policy, dim=-1) + # predict_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # predict_rewards = inverse_scalar_transform_handle(outputs.logits_rewards.reshape(-1, 101)).reshape(batch['observations'].shape[0], batch['observations'].shape[1], 1) + # import pdb; pdb.set_trace() + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=[], suffix='pong_H10_H4_0613') + + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_success_episode') + # visualize_reward_value_img_policy(original_images, reconstructed_images, target_predict_value, true_rewards, target_policy, predict_value, predict_rewards, predict_policy, not_plot_timesteps=list(np.arange(4,60)), suffix='visual_match_memlen1-60-15/one_fail_episode') + # ========== for visualization ========== + + # For training stability, use target_tokenizer to compute the true next latent state representations + with torch.no_grad(): + target_obs_embeddings = target_tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) + + # Compute labels for observations, rewards, and ends + labels_observations, labels_rewards, labels_ends = self.compute_labels_world_model(target_obs_embeddings, + batch['rewards'], + batch['ends'], + batch['mask_padding']) + + # Reshape the logits and labels for observations + logits_observations = rearrange(outputs.logits_observations[:, :-1], 'b t o -> (b t) o') + labels_observations = labels_observations.reshape(-1, self.projection_input_dim) + + # Compute prediction loss for observations. Options: MSE and Group KL + if self.predict_latent_loss_type == 'mse': + # MSE loss, directly compare logits and labels + loss_obs = torch.nn.functional.mse_loss(logits_observations, labels_observations, reduction='none').mean( + -1) + elif self.predict_latent_loss_type == 'group_kl': + # Group KL loss, group features and calculate KL divergence within each group + batch_size, num_features = logits_observations.shape + epsilon = 1e-6 + logits_reshaped = logits_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + labels_reshaped = labels_observations.reshape(batch_size, self.num_groups, self.group_size) + epsilon + + loss_obs = F.kl_div(logits_reshaped.log(), labels_reshaped, reduction='none').sum(dim=-1).mean(dim=-1) + + # ========== for debugging ========== + # print('loss_obs:', loss_obs.mean()) + # assert not torch.isnan(loss_obs).any(), "loss_obs contains NaN values" + # assert not torch.isinf(loss_obs).any(), "loss_obs contains Inf values" + # for name, param in self.tokenizer.encoder.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + + # Apply mask to loss_obs + mask_padding_expanded = batch['mask_padding'][:, 1:].contiguous().view(-1) + loss_obs = (loss_obs * mask_padding_expanded) + + # Compute labels for policy and value + labels_policy, labels_value = self.compute_labels_world_model_value_policy(batch['target_value'], + batch['target_policy'], + batch['mask_padding']) + + # Compute losses for rewards, policy, and value + loss_rewards = self.compute_cross_entropy_loss(outputs, labels_rewards, batch, element='rewards') + loss_policy, orig_policy_loss, policy_entropy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, + element='policy') + loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') + + # Compute timesteps + timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) + # Compute discount coefficients for each timestep + discounts = self.gamma ** timesteps + + # Group losses into first step, middle step, and last step + first_step_losses = {} + middle_step_losses = {} + last_step_losses = {} + # batch['mask_padding'] indicates mask status for future H steps, exclude masked losses to maintain accurate mean statistics + # Group losses for each loss item + for loss_name, loss_tmp in zip( + ['loss_obs', 'loss_rewards', 'loss_value', 'loss_policy', 'orig_policy_loss', 'policy_entropy'], + [loss_obs, loss_rewards, loss_value, loss_policy, orig_policy_loss, policy_entropy] + ): + if loss_name == 'loss_obs': + seq_len = batch['actions'].shape[1] - 1 + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, 1:seq_len] + else: + seq_len = batch['actions'].shape[1] + # Get the corresponding mask_padding + mask_padding = batch['mask_padding'][:, :seq_len] + + # Adjust loss shape to (batch_size, seq_len) + loss_tmp = loss_tmp.view(-1, seq_len) + + # First step loss + first_step_mask = mask_padding[:, 0] + first_step_losses[loss_name] = loss_tmp[:, 0][first_step_mask].mean() + + # Middle step loss + middle_step_index = seq_len // 2 + middle_step_mask = mask_padding[:, middle_step_index] + middle_step_losses[loss_name] = loss_tmp[:, middle_step_index][middle_step_mask].mean() + + # Last step loss + last_step_mask = mask_padding[:, -1] + last_step_losses[loss_name] = loss_tmp[:, -1][last_step_mask].mean() + + # Discount reconstruction loss and perceptual loss + discounted_latent_recon_loss = latent_recon_loss + discounted_perceptual_loss = perceptual_loss + + # Calculate overall discounted loss + discounted_loss_obs = (loss_obs.view(-1, batch['actions'].shape[1] - 1) * discounts[1:]).mean() + discounted_loss_rewards = (loss_rewards.view(-1, batch['actions'].shape[1]) * discounts).mean() + discounted_loss_value = (loss_value.view(-1, batch['actions'].shape[1]) * discounts).mean() + discounted_loss_policy = (loss_policy.view(-1, batch['actions'].shape[1]) * discounts).mean() + discounted_orig_policy_loss = (orig_policy_loss.view(-1, batch['actions'].shape[1]) * discounts).mean() + discounted_policy_entropy = (policy_entropy.view(-1, batch['actions'].shape[1]) * discounts).mean() + + return LossWithIntermediateLosses( + latent_recon_loss_weight=self.latent_recon_loss_weight, + perceptual_loss_weight=self.perceptual_loss_weight, + loss_obs=discounted_loss_obs, + loss_rewards=discounted_loss_rewards, + loss_value=discounted_loss_value, + loss_policy=discounted_loss_policy, + latent_recon_loss=discounted_latent_recon_loss, + perceptual_loss=discounted_perceptual_loss, + orig_policy_loss=discounted_orig_policy_loss, + policy_entropy=discounted_policy_entropy, + first_step_losses=first_step_losses, + middle_step_losses=middle_step_losses, + last_step_losses=last_step_losses, + dormant_ratio_encoder=dormant_ratio_encoder, + dormant_ratio_world_model=dormant_ratio_world_model, + latent_state_l2_norms=latent_state_l2_norms, + ) + + def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): + # Assume outputs is an object with logits attributes like 'rewards', 'policy', and 'value'. + # labels is a target tensor for comparison. batch is a dictionary with a mask indicating valid timesteps. + + logits = getattr(outputs, f'logits_{element}') + + # Reshape your tensors + logits = rearrange(logits, 'b t e -> (b t) e') + labels = labels.reshape(-1, labels.shape[-1]) # Assume labels initially have shape [batch, time, dim] + + # Reshape your mask. True indicates valid data. + mask_padding = rearrange(batch['mask_padding'], 'b t -> (b t)') + + # Compute cross-entropy loss + loss = -(torch.log_softmax(logits, dim=1) * labels).sum(1) + loss = (loss * mask_padding) + + if element == 'policy': + # Compute policy entropy loss + policy_entropy = self.compute_policy_entropy_loss(logits, mask_padding) + # Combine losses with specified weight + combined_loss = loss - self.policy_entropy_weight * policy_entropy + return combined_loss, loss, policy_entropy + + return loss + + def compute_policy_entropy_loss(self, logits, mask): + # Compute entropy of the policy + probs = torch.softmax(logits, dim=1) + log_probs = torch.log_softmax(logits, dim=1) + entropy = -(probs * log_probs).sum(1) + # Apply mask and return average entropy loss + entropy_loss = (entropy * mask) + return entropy_loss + + def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag + mask_fill = torch.logical_not(mask_padding) + + # Prepare observation labels + labels_observations = obs_embeddings.contiguous().view(rewards.shape[0], -1, self.projection_input_dim)[:, 1:] + + # Fill the masked areas of rewards + mask_fill_rewards = mask_fill.unsqueeze(-1).expand_as(rewards) + labels_rewards = rewards.masked_fill(mask_fill_rewards, -100) + + # Fill the masked areas of ends + labels_ends = ends.masked_fill(mask_fill, -100) + + return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) + + def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, + mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ Compute labels for value and policy predictions. """ + mask_fill = torch.logical_not(mask_padding) + + # Fill the masked areas of policy + mask_fill_policy = mask_fill.unsqueeze(-1).expand_as(target_policy) + labels_policy = target_policy.masked_fill(mask_fill_policy, -100) + + # Fill the masked areas of value + mask_fill_value = mask_fill.unsqueeze(-1).expand_as(target_value) + labels_value = target_value.masked_fill(mask_fill_value, -100) + + return labels_policy.reshape(-1, self.action_space_size), labels_value.reshape(-1, self.support_size) + + def clear_caches(self): + """ + Clears the caches of the world model. + """ + self.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in self.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + self.past_kv_cache_recurrent_infer.clear() + self.keys_values_wm_list.clear() + print(f'Cleared {self.__class__.__name__} past_kv_cache.') + + def __repr__(self) -> str: + return "transformer-based latent world_model of UniZero" diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 964b33529..0f1c0ad21 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -557,7 +557,8 @@ def _forward_collect( temperature: float = 1, to_play: List = [-1], epsilon: float = 0.25, - ready_env_id: np.array = None + ready_env_id: np.array = None, + task_id: int = None, ) -> Dict: """ Overview: @@ -569,6 +570,7 @@ def _forward_collect( - temperature (:obj:`float`): The temperature of the policy. - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - task_id (:obj:`int`): The task id. Default is None, which means UniZero is in the single-task mode. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ @@ -592,7 +594,7 @@ def _forward_collect( output = {i: None for i in ready_env_id} with torch.no_grad(): - network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data) + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() @@ -683,7 +685,7 @@ def _init_eval(self) -> None: self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, - ready_env_id: np.array = None) -> Dict: + ready_env_id: np.array = None, task_id: int = None,) -> Dict: """ Overview: The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. @@ -693,6 +695,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. - to_play (:obj:`int`): The player to play. - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + - task_id (:obj:`int`): The task id. Default is None, which means UniZero is in the single-task mode. Shape: - data (:obj:`torch.Tensor`): - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ @@ -711,7 +714,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 ready_env_id = np.arange(active_eval_env_num) output = {i: None for i in ready_env_id} with torch.no_grad(): - network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data) + network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) if not self._eval_model.training: diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py new file mode 100644 index 000000000..6360f8dd7 --- /dev/null +++ b/lzero/policy/unizero_multitask.py @@ -0,0 +1,1103 @@ +import copy +import sys +from collections import defaultdict +from typing import List, Dict, Tuple, Union + +import numpy as np +import torch +from ding.model import model_wrap +from ding.utils import POLICY_REGISTRY + +from lzero.entry.utils import initialize_zeros_batch +from lzero.mcts import UniZeroMCTSCtree as MCTSCtree +from lzero.model import ImageTransforms +from lzero.policy import prepare_obs_stack4_for_unizero +from lzero.policy import scalar_transform, InverseScalarTransform, phi_transform, \ + DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs +from lzero.policy.unizero import UniZeroPolicy +from .utils import configure_optimizers_nanogpt + +sys.path.append('/Users/puyuan/code/LibMTL/') +# from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect +# from LibMTL.weighting.CAGrad_unizero import CAGrad as GradCorrect +from LibMTL.weighting.FAMO_unizero import FAMO as GradCorrect + + +# from LibMTL.weighting.abstract_weighting import AbsWeighting + +def generate_task_loss_dict(multi_task_losses, task_name_template): + """ + 生成每个任务的损失字典 + :param multi_task_losses: 包含每个任务损失的列表 + :param task_name_template: 任务名称模板,例如 'obs_loss_task{}' + :return: 一个字典,包含每个任务的损失 + """ + task_loss_dict = {} + for task_idx, task_loss in enumerate(multi_task_losses): + task_name = task_name_template.format(task_idx) + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + return task_loss_dict + + + +class WrappedModel: + def __init__(self, tokenizer, transformer): + self.tokenizer = tokenizer + self.transformer = transformer + + def parameters(self): + # pos_emb.weight + # task_emb.weight + # act_embedding_table.weight + # 返回 tokenizer 和 transformer 的参数 + return list(self.tokenizer.parameters()) + list(self.transformer.parameters()) + + def zero_grad(self, set_to_none=False): + # 将 tokenizer 和 transformer 的梯度设为零 + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV2: + def __init__(self, tokenizer, transformer, pos_emb, task_emb, act_embedding_table): + self.tokenizer = tokenizer + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return (list(self.tokenizer.parameters()) + + list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV3: + def __init__(self, world_model): + self.world_model = world_model + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return self.world_model.parameters() + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + self.world_model.zero_grad(set_to_none=set_to_none) + + +class WrappedModelV4: + def __init__(self, transformer, pos_emb, task_emb, act_embedding_table): + self.transformer = transformer + self.pos_emb = pos_emb + self.task_emb = task_emb + self.act_embedding_table = act_embedding_table + + def parameters(self): + # 返回 tokenizer, transformer 以及所有嵌入层的参数 + return (list(self.transformer.parameters()) + + list(self.pos_emb.parameters()) + + list(self.task_emb.parameters()) + + list(self.act_embedding_table.parameters())) + + def zero_grad(self, set_to_none=False): + # 将 tokenizer, transformer 和所有嵌入层的梯度设为零 + # self.tokenizer.zero_grad(set_to_none=set_to_none) + self.transformer.zero_grad(set_to_none=set_to_none) + self.pos_emb.zero_grad(set_to_none=set_to_none) + self.task_emb.zero_grad(set_to_none=set_to_none) + self.act_embedding_table.zero_grad(set_to_none=set_to_none) + + +@POLICY_REGISTRY.register('unizero_multitask') +class UniZeroMTPolicy(UniZeroPolicy): + """ + Overview: + The policy class for UniZero, official implementation for paper UniZero: Generalized and Efficient Planning + with Scalable LatentWorld Models. UniZero aims to enhance the planning capabilities of reinforcement learning agents + by addressing the limitations found in MuZero-style algorithms, particularly in environments requiring the + capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + """ + + # The default_config for UniZero policy. + config = dict( + type='unizero_multitask', + model=dict( + # (str) The model type. For 1-dimensional vector obs, we use mlp model. For the image obs, we use conv model. + model_type='conv', # options={'mlp', 'conv'} + # (bool) If True, the action space of the environment is continuous, otherwise discrete. + continuous_action_space=False, + # (tuple) The obs shape. + observation_shape=(3, 64, 64), + # (bool) Whether to use the self-supervised learning loss. + self_supervised_learning_loss=True, + # (bool) Whether to use discrete support to represent categorical distribution for value/reward/value_prefix. + categorical_distribution=True, + # (int) The image channel in image observation. + image_channel=3, + # (int) The number of frames to stack together. + frame_stack_num=1, + # (int) The number of res blocks in MuZero model. + num_res_blocks=1, + # (int) The number of channels of hidden states in MuZero model. + num_channels=64, + # (int) The scale of supports used in categorical distribution. + # This variable is only effective when ``categorical_distribution=True``. + support_scale=50, + # (bool) whether to learn bias in the last linear layer in value and policy head. + bias=True, + # (bool) whether to use res connection in dynamics. + res_connection_in_dynamics=True, + # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'BN'. + norm_type='BN', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (int) The save interval of the model. + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=10000, ), ), ), + world_model_cfg=dict( + # (int) The number of tokens per block. + tokens_per_block=2, + # (int) The maximum number of blocks. + max_blocks=10, + # (int) The maximum number of tokens, calculated as tokens per block multiplied by max blocks. + max_tokens=2 * 10, + # (int) The context length, usually calculated as twice the number of some base unit. + context_length=2 * 4, + # (bool) Whether to use GRU gating mechanism. + gru_gating=False, + # (str) The device to be used for computation, e.g., 'cpu' or 'cuda'. + device='cpu', + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to analyze dormant ratio. + analysis_dormant_ratio=False, + # (int) The shape of the action space. + action_space_size=6, + # (int) The size of the group, related to simulation normalization. + group_size=8, # NOTE: sim_norm + # (str) The type of attention mechanism used. Options could be ['causal']. + attention='causal', + # (int) The number of layers in the model. + num_layers=4, + # (int) The number of attention heads. + num_heads=8, + # (int) The dimension of the embedding. + embed_dim=768, + # (float) The dropout probability for the embedding layer. + embed_pdrop=0.1, + # (float) The dropout probability for the residual connections. + resid_pdrop=0.1, + # (float) The dropout probability for the attention mechanism. + attn_pdrop=0.1, + # (int) The size of the support set for value and reward heads. + support_size=101, + # (int) The maximum size of the cache. + max_cache_size=5000, + # (int) The number of environments. + env_num=8, + # (float) The weight of the latent reconstruction loss. + latent_recon_loss_weight=0., + # (float) The weight of the perceptual loss. + perceptual_loss_weight=0., + # (float) The weight of the policy entropy. + policy_entropy_weight=1e-4, + # (str) The type of loss for predicting latent variables. Options could be ['group_kl', 'mse']. + predict_latent_loss_type='group_kl', + # (str) The type of observation. Options are ['image', 'vector']. + obs_type='image', + # (float) The discount factor for future rewards. + gamma=1, + # (float) The threshold for a dormant neuron. + dormant_threshold=0.025, + ), + ), + # ****** common ****** + # (bool) whether to use rnd model. + use_rnd_model=False, + # (bool) Whether to use multi-gpu training. + multi_gpu=False, + # (bool) Whether to enable the sampled-based algorithm (e.g. Sampled EfficientZero) + # this variable is used in ``collector``. + sampled_algo=False, + # (bool) Whether to enable the gumbel-based algorithm (e.g. Gumbel Muzero) + gumbel_algo=False, + # (bool) Whether to use C++ MCTS in policy. If False, use Python implementation. + mcts_ctree=True, + # (bool) Whether to use cuda for network. + cuda=True, + # (int) The number of environments used in collecting data. + collector_env_num=8, + # (int) The number of environments used in evaluating policy. + evaluator_env_num=3, + # (str) The type of environment. Options are ['not_board_games', 'board_games']. + env_type='not_board_games', + # (str) The type of action space. Options are ['fixed_action_space', 'varied_action_space']. + action_type='fixed_action_space', + # (str) The type of battle mode. Options are ['play_with_bot_mode', 'self_play_mode']. + battle_mode='play_with_bot_mode', + # (bool) Whether to monitor extra statistics in tensorboard. + monitor_extra_statistics=True, + # (int) The transition number of one ``GameSegment``. + game_segment_length=400, + # (bool) Whether to analyze simulation normalization. + analysis_sim_norm=False, + # (bool) Whether to use the pure policy to collect data. + collect_with_pure_policy=False, + # (int) The evaluation frequency. + eval_freq=int(2e3), + # (str) The sample type. Options are ['episode', 'transition']. + sample_type='transition', + + # ****** observation ****** + # (bool) Whether to transform image to string to save memory. + transform2string=False, + # (bool) Whether to use gray scale image. + gray_scale=False, + # (bool) Whether to use data augmentation. + use_augmentation=False, + # (list) The style of augmentation. + augmentation=['shift', 'intensity'], + + # ******* learn ****** + # (bool) Whether to ignore the done flag in the training data. Typically, this value is set to False. + # However, for some environments with a fixed episode length, to ensure the accuracy of Q-value calculations, + # we should set it to True to avoid the influence of the done flag. + ignore_done=False, + # (int) How many updates(iterations) to train after collector's one collection. + # Bigger "update_per_collect" means bigger off-policy. + # collect data -> update policy-> collect data -> ... + # For different env, we have different episode_length, + # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. + update_per_collect=None, + # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. + model_update_ratio=0.25, + # (int) Minibatch size for one gradient descent. + batch_size=256, + # (str) Optimizer for training policy network. ['SGD', 'Adam'] + optim_type='AdamW', + # (float) Learning rate for training policy network. Initial lr for manually decay schedule. + learning_rate=0.0001, + # (int) Frequency of hard target network update. + target_update_freq=100, + # (int) Frequency of soft target network update. + target_update_theta=0.05, + # (int) Frequency of target network update. + target_update_freq_for_intrinsic_reward=1000, + # (float) Weight decay for training policy network. + weight_decay=1e-4, + # (float) One-order Momentum in optimizer, which stabilizes the training process (gradient direction). + momentum=0.9, + # (float) The maximum constraint value of gradient norm clipping. + grad_clip_value=5, + # (int) The number of episodes in each collecting stage. + n_episode=8, + # (int) the number of simulations in MCTS. + num_simulations=50, + # (float) Discount factor (gamma) for returns. + discount_factor=0.997, + # (int) The number of steps for calculating target q_value. + td_steps=5, + # (int) The number of unroll steps in dynamics network. + num_unroll_steps=10, + # (float) The weight of reward loss. + reward_loss_weight=1, + # (float) The weight of value loss. + value_loss_weight=0.25, + # (float) The weight of policy loss. + policy_loss_weight=1, + # (float) The weight of policy entropy loss. + policy_entropy_loss_weight=0, + # (float) The weight of ssl (self-supervised learning) loss. + ssl_loss_weight=0, + # (bool) Whether to use piecewise constant learning rate decay. + # i.e. lr: 0.2 -> 0.02 -> 0.002 + lr_piecewise_constant_decay=False, + # (int) The number of final training iterations to control lr decay, which is only used for manually decay. + threshold_training_steps_for_final_lr=int(5e4), + # (bool) Whether to use manually decayed temperature. + manual_temperature_decay=False, + # (int) The number of final training iterations to control temperature, which is only used for manually decay. + threshold_training_steps_for_final_temperature=int(1e5), + # (float) The fixed temperature value for MCTS action selection, which is used to control the exploration. + # The larger the value, the more exploration. This value is only used when manual_temperature_decay=False. + fixed_temperature_value=0.25, + # (bool) Whether to use the true chance in MCTS in some environments with stochastic dynamics, such as 2048. + use_ture_chance_label_in_chance_encoder=False, + + # ****** Priority ****** + # (bool) Whether to use priority when sampling training data from the buffer. + use_priority=False, + # (float) The degree of prioritization to use. A value of 0 means no prioritization, + # while a value of 1 means full prioritization. + priority_prob_alpha=0.6, + # (float) The degree of correction to use. A value of 0 means no correction, + # while a value of 1 means full correction. + priority_prob_beta=0.4, + # (int) The initial Env Steps for training. + train_start_after_envsteps=int(0), + + # ****** UCB ****** + # (float) The alpha value used in the Dirichlet distribution for exploration at the root node of search tree. + root_dirichlet_alpha=0.3, + # (float) The noise weight at the root node of the search tree. + root_noise_weight=0.25, + + # ****** Explore by random collect ****** + # (int) The number of episodes to collect data randomly before training. + random_collect_episode_num=0, + + # ****** Explore by eps greedy ****** + eps=dict( + # (bool) Whether to use eps greedy exploration in collecting data. + eps_greedy_exploration_in_collect=False, + # (str) The type of decaying epsilon. Options are 'linear', 'exp'. + type='linear', + # (float) The start value of eps. + start=1., + # (float) The end value of eps. + end=0.05, + # (int) The decay steps from start to end eps. + decay=int(1e5), + ), + ) + + def default_model(self) -> Tuple[str, List[str]]: + """ + Overview: + Return this algorithm default model setting for demonstration. + Returns: + - model_info (:obj:`Tuple[str, List[str]]`): model name and model import_names. + - model_type (:obj:`str`): The model type used in this algorithm, which is registered in ModelRegistry. + - import_names (:obj:`List[str]`): The model class path list used in this algorithm. + .. note:: + The user can define and use customized network model but must obey the same interface definition indicated \ + by import_names path. For MuZero, ``lzero.model.unizero_model.MuZeroModel`` + """ + # return 'UniZeroModel', ['lzero.model.unizero_model'] + + return 'UniZeroMTModel', ['lzero.model.unizero_model_multitask'] + + def _init_learn(self) -> None: + """ + Overview: + Learn mode init method. Called by ``self.__init__``. Initialize the learn model, optimizer and MCTS utils. + """ + # NOTE: nanoGPT optimizer + self._optimizer_world_model = configure_optimizers_nanogpt( + model=self._model.world_model, + learning_rate=self._cfg.learning_rate, + weight_decay=self._cfg.weight_decay, + device_type=self._cfg.device, + betas=(0.9, 0.95), + ) + + # use model_wrapper for specialized demands of different modes + self._target_model = copy.deepcopy(self._model) + # Ensure that the installed torch version is greater than or equal to 2.0 + assert int(''.join(filter(str.isdigit, torch.__version__))) >= 200, "We need torch version >= 2.0" + self._model = torch.compile(self._model) + self._target_model = torch.compile(self._target_model) + # NOTE: soft target + self._target_model = model_wrap( + self._target_model, + wrapper_name='target', + update_type='momentum', + update_kwargs={'theta': self._cfg.target_update_theta} + ) + self._learn_model = self._model + + if self._cfg.use_augmentation: + self.image_transforms = ImageTransforms( + self._cfg.augmentation, + image_shape=(self._cfg.model.observation_shape[1], self._cfg.model.observation_shape[2]) + ) + self.value_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.reward_support = DiscreteSupport(-self._cfg.model.support_scale, self._cfg.model.support_scale, delta=1) + self.inverse_scalar_transform_handle = InverseScalarTransform( + self._cfg.model.support_scale, self._cfg.device, self._cfg.model.categorical_distribution + ) + self.intermediate_losses = defaultdict(float) + self.l2_norm_before = 0. + self.l2_norm_after = 0. + self.grad_norm_before = 0. + self.grad_norm_after = 0. + + # 创建 WrappedModel 实例 + # wrapped_model = WrappedModel( + # self._learn_model.world_model.tokenizer, + # self._learn_model.world_model.transformer + # ) + wrapped_model = WrappedModelV2( + # self._learn_model.world_model.tokenizer, # TODO + self._learn_model.world_model.tokenizer.encoder[0], # TODO: one encoder + self._learn_model.world_model.transformer, + self._learn_model.world_model.pos_emb, + self._learn_model.world_model.task_emb, + self._learn_model.world_model.act_embedding_table, + ) + # wrapped_model = WrappedModelV3( + # self._learn_model.world_model, + # ) + # wrapped_model = WrappedModelV4( + # self._learn_model.world_model.transformer, + # self._learn_model.world_model.pos_emb, + # self._learn_model.world_model.task_emb, + # self._learn_model.world_model.act_embedding_table, + # ) + # 将 wrapped_model 作为 share_model 传递给 GradCorrect + self.task_num = self._cfg.task_num + self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device) + self.grad_correct.init_param() # 初始化MoCo参数 + self.grad_correct.rep_grad = False + # self.grad_correct.set_min_losses(torch.tensor([0. for i in range(self.task_num)], device=self._cfg.device)) # only for FAMO + self.curr_min_loss = torch.tensor([0. for i in range(self.task_num)], device=self._cfg.device) + self.grad_correct.prev_loss = self.curr_min_loss + + # @profile + def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: + """ + Overview: + The forward function for learning policy in learn mode, which is the core of the learning process. + The data is sampled from replay buffer. + The loss is calculated by the loss function and the loss is backpropagated to update the model. + Arguments: + - data (:obj:`Tuple[torch.Tensor]`): The data sampled from replay buffer, which is a tuple of tensors. + The first tensor is the current_batch, the second tensor is the target_batch. + Returns: + - info_dict (:obj:`Dict[str, Union[float, int]]`): The information dict to be logged, which contains \ + current learning loss and learning statistics. + """ + self._learn_model.train() + self._target_model.train() + + obs_loss_multi_task = [] + reward_loss_multi_task = [] + policy_loss_multi_task = [] + value_loss_multi_task = [] + latent_recon_loss_multi_task = [] + perceptual_loss_multi_task = [] + orig_policy_loss_multi_task = [] + policy_entropy_multi_task = [] + # weighted_total_loss = torch.tensor(0., device=self._cfg.device) + # weighted_total_loss.requires_grad = True + weighted_total_loss = 0.0 # 初始化为0,避免使用in-place操作 + + average_target_policy_entropy_multi_task = [] + + losses_list = [] # 用于存储每个任务的损失 + for task_id, data_one_task in enumerate(data): + current_batch, target_batch, task_id = data_one_task + # current_batch, target_batch, _ = data + obs_batch_ori, action_batch, mask_batch, indices, weights, make_time = current_batch + target_reward, target_value, target_policy = target_batch + + # Prepare observations based on frame stack number + if self._cfg.model.frame_stack_num == 4: + obs_batch, obs_target_batch = prepare_obs_stack4_for_unizero(obs_batch_ori, self._cfg) + else: + obs_batch, obs_target_batch = prepare_obs(obs_batch_ori, self._cfg) + + # Apply augmentations if needed + if self._cfg.use_augmentation: + obs_batch = self.image_transforms.transform(obs_batch) + if self._cfg.model.self_supervised_learning_loss: + obs_target_batch = self.image_transforms.transform(obs_target_batch) + + # Prepare action batch and convert to torch tensor + action_batch = torch.from_numpy(action_batch).to(self._cfg.device).unsqueeze( + -1).long() # For discrete action space + data_list = [mask_batch, target_reward.astype('float32'), target_value.astype('float32'), target_policy, + weights] + mask_batch, target_reward, target_value, target_policy, weights = to_torch_float_tensor(data_list, + self._cfg.device) + + target_reward = target_reward.view(self._cfg.batch_size[task_id], -1) + target_value = target_value.view(self._cfg.batch_size[task_id], -1) + + # assert obs_batch.size(0) == self._cfg.batch_size == target_reward.size(0) + + # Transform rewards and values to their scaled forms + transformed_target_reward = scalar_transform(target_reward) + transformed_target_value = scalar_transform(target_value) + + # Convert to categorical distributions + target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + + # Prepare batch for GPT model + batch_for_gpt = {} + if isinstance(self._cfg.model.observation_shape, int) or len(self._cfg.model.observation_shape) == 1: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size[task_id], -1, self._cfg.model.observation_shape) + elif len(self._cfg.model.observation_shape) == 3: + batch_for_gpt['observations'] = torch.cat((obs_batch, obs_target_batch), dim=1).reshape( + self._cfg.batch_size[task_id], -1, *self._cfg.model.observation_shape) + + batch_for_gpt['actions'] = action_batch.squeeze(-1) + batch_for_gpt['rewards'] = target_reward_categorical[:, :-1] + batch_for_gpt['mask_padding'] = mask_batch == 1.0 # 0 means invalid padding data + batch_for_gpt['mask_padding'] = batch_for_gpt['mask_padding'][:, :-1] + batch_for_gpt['observations'] = batch_for_gpt['observations'][:, :-1] + batch_for_gpt['ends'] = torch.zeros(batch_for_gpt['mask_padding'].shape, dtype=torch.long, + device=self._cfg.device) + batch_for_gpt['target_value'] = target_value_categorical[:, :-1] + batch_for_gpt['target_policy'] = target_policy[:, :-1] + + # Extract valid target policy data and compute entropy + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean().item() + + # Update world model + intermediate_losses = defaultdict(float) + losses = self._learn_model.world_model.compute_loss( + batch_for_gpt, self._target_model.world_model.tokenizer, self.inverse_scalar_transform_handle, task_id=task_id + ) + + weighted_total_loss += losses.loss_total # TODO + # weighted_total_loss = torch.tensor(0., device=self._cfg.device) + + # assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" + # assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" + + losses_list.append(losses.loss_total) # TODO: for moco + + # weighted_total_loss = weighted_total_loss + losses.loss_total # 修改为非in-place操作 + for loss_name, loss_value in losses.intermediate_losses.items(): + intermediate_losses[f"{loss_name}"] = loss_value + + obs_loss = intermediate_losses['loss_obs'] + reward_loss = intermediate_losses['loss_rewards'] + policy_loss = intermediate_losses['loss_policy'] + orig_policy_loss = intermediate_losses['orig_policy_loss'] + policy_entropy = intermediate_losses['policy_entropy'] + value_loss = intermediate_losses['loss_value'] + latent_recon_loss = intermediate_losses['latent_recon_loss'] + perceptual_loss = intermediate_losses['perceptual_loss'] + + obs_loss_multi_task.append(obs_loss) + reward_loss_multi_task.append(reward_loss) + policy_loss_multi_task.append(policy_loss) + orig_policy_loss_multi_task.append(orig_policy_loss) + policy_entropy_multi_task.append(policy_entropy) + reward_loss_multi_task.append(reward_loss) + value_loss_multi_task.append(value_loss) + latent_recon_loss_multi_task.append(latent_recon_loss) + perceptual_loss_multi_task.append(perceptual_loss) + + # Core learn model update step + self._optimizer_world_model.zero_grad() + + # TODO MoCo + # 使用MoCo来计算梯度和权重 + # ============= for CAGrad and MoCo ============= + # lambd = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) + + # ============= for FAMO ============= TODO: self.grad_correct.min_loss + # lambd = torch.tensor([0. for i in range(self.task_num)], device=self._cfg.device) + # curr_loss, _ = self.grad_correct.backward(losses=torch.tensor(losses_list, device=self._cfg.device)) + # for i in range(self.task_num): + # if losses_list[i] < self.grad_correct.min_losses[i]: + # self.curr_min_loss[i] = losses_list[i] + # self.grad_correct.min_loss = self.curr_min_loss # only for FAMO + # self.grad_correct.update(curr_loss.detach()) + # self.grad_correct.prev_loss = curr_loss.detach() + + # ============= 不使用梯度矫正的情况 ============= + lambd = torch.tensor([0. for i in range(self.task_num)], device=self._cfg.device) + + weighted_total_loss.backward() + + # ========== for debugging ========== + # for name, param in self._learn_model.world_model.tokenizer.encoder.named_parameters(): + # print('name, param.mean(), param.std():', name, param.mean(), param.std()) + # if param.requires_grad: + # print(name, param.grad.norm()) + + if self._cfg.analysis_sim_norm: + del self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after + self.l2_norm_before, self.l2_norm_after, self.grad_norm_before, self.grad_norm_after = self._learn_model.encoder_hook.analyze() + self._target_model.encoder_hook.clear_data() + + if self._cfg.multi_gpu: + self.sync_gradients(self._learn_model) + total_grad_norm_before_clip_wm = torch.nn.utils.clip_grad_norm_(self._learn_model.world_model.parameters(), + self._cfg.grad_clip_value) + + self._optimizer_world_model.step() + if self._cfg.lr_piecewise_constant_decay: + self.lr_scheduler.step() + + # Core target model update step + self._target_model.update(self._learn_model.state_dict()) + + if torch.cuda.is_available(): + torch.cuda.synchronize() + current_memory_allocated = torch.cuda.memory_allocated() + max_memory_allocated = torch.cuda.max_memory_allocated() + current_memory_allocated_gb = current_memory_allocated / (1024 ** 3) + max_memory_allocated_gb = max_memory_allocated / (1024 ** 3) + else: + current_memory_allocated_gb = 0. + max_memory_allocated_gb = 0. + + # 然后,在您的代码中,使用这个函数来构建损失字典: + return_loss_dict = { + 'Current_GPU': current_memory_allocated_gb, + 'Max_GPU': max_memory_allocated_gb, + 'collect_mcts_temperature': self._collect_mcts_temperature, + 'collect_epsilon': self._collect_epsilon, + 'cur_lr_world_model': self._optimizer_world_model.param_groups[0]['lr'], + 'weighted_total_loss': weighted_total_loss.item(), + # 'policy_entropy': policy_entropy, + # 'target_policy_entropy': average_target_policy_entropy, + 'total_grad_norm_before_clip_wm': total_grad_norm_before_clip_wm.item(), + } + + # 用于存储多任务损失的字典 + multi_task_loss_dicts = { + **generate_task_loss_dict(obs_loss_multi_task, 'obs_loss_task{}'), + **generate_task_loss_dict(latent_recon_loss_multi_task, 'latent_recon_loss_task{}'), + **generate_task_loss_dict(perceptual_loss_multi_task, 'perceptual_loss_task{}'), + **generate_task_loss_dict(policy_loss_multi_task, 'policy_loss_task{}'), + **generate_task_loss_dict(orig_policy_loss_multi_task, 'orig_policy_loss_task{}'), + **generate_task_loss_dict(policy_entropy_multi_task, 'policy_entropy_task{}'), + **generate_task_loss_dict(reward_loss_multi_task, 'reward_loss_task{}'), + **generate_task_loss_dict(value_loss_multi_task, 'value_loss_task{}'), + **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'target_policy_entropy_task{}'), + **generate_task_loss_dict(lambd, 'lambd_task{}'), + } + + # 合并两个字典 + return_loss_dict.update(multi_task_loss_dicts) + + # 返回最终的损失字典 + return return_loss_dict + + def monitor_weights_and_grads(self, model): + for name, param in model.named_parameters(): + if param.requires_grad: + print(f"Layer: {name} | " + f"Weight mean: {param.data.mean():.4f} | " + f"Weight std: {param.data.std():.4f} | " + f"Grad mean: {param.grad.mean():.4f} | " + f"Grad std: {param.grad.std():.4f}") + + def _init_collect(self) -> None: + """ + Overview: + Collect mode init method. Called by ``self.__init__``. Initialize the collect model and MCTS utils. + """ + self._collect_model = self._model + + if self._cfg.mcts_ctree: + self._mcts_collect = MCTSCtree(self._cfg) + else: + self._mcts_collect = MCTSPtree(self._cfg) + self._collect_mcts_temperature = 1. + self._collect_epsilon = 0.0 + self.collector_env_num = self._cfg.collector_env_num + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for i in range(self.collector_env_num)] + + # @profile + def _forward_collect( + self, + data: torch.Tensor, + action_mask: list = None, + temperature: float = 1, + to_play: List = [-1], + epsilon: float = 0.25, + ready_env_id: np.array = None, + task_id=None + ) -> Dict: + """ + Overview: + The forward function for collecting data in collect mode. Use model to execute MCTS search. + Choosing the action through sampling during the collect mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - temperature (:obj:`float`): The temperature of the policy. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - temperature: :math:`(1, )`. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._collect_model.eval() + + self._collect_mcts_temperature = temperature + self._collect_epsilon = epsilon + active_collect_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_collect_env_num) + output = {i: None for i in ready_env_id} + + with torch.no_grad(): + network_output = self._collect_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_collect_env_num)] + # the only difference between collect and eval is the dirichlet noise + noises = [ + np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j])) + ).astype(np.float32).tolist() for j in range(active_collect_env_num) + ] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_collect_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_collect_env_num, legal_actions) + + roots.prepare(self._cfg.root_noise_weight, noises, reward_roots, policy_logits, to_play) + self._mcts_collect.search(roots, self._collect_model, latent_state_roots, to_play, task_id=task_id) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + + if self._cfg.eps.eps_greedy_exploration_in_collect: + # eps greedy collect + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=True + ) + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + if np.random.rand() < self._collect_epsilon: + action = np.random.choice(legal_actions[i]) + else: + # normal collect + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=self._collect_mcts_temperature, deterministic=False + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + # ============== TODO: only for visualize ============== + # action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + # distributions, temperature=self._collect_mcts_temperature, deterministic=True + # ) + # action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + # ============== TODO: only for visualize ============== + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + return output + + def _init_eval(self) -> None: + """ + Overview: + Evaluate mode init method. Called by ``self.__init__``. Initialize the eval model and MCTS utils. + """ + self._eval_model = self._model + if self._cfg.mcts_ctree: + self._mcts_eval = MCTSCtree(self._cfg) + else: + self._mcts_eval = MCTSPtree(self._cfg) + self.evaluator_env_num = self._cfg.evaluator_env_num + + if self._cfg.model.model_type == 'conv': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape[0], 64, 64]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + elif self._cfg.model.model_type == 'mlp': + self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) + self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + + def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, + ready_env_id: np.array = None, task_id=None) -> Dict: + """ + Overview: + The forward function for evaluating the current policy in eval mode. Use model to execute MCTS search. + Choosing the action with the highest value (argmax) rather than sampling during the eval mode. + Arguments: + - data (:obj:`torch.Tensor`): The input data, i.e. the observation. + - action_mask (:obj:`list`): The action mask, i.e. the action that cannot be selected. + - to_play (:obj:`int`): The player to play. + - ready_env_id (:obj:`list`): The id of the env that is ready to collect. + Shape: + - data (:obj:`torch.Tensor`): + - For Atari, :math:`(N, C*S, H, W)`, where N is the number of collect_env, C is the number of channels, \ + S is the number of stacked frames, H is the height of the image, W is the width of the image. + - For lunarlander, :math:`(N, O)`, where N is the number of collect_env, O is the observation space size. + - action_mask: :math:`(N, action_space_size)`, where N is the number of collect_env. + - to_play: :math:`(N, 1)`, where N is the number of collect_env. + - ready_env_id: None + Returns: + - output (:obj:`Dict[int, Any]`): Dict type data, the keys including ``action``, ``distributions``, \ + ``visit_count_distribution_entropy``, ``value``, ``pred_value``, ``policy_logits``. + """ + self._eval_model.eval() + active_eval_env_num = data.shape[0] + if ready_env_id is None: + ready_env_id = np.arange(active_eval_env_num) + output = {i: None for i in ready_env_id} + with torch.no_grad(): + network_output = self._eval_model.initial_inference(self.last_batch_obs, self.last_batch_action, data, task_id=task_id) + latent_state_roots, reward_roots, pred_values, policy_logits = mz_network_output_unpack(network_output) + + if not self._eval_model.training: + # if not in training, obtain the scalars of the value/reward + pred_values = self.inverse_scalar_transform_handle(pred_values).detach().cpu().numpy() # shape(B, 1) + latent_state_roots = latent_state_roots.detach().cpu().numpy() + policy_logits = policy_logits.detach().cpu().numpy().tolist() # list shape(B, A) + + legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(active_eval_env_num)] + if self._cfg.mcts_ctree: + # cpp mcts_tree + roots = MCTSCtree.roots(active_eval_env_num, legal_actions) + else: + # python mcts_tree + roots = MCTSPtree.roots(active_eval_env_num, legal_actions) + roots.prepare_no_noise(reward_roots, policy_logits, to_play) + self._mcts_eval.search(roots, self._eval_model, latent_state_roots, to_play, task_id=task_id) + + # list of list, shape: ``{list: batch_size} -> {list: action_space_size}`` + roots_visit_count_distributions = roots.get_distributions() + roots_values = roots.get_values() # shape: {list: batch_size} + + batch_action = [] + + for i, env_id in enumerate(ready_env_id): + distributions, value = roots_visit_count_distributions[i], roots_values[i] + # print("roots_visit_count_distributions:", distributions, "root_value:", value) + + # NOTE: Only legal actions possess visit counts, so the ``action_index_in_legal_action_set`` represents + # the index within the legal action set, rather than the index in the entire action set. + # Setting deterministic=True implies choosing the action with the highest value (argmax) rather than + # sampling during the evaluation phase. + action_index_in_legal_action_set, visit_count_distribution_entropy = select_action( + distributions, temperature=1, deterministic=True + ) + # NOTE: Convert the ``action_index_in_legal_action_set`` to the corresponding ``action`` in the + # entire action set. + action = np.where(action_mask[i] == 1.0)[0][action_index_in_legal_action_set] + + output[env_id] = { + 'action': action, + 'visit_count_distributions': distributions, + 'visit_count_distribution_entropy': visit_count_distribution_entropy, + 'searched_value': value, + 'predicted_value': pred_values[i], + 'predicted_policy_logits': policy_logits[i], + } + batch_action.append(action) + + self.last_batch_obs = data + self.last_batch_action = batch_action + + return output + + def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + """ + Overview: + This method resets the collection process for a specific environment. It clears caches and memory + when certain conditions are met, ensuring optimal performance. If reset_init_data is True, the initial data + will be reset. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. + - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine + whether to clear caches. + - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + """ + if reset_init_data: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.collector_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.collector_env_num)] + # print('collector: last_batch_obs, last_batch_action reset()', self.last_batch_obs.shape) + + # Return immediately if env_id is None or a list + if env_id is None or isinstance(env_id, list): + return + + # Determine the clear interval based on the environment's sample type + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + + # Clear caches if the current steps are a multiple of the clear interval + if current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') + + # Clear various caches in the collect model's world model + world_model = self._collect_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + + print('collector: collect_model clear()') + print(f'eps_steps_lst[{env_id}]: {current_steps}') + # TODO: check its correctness + self._reset_target_model() + + def _reset_target_model(self) -> None: + """ + Overview: + This method resets the target model. It clears caches and memory, ensuring optimal performance. + Arguments: + - None + """ + + # Clear various caches in the target_model + world_model = self._target_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + print('collector: target_model past_kv_cache.clear()') + + def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + """ + Overview: + This method resets the evaluation process for a specific environment. It clears caches and memory + when certain conditions are met, ensuring optimal performance. If reset_init_data is True, + the initial data will be reset. + Arguments: + - env_id (:obj:`int`, optional): The ID of the environment to reset. If None or list, the function returns immediately. + - current_steps (:obj:`int`, optional): The current step count in the environment. Used to determine + whether to clear caches. + - reset_init_data (:obj:`bool`, optional): Whether to reset the initial data. If True, the initial data will be reset. + """ + if reset_init_data: + self.last_batch_obs = initialize_zeros_batch( + self._cfg.model.observation_shape, + self._cfg.evaluator_env_num, + self._cfg.device + ) + self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] + # print('evaluator: last_batch_obs, last_batch_action reset()', self.last_batch_obs.shape) + + # Return immediately if env_id is None or a list + if env_id is None or isinstance(env_id, list): + return + + # Determine the clear interval based on the environment's sample type + clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 + + # Clear caches if the current steps are a multiple of the clear interval + if current_steps % clear_interval == 0: + print(f'clear_interval: {clear_interval}') + + # Clear various caches in the eval model's world model + world_model = self._eval_model.world_model + world_model.past_kv_cache_init_infer.clear() + for kv_cache_dict_env in world_model.past_kv_cache_init_infer_envs: + kv_cache_dict_env.clear() + world_model.past_kv_cache_recurrent_infer.clear() + world_model.keys_values_wm_list.clear() + + # Free up GPU memory + torch.cuda.empty_cache() + + print('evaluator: eval_model clear()') + print(f'eps_steps_lst[{env_id}]: {current_steps}') + + # TODO: num_tasks + def _monitor_vars_learn(self, num_tasks=2) -> List[str]: + """ + Overview: + Register the variables to be monitored in learn mode. The registered variables will be logged in + tensorboard according to the return value ``_forward_learn``. + If num_tasks is provided, generate monitored variables for each task. + """ + # Basic monitored variables that do not depend on the number of tasks + monitored_vars = [ + 'Current_GPU', + 'Max_GPU', + 'collect_epsilon', + 'collect_mcts_temperature', + 'cur_lr_world_model', + 'weighted_total_loss', + 'total_grad_norm_before_clip_wm', + ] + + # Variable names that will have task-specific counterparts + task_specific_vars = [ + 'obs_loss', + 'orig_policy_loss', + 'policy_loss', + 'latent_recon_loss', + 'policy_entropy', + 'target_policy_entropy', + 'reward_loss', + 'value_loss', + 'perceptual_loss', + 'lambd', + ] + + # If the number of tasks is provided, extend the monitored variables list with task-specific variables + if num_tasks is not None: + for var in task_specific_vars: + for task_idx in range(num_tasks): + monitored_vars.append(f'{var}_task{task_idx}') + else: + # If num_tasks is not provided, we assume there's only one task and keep the original variable names + monitored_vars.extend(task_specific_vars) + + return monitored_vars + + def recompute_pos_emb_diff_and_clear_cache(self) -> None: + """ + Overview: + Clear the caches and precompute positional embedding matrices in the model. + """ + # NOTE: Clear caches and precompute positional embedding matrices both for the collect and target models + for model in [self._collect_model, self._target_model]: + model.world_model.precompute_pos_emb_diff_kv() + model.world_model.clear_caches() + torch.cuda.empty_cache() \ No newline at end of file diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 9933f816e..81bc46047 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -40,6 +40,7 @@ def __init__( exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'collector', policy_config: 'policy_config' = None, # noqa + task_id: int = None, ) -> None: """ Overview: @@ -52,7 +53,9 @@ def __init__( - exp_name (:obj:`str`): Name of the experiment, used for logging and saving purposes. - instance_name (:obj:`str`): Unique identifier for this collector instance. - policy_config (:obj:`Optional[policy_config]`): Configuration object for the policy. + - task_id (:obj:`int`): Unique identifier for the task. If None, that means we are in the single task mode. """ + self.task_id = task_id self._exp_name = exp_name self._instance_name = instance_name self._collect_print_freq = collect_print_freq @@ -423,7 +426,8 @@ def collect(self, # Key policy forward step # ============================================================== # print(f'ready_env_id:{ready_env_id}') - policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + # policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id) + policy_output = self._policy.forward(stack_obs, action_mask, temperature, to_play, epsilon, ready_env_id=ready_env_id, task_id=self.task_id) # Extract relevant policy outputs actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} @@ -763,7 +767,13 @@ def _output_log(self, train_iter: int) -> None: for k, v in info.items(): if k in ['each_reward']: continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + if self.task_id is None: + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + else: + self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, train_iter) if k in ['total_envstep_count']: continue - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) \ No newline at end of file + if self.task_id is None: + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count) + else: + self._tb_logger.add_scalar('{}_step_task{}/'.format(self._instance_name, self.task_id) + k, v, self._total_envstep_count) \ No newline at end of file diff --git a/lzero/worker/muzero_evaluator.py b/lzero/worker/muzero_evaluator.py index bf13e010b..3dd604097 100644 --- a/lzero/worker/muzero_evaluator.py +++ b/lzero/worker/muzero_evaluator.py @@ -55,6 +55,7 @@ def __init__( exp_name: Optional[str] = 'default_experiment', instance_name: Optional[str] = 'evaluator', policy_config: 'policy_config' = None, # noqa + task_id: int = None, ) -> None: """ Overview: @@ -69,7 +70,9 @@ def __init__( - exp_name (:obj:`str`): Name of the experiment, used to determine output directory. - instance_name (:obj:`str`): Name of this evaluator instance. - policy_config (:obj:`Optional[dict]`): Optional configuration for the game policy. + - task_id (:obj:`int`): Unique identifier for the task. If None, that means we are in the single task mode. """ + self.task_id = task_id self._eval_freq = eval_freq self._exp_name = exp_name self._instance_name = instance_name @@ -283,7 +286,8 @@ def eval( # ============================================================== # policy forward # ============================================================== - policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) + # policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id) + policy_output = self._policy.forward(stack_obs, action_mask, to_play, ready_env_id=ready_env_id, task_id=self.task_id) actions_with_env_id = {k: v['action'] for k, v in policy_output.items()} distributions_dict_with_env_id = {k: v['visit_count_distributions'] for k, v in policy_output.items()} @@ -431,8 +435,14 @@ def eval( continue if not np.isscalar(v): continue - self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) - self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + if self.task_id is None: + self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter) + self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep) + else: + self._tb_logger.add_scalar('{}_iter_task{}/'.format(self._instance_name, self.task_id) + k, v, + train_iter) + self._tb_logger.add_scalar('{}_step_task{}/'.format(self._instance_name, self.task_id) + k, v, + envstep) episode_return = np.mean(episode_return) if episode_return > self._max_episode_return: if save_ckpt_fn: diff --git a/zoo/atari/config/atari_unizero_multitask_config.py b/zoo/atari/config/atari_unizero_multitask_config.py new file mode 100644 index 000000000..a182cc6a0 --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_config.py @@ -0,0 +1,144 @@ +from easydict import EasyDict +from copy import deepcopy +# from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + full_action_space=True, + # ===== only for debug ===== + collect_max_episode_steps=int(50), + eval_max_episode_steps=int(50), + ), + policy=dict( + grad_correct_params=dict( + # for MoCo + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + # for CAGrad + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cpu', # 'cuda', + action_space_size=action_space_size, + num_layers=4, # NOTE + num_heads=8, + embed_dim=768, + obs_type='image', + env_num=max(collector_env_num, evaluator_env_num), + # collector_env_num=collector_env_num, + # evaluator_env_num=evaluator_env_num, + task_num=len(env_id_list), + ), + ), + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + update_per_collect=None, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, seed): + configs = [] + exp_name_prefix = f'data_unizero_mt_0705/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head_1-encoder-LN_lsd768-nlayer4-nh8_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + # collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing + # evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + # n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + + configs.append([task_id, [config, create_env_manager()]]) + return configs + + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + from lzero.entry import train_unizero_multitask + + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4' + ] + + action_space_size = 18 # Full action space + seed = 0 + collector_env_num = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(1e6) + reanalyze_ratio = 0. + batch_size = [32, 32, 32, 32] + num_unroll_steps = 10 + infer_context_length = 4 + + # ======== only for debug ======== + # collector_env_num = 3 + # n_episode = 3 + # evaluator_env_num = 2 + # num_simulations = 5 + # batch_size = [2, 2, 2, 2] + + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, seed) + + # Uncomment the desired training run + # train_unizero_multitask(configs[:1], seed=seed, max_env_step=max_env_step) # Pong + train_unizero_multitask(configs[:2], seed=seed, max_env_step=max_env_step) # Pong, MsPacman + # train_unizero_multitask(configs, seed=seed, max_env_step=max_env_step) # Pong, MsPacman, Seaquest, Boxing \ No newline at end of file diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index 84288feb5..f68f113d4 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -24,6 +24,8 @@ class AtariEnvLightZero(BaseEnv): _reward_space, obs, _eval_episode_return, has_reset, _seed, _dynamic_seed """ config = dict( + # (bool) Whether to use the full action space of the environment. Default is False. If set to True, the action space size is 18 for Atari. + full_action_space=False, # (int) The number of environment instances used for data collection. collector_env_num=8, # (int) The number of environment instances used for evaluator. diff --git a/zoo/atari/envs/atari_wrappers.py b/zoo/atari/envs/atari_wrappers.py index f38aa24d6..265ef31ac 100644 --- a/zoo/atari/envs/atari_wrappers.py +++ b/zoo/atari/envs/atari_wrappers.py @@ -93,9 +93,9 @@ def wrap_lightzero(config: EasyDict, episode_life: bool, clip_rewards: bool) -> - env (:obj:`gym.Env`): The wrapped Atari environment with the given configurations. """ if config.render_mode_human: - env = gym.make(config.env_id, render_mode='human') + env = gym.make(config.env_id, render_mode='human', full_action_space=config.full_action_space) else: - env = gym.make(config.env_id, render_mode='rgb_array') + env = gym.make(config.env_id, render_mode='rgb_array', full_action_space=config.full_action_space) assert 'NoFrameskip' in env.spec.id if hasattr(config, 'save_replay') and config.save_replay \ and hasattr(config, 'replay_path') and config.replay_path is not None: From 8769a5ce0b7ec1cc409ae345273f7e6f56d60411 Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Mon, 8 Jul 2024 16:41:44 +0800 Subject: [PATCH 02/13] polish(pu): polish unizero_multitask config --- lzero/entry/train_unizero_multitask.py | 2 +- lzero/policy/unizero_multitask.py | 12 ++++++------ zoo/atari/config/atari_unizero_multitask_config.py | 13 +++++++------ 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/lzero/entry/train_unizero_multitask.py b/lzero/entry/train_unizero_multitask.py index 0ba9d2233..e93b1ecff 100644 --- a/lzero/entry/train_unizero_multitask.py +++ b/lzero/entry/train_unizero_multitask.py @@ -54,7 +54,7 @@ def train_unizero_multitask( assert create_cfg.policy.type in ['unizero_multitask'], "train_unizero entry now only supports 'unizero'" # Set device based on CUDA availability - cfg.policy.device = cfg.policy.model.world_model.device if torch.cuda.is_available() else 'cpu' + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' logging.info(f'cfg.policy.device: {cfg.policy.device}') # Compile the configuration diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py index 6360f8dd7..d752fe148 100644 --- a/lzero/policy/unizero_multitask.py +++ b/lzero/policy/unizero_multitask.py @@ -20,7 +20,7 @@ sys.path.append('/Users/puyuan/code/LibMTL/') # from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect # from LibMTL.weighting.CAGrad_unizero import CAGrad as GradCorrect -from LibMTL.weighting.FAMO_unizero import FAMO as GradCorrect +# from LibMTL.weighting.FAMO_unizero import FAMO as GradCorrect # from LibMTL.weighting.abstract_weighting import AbsWeighting @@ -457,12 +457,12 @@ def _init_learn(self) -> None: # ) # 将 wrapped_model 作为 share_model 传递给 GradCorrect self.task_num = self._cfg.task_num - self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device) - self.grad_correct.init_param() # 初始化MoCo参数 - self.grad_correct.rep_grad = False + # self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device) + # self.grad_correct.init_param() # 初始化MoCo参数 + # self.grad_correct.rep_grad = False # self.grad_correct.set_min_losses(torch.tensor([0. for i in range(self.task_num)], device=self._cfg.device)) # only for FAMO - self.curr_min_loss = torch.tensor([0. for i in range(self.task_num)], device=self._cfg.device) - self.grad_correct.prev_loss = self.curr_min_loss + # self.curr_min_loss = torch.tensor([0. for i in range(self.task_num)], device=self._cfg.device) + # self.grad_correct.prev_loss = self.curr_min_loss # @profile def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: diff --git a/zoo/atari/config/atari_unizero_multitask_config.py b/zoo/atari/config/atari_unizero_multitask_config.py index a182cc6a0..e69077db7 100644 --- a/zoo/atari/config/atari_unizero_multitask_config.py +++ b/zoo/atari/config/atari_unizero_multitask_config.py @@ -15,8 +15,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu manager=dict(shared_memory=False, ), full_action_space=True, # ===== only for debug ===== - collect_max_episode_steps=int(50), - eval_max_episode_steps=int(50), + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), ), policy=dict( grad_correct_params=dict( @@ -39,7 +39,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu max_blocks=num_unroll_steps, max_tokens=2 * num_unroll_steps, context_length=2 * infer_context_length, - device='cpu', # 'cuda', + # device='cpu', # 'cuda', + device='cuda', # 'cuda', action_space_size=action_space_size, num_layers=4, # NOTE num_heads=8, @@ -69,7 +70,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, seed): configs = [] - exp_name_prefix = f'data_unizero_mt_0705/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head_1-encoder-LN_lsd768-nlayer4-nh8_seed{seed}/' + exp_name_prefix = f'data_unizero_mt_0708/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head_1-encoder-LN_lsd768-nlayer4-nh8_seed{seed}/' for task_id, env_id in enumerate(env_id_list): config = create_config( @@ -140,5 +141,5 @@ def create_env_manager(): # Uncomment the desired training run # train_unizero_multitask(configs[:1], seed=seed, max_env_step=max_env_step) # Pong - train_unizero_multitask(configs[:2], seed=seed, max_env_step=max_env_step) # Pong, MsPacman - # train_unizero_multitask(configs, seed=seed, max_env_step=max_env_step) # Pong, MsPacman, Seaquest, Boxing \ No newline at end of file + # train_unizero_multitask(configs[:2], seed=seed, max_env_step=max_env_step) # Pong, MsPacman + train_unizero_multitask(configs, seed=seed, max_env_step=max_env_step) # Pong, MsPacman, Seaquest, Boxing \ No newline at end of file From c342ce17af1e708784522c2c8b7bd2ef9bed0cff Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Thu, 11 Jul 2024 20:18:29 +0800 Subject: [PATCH 03/13] fix(pu): fix empty_keys_values in init_infer --- lzero/model/unizero_world_models/transformer.py | 5 ++++- lzero/model/unizero_world_models/world_model_multitask.py | 2 +- zoo/atari/config/atari_unizero_multitask_config.py | 6 ++++-- zoo/atari/envs/atari_lightzero_env.py | 1 + 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index 714bc13d6..ee43056f0 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -205,7 +205,10 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, B, T, C = x.size() if kv_cache is not None: b, nh, L, c = kv_cache.shape - assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions do not match input dimensions." + try: + assert nh == self.num_heads and b == B and c * nh == C, "Cache dimensions do not match input dimensions." + except Exception as e: + print('debug') else: L = 0 diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index f1d6733d1..fe2128fac 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -484,7 +484,7 @@ def refresh_kvs_with_initial_latent_state_for_init_infer(self, latent_state: tor if current_obs_embeddings is not None: if max(buffer_action) == -1: # First step in an episode - self.keys_values_wm = self.transformer.generate_empty_keys_values(n=n, + self.keys_values_wm = self.transformer.generate_empty_keys_values(n=current_obs_embeddings.shape[0], max_tokens=self.context_length) # print(f"current_obs_embeddings.device: {current_obs_embeddings.device}") outputs_wm = self.forward({'obs_embeddings': current_obs_embeddings}, diff --git a/zoo/atari/config/atari_unizero_multitask_config.py b/zoo/atari/config/atari_unizero_multitask_config.py index e69077db7..75ee2bcac 100644 --- a/zoo/atari/config/atari_unizero_multitask_config.py +++ b/zoo/atari/config/atari_unizero_multitask_config.py @@ -17,6 +17,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu # ===== only for debug ===== # collect_max_episode_steps=int(50), # eval_max_episode_steps=int(50), + # collect_max_episode_steps=int(500), + # eval_max_episode_steps=int(500), ), policy=dict( grad_correct_params=dict( @@ -70,7 +72,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, seed): configs = [] - exp_name_prefix = f'data_unizero_mt_0708/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head_1-encoder-LN_lsd768-nlayer4-nh8_seed{seed}/' + exp_name_prefix = f'data_unizero_mt_0711/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head_1-encoder-LN_lsd768-nlayer4-nh8_seed{seed}/' for task_id, env_id in enumerate(env_id_list): config = create_config( @@ -135,7 +137,7 @@ def create_env_manager(): # n_episode = 3 # evaluator_env_num = 2 # num_simulations = 5 - # batch_size = [2, 2, 2, 2] + # batch_size = [4, 4, 4, 4] configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, seed) diff --git a/zoo/atari/envs/atari_lightzero_env.py b/zoo/atari/envs/atari_lightzero_env.py index f68f113d4..d66ed383d 100644 --- a/zoo/atari/envs/atari_lightzero_env.py +++ b/zoo/atari/envs/atari_lightzero_env.py @@ -153,6 +153,7 @@ def step(self, action: int) -> BaseEnvTimestep: observation = self.observe() if done: info['eval_episode_return'] = self._eval_episode_return + print(f'one episode of {self.cfg.env_id} done') return BaseEnvTimestep(observation, self.reward, done, info) From 6eb772a789db93447675382e9d4ed054b52fc734 Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Thu, 11 Jul 2024 21:06:09 +0800 Subject: [PATCH 04/13] feature(pu): add softmoe head option in unizero_multitask --- .../world_model_multitask.py | 103 +++++++++++++++--- .../config/atari_unizero_multitask_config.py | 3 +- 2 files changed, 87 insertions(+), 19 deletions(-) diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index fe2128fac..60073a79b 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -53,6 +53,8 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.head_rewards_multi_task = nn.ModuleList() self.head_observations_multi_task = nn.ModuleList() + self.num_experts_in_softmoe = config.num_experts_in_softmoe + # Move all modules to the specified device print(f"self.config.device: {self.config.device}") self.to(self.config.device) @@ -72,29 +74,38 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) print(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") - for task_id in range(self.task_num): # TODO - action_space_size = self.action_space_size # TODO:====================== - # action_space_size=18 # TODO:====================== - self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) - self.head_policy_multi_task.append(self.head_policy) + if self.num_experts_in_softmoe == -1: + print('We use normal head') + # TODO: Normal Head + for task_id in range(self.task_num): # TODO + action_space_size = self.action_space_size # TODO:====================== + # action_space_size=18 # TODO:====================== + self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) + self.head_policy_multi_task.append(self.head_policy) + + self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) + self.head_value_multi_task.append(self.head_value) + + self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) + self.head_rewards_multi_task.append(self.head_rewards) + + self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, + self.obs_per_embdding_dim, + self.sim_norm) # NOTE: we add a sim_norm to the head for observations + self.head_observations_multi_task.append(self.head_observations) + else: + print(f'We use softmoe head, self.num_experts_in_softmoe is {self.num_experts_in_softmoe}') + # Dictionary to store SoftMoE instances + self.soft_moe_instances = {} - self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) - self.head_value_multi_task.append(self.head_value) + # Create softmoe head modules + self.create_head_modules_softmoe() - self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) + self.head_policy_multi_task.append(self.head_policy) + self.head_value_multi_task.append(self.head_value) self.head_rewards_multi_task.append(self.head_rewards) - - self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, - self.obs_per_embdding_dim, - self.sim_norm) # NOTE: we add a sim_norm to the head for observations self.head_observations_multi_task.append(self.head_observations) - # Head modules - # self.head_rewards = self._create_head(self.act_tokens_pattern, self.support_size) - # self.head_observations = self._create_head(self.all_but_last_latent_state_pattern, self.obs_per_embdding_dim, - # self.sim_norm) # NOTE: we add a sim_norm to the head for observations - # self.head_policy = self._create_head(self.value_policy_tokens_pattern, self.action_space_size) - # self.head_value = self._create_head(self.value_policy_tokens_pattern, self.support_size) # Apply weight initialization, the order is important self.apply(lambda module: init_weights(module, norm_type=self.config.norm_type)) @@ -161,6 +172,62 @@ def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=Non head_module=nn.Sequential(*modules) ) + def _create_head_softmoe(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, soft_moe=None) -> Head: + """Create softmoe head modules for the transformer.""" + modules = [ + soft_moe, + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + + def get_soft_moe(self, name): + """Get or create a SoftMoE instance""" + from soft_moe_pytorch import SoftMoE + if name not in self.soft_moe_instances: + self.soft_moe_instances[name] = SoftMoE( + dim=self.embed_dim, + seq_len=20, # TODO + num_experts=self.num_experts_in_softmoe, + ) + return self.soft_moe_instances[name] + + def create_head_modules_softmoe(self): + """Create all softmoe head modules""" + # Rewards head + self.head_rewards = self._create_head_softmoe( + self.act_tokens_pattern, + self.support_size, + soft_moe=self.get_soft_moe("rewards_soft_moe") + ) + + # Observations head + self.head_observations = self._create_head_softmoe( + self.all_but_last_latent_state_pattern, + self.obs_per_embdding_dim, + norm_layer=self.sim_norm, # NOTE + soft_moe=self.get_soft_moe("observations_soft_moe") + ) + + # Policy head + self.head_policy = self._create_head_softmoe( + self.value_policy_tokens_pattern, + self.action_space_size, + soft_moe=self.get_soft_moe("policy_soft_moe") + ) + + # Value head + self.head_value = self._create_head_softmoe( + self.value_policy_tokens_pattern, + self.support_size, + soft_moe=self.get_soft_moe("value_soft_moe") + ) + def _initialize_last_layer(self) -> None: """Initialize the last linear layer.""" last_linear_layer_init_zero = True diff --git a/zoo/atari/config/atari_unizero_multitask_config.py b/zoo/atari/config/atari_unizero_multitask_config.py index 75ee2bcac..fec1f4013 100644 --- a/zoo/atari/config/atari_unizero_multitask_config.py +++ b/zoo/atari/config/atari_unizero_multitask_config.py @@ -52,6 +52,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu # collector_env_num=collector_env_num, # evaluator_env_num=evaluator_env_num, task_num=len(env_id_list), + num_experts_in_softmoe=4, # NOTE ), ), cuda=True, @@ -72,7 +73,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, seed): configs = [] - exp_name_prefix = f'data_unizero_mt_0711/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head_1-encoder-LN_lsd768-nlayer4-nh8_seed{seed}/' + exp_name_prefix = f'data_unizero_mt_0711/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head-softmoe4_1-encoder-LN_lsd768-nlayer4-nh8_seed{seed}/' for task_id, env_id in enumerate(env_id_list): config = create_config( From 71f55b4c8da523b3ea7d63f66fe1ccee19905d7c Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Fri, 12 Jul 2024 12:54:02 +0800 Subject: [PATCH 05/13] fix(pu): fix unizero reset in muzero_collector --- lzero/policy/unizero_multitask.py | 15 ++++++++------- lzero/worker/muzero_collector.py | 4 ++-- .../config/atari_unizero_multitask_config.py | 14 ++++++++------ 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py index d752fe148..17c653ce8 100644 --- a/lzero/policy/unizero_multitask.py +++ b/lzero/policy/unizero_multitask.py @@ -933,7 +933,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 return output - def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True) -> None: """ Overview: This method resets the collection process for a specific environment. It clears caches and memory @@ -955,8 +955,8 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in # print('collector: last_batch_obs, last_batch_action reset()', self.last_batch_obs.shape) # Return immediately if env_id is None or a list - if env_id is None or isinstance(env_id, list): - return + # if env_id is None or isinstance(env_id, list): + # return # Determine the clear interval based on the environment's sample type clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 @@ -978,6 +978,7 @@ def _reset_collect(self, env_id: int = None, current_steps: int = None, reset_in print('collector: collect_model clear()') print(f'eps_steps_lst[{env_id}]: {current_steps}') + # TODO: check its correctness self._reset_target_model() @@ -1001,7 +1002,7 @@ def _reset_target_model(self) -> None: torch.cuda.empty_cache() print('collector: target_model past_kv_cache.clear()') - def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_data: bool = True) -> None: + def _reset_eval(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True) -> None: """ Overview: This method resets the evaluation process for a specific environment. It clears caches and memory @@ -1020,11 +1021,11 @@ def _reset_eval(self, env_id: int = None, current_steps: int = None, reset_init_ self._cfg.device ) self.last_batch_action = [-1 for _ in range(self._cfg.evaluator_env_num)] - # print('evaluator: last_batch_obs, last_batch_action reset()', self.last_batch_obs.shape) + print('evaluator: last_batch_obs, last_batch_action reset()', self.last_batch_obs.shape) # Return immediately if env_id is None or a list - if env_id is None or isinstance(env_id, list): - return + # if env_id is None or isinstance(env_id, list): + # return # Determine the clear interval based on the environment's sample type clear_interval = 2000 if getattr(self._cfg, 'sample_type', '') == 'episode' else 200 diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index 81bc46047..54422cfa7 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -547,9 +547,9 @@ def collect(self, completed_value_lst[env_id] += np.mean(np.array(completed_value_dict[env_id])) eps_steps_lst[env_id] += 1 - if self._policy.get_attribute('cfg').type == 'unizero': + if self._policy.get_attribute('cfg').type in ['unizero', 'unizero_multitask']: # only for UniZero now - self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) + self._policy.reset(env_id=env_id, current_steps=eps_steps_lst[env_id], reset_init_data=False) # NOTE: reset_init_data=False total_transitions += 1 diff --git a/zoo/atari/config/atari_unizero_multitask_config.py b/zoo/atari/config/atari_unizero_multitask_config.py index fec1f4013..5e0ee7cf6 100644 --- a/zoo/atari/config/atari_unizero_multitask_config.py +++ b/zoo/atari/config/atari_unizero_multitask_config.py @@ -53,6 +53,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu # evaluator_env_num=evaluator_env_num, task_num=len(env_id_list), num_experts_in_softmoe=4, # NOTE + # num_experts_in_softmoe=-1, # NOTE ), ), cuda=True, @@ -73,7 +74,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, seed): configs = [] - exp_name_prefix = f'data_unizero_mt_0711/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head-softmoe4_1-encoder-LN_lsd768-nlayer4-nh8_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0711/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head-softmoe4_1-encoder-LN_lsd768-nlayer4-nh8_seed{seed}/' + exp_name_prefix = f'data_unizero_mt_0711_debug/{len(env_id_list)}games_1-head-softmoe4_1-encoder-LN_lsd768-nlayer4-nh8_seed{seed}/' for task_id, env_id in enumerate(env_id_list): config = create_config( @@ -134,11 +136,11 @@ def create_env_manager(): infer_context_length = 4 # ======== only for debug ======== - # collector_env_num = 3 - # n_episode = 3 - # evaluator_env_num = 2 - # num_simulations = 5 - # batch_size = [4, 4, 4, 4] + collector_env_num = 3 + n_episode = 3 + evaluator_env_num = 2 + num_simulations = 5 + batch_size = [4, 4, 4, 4] configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, seed) From 445fd70947bc08a10321564d16c89ac2ae4193c2 Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Sun, 14 Jul 2024 18:12:24 +0800 Subject: [PATCH 06/13] polish(pu): polish unizero-multitask config --- lzero/model/unizero_world_models/utils.py | 10 +++- .../world_model_multitask.py | 31 ++++++----- lzero/policy/unizero.py | 2 +- lzero/policy/unizero_multitask.py | 45 +++++++++++----- .../config/atari_unizero_multitask_config.py | 53 +++++++++++++------ 5 files changed, 96 insertions(+), 45 deletions(-) diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index d6f529971..e5bb3eaf0 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -109,8 +109,14 @@ def init_weights(module, norm_type='BN'): module.bias.data.zero_() elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): print(f"Init {module} using zero bias, 1 weight") - module.bias.data.zero_() - module.weight.data.fill_(1.0) + try: + module.bias.data.zero_() + except Exception as e: + print(e) + try: + module.weight.data.fill_(1.0) + except Exception as e: + print(e) elif isinstance(module, nn.BatchNorm2d): print(f"Init nn.BatchNorm2d using zero bias, 1 weight") module.weight.data.fill_(1.0) diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index 60073a79b..a997d3ebf 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -188,12 +188,19 @@ def _create_head_softmoe(self, block_mask: torch.Tensor, output_dim: int, norm_l def get_soft_moe(self, name): """Get or create a SoftMoE instance""" - from soft_moe_pytorch import SoftMoE + # from soft_moe_pytorch import SoftMoE + # if name not in self.soft_moe_instances: + # self.soft_moe_instances[name] = SoftMoE( + # dim=self.embed_dim, + # seq_len=20, # TODO + # num_experts=self.num_experts_in_softmoe, + # ) + from soft_moe_pytorch import DynamicSlotsSoftMoE as SoftMoE if name not in self.soft_moe_instances: self.soft_moe_instances[name] = SoftMoE( dim=self.embed_dim, - seq_len=20, # TODO num_experts=self.num_experts_in_softmoe, + geglu = True ) return self.soft_moe_instances[name] @@ -406,16 +413,16 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 # one head or soft_moe - logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) - logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) - logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) - - # N head - # logits_observations = self.head_observations_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) - # logits_rewards = self.head_rewards_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) - # logits_policy = self.head_policy_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) - # logits_value = self.head_value_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + # logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) + # logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) + # logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) + # logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) + + # TODO: N head + logits_observations = self.head_observations_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) # logits_ends is None return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index 0f1c0ad21..3f074ccb6 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -152,7 +152,7 @@ class UniZeroPolicy(MuZeroPolicy): # (bool) Whether to use the pure policy to collect data. collect_with_pure_policy=False, # (int) The evaluation frequency. - eval_freq=int(2e3), + eval_freq=int(5e3), # (str) The sample type. Options are ['episode', 'transition']. sample_type='transition', diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py index 17c653ce8..d1945cedb 100644 --- a/lzero/policy/unizero_multitask.py +++ b/lzero/policy/unizero_multitask.py @@ -18,9 +18,9 @@ from .utils import configure_optimizers_nanogpt sys.path.append('/Users/puyuan/code/LibMTL/') -# from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect +from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect # from LibMTL.weighting.CAGrad_unizero import CAGrad as GradCorrect -# from LibMTL.weighting.FAMO_unizero import FAMO as GradCorrect +# from LibMTL.weighting.FAMO_unizero import FAMO as GradCorrect # NOTE: FAMO have bugs now # from LibMTL.weighting.abstract_weighting import AbsWeighting @@ -159,7 +159,7 @@ class UniZeroMTPolicy(UniZeroPolicy): # (bool) whether to use res connection in dynamics. res_connection_in_dynamics=True, # (str) The type of normalization in MuZero model. Options are ['BN', 'LN']. Default to 'BN'. - norm_type='BN', + norm_type='LN', # (bool) Whether to analyze simulation normalization. analysis_sim_norm=False, # (int) The save interval of the model. @@ -254,7 +254,7 @@ class UniZeroMTPolicy(UniZeroPolicy): # (bool) Whether to use the pure policy to collect data. collect_with_pure_policy=False, # (int) The evaluation frequency. - eval_freq=int(2e3), + eval_freq=int(5e3), # (str) The sample type. Options are ['episode', 'transition']. sample_type='transition', @@ -434,10 +434,12 @@ def _init_learn(self) -> None: self.grad_norm_after = 0. # 创建 WrappedModel 实例 + # head和nn.Embedding 没有矫正梯度 # wrapped_model = WrappedModel( # self._learn_model.world_model.tokenizer, # self._learn_model.world_model.transformer # ) + # head 没有矫正梯度 wrapped_model = WrappedModelV2( # self._learn_model.world_model.tokenizer, # TODO self._learn_model.world_model.tokenizer.encoder[0], # TODO: one encoder @@ -446,21 +448,27 @@ def _init_learn(self) -> None: self._learn_model.world_model.task_emb, self._learn_model.world_model.act_embedding_table, ) - # wrapped_model = WrappedModelV3( + # 所有参数都共享,即所有参数都需要进行矫正 + # wrapped_model = WrappedModelV3( # self._learn_model.world_model, # ) + # head 和 tokenizer.encoder 没有矫正梯度 # wrapped_model = WrappedModelV4( # self._learn_model.world_model.transformer, # self._learn_model.world_model.pos_emb, # self._learn_model.world_model.task_emb, # self._learn_model.world_model.act_embedding_table, # ) + # 将 wrapped_model 作为 share_model 传递给 GradCorrect + # ========= 初始化 MoCo CAGrad 参数 ========= self.task_num = self._cfg.task_num - # self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device) - # self.grad_correct.init_param() # 初始化MoCo参数 - # self.grad_correct.rep_grad = False - # self.grad_correct.set_min_losses(torch.tensor([0. for i in range(self.task_num)], device=self._cfg.device)) # only for FAMO + self.grad_correct = GradCorrect(wrapped_model, self.task_num, self._cfg.device) + self.grad_correct.init_param() + self.grad_correct.rep_grad = False + + # =========only for FAMO ========= + # self.grad_correct.set_min_losses(torch.tensor([0. for i in range(self.task_num)], device=self._cfg.device)) # self.curr_min_loss = torch.tensor([0. for i in range(self.task_num)], device=self._cfg.device) # self.grad_correct.prev_loss = self.curr_min_loss @@ -493,6 +501,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # weighted_total_loss.requires_grad = True weighted_total_loss = 0.0 # 初始化为0,避免使用in-place操作 + latent_state_l2_norms_multi_task = [] + + average_target_policy_entropy_multi_task = [] losses_list = [] # 用于存储每个任务的损失 @@ -585,6 +596,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in value_loss = intermediate_losses['loss_value'] latent_recon_loss = intermediate_losses['latent_recon_loss'] perceptual_loss = intermediate_losses['perceptual_loss'] + latent_state_l2_norms = intermediate_losses['latent_state_l2_norms'] + obs_loss_multi_task.append(obs_loss) reward_loss_multi_task.append(reward_loss) @@ -595,12 +608,13 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in value_loss_multi_task.append(value_loss) latent_recon_loss_multi_task.append(latent_recon_loss) perceptual_loss_multi_task.append(perceptual_loss) + latent_state_l2_norms_multi_task.append(latent_state_l2_norms) # Core learn model update step self._optimizer_world_model.zero_grad() # TODO MoCo - # 使用MoCo来计算梯度和权重 + # 使用 MoCo 和 CAGrad 来计算梯度和权重 # ============= for CAGrad and MoCo ============= # lambd = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) @@ -670,6 +684,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in **generate_task_loss_dict(obs_loss_multi_task, 'obs_loss_task{}'), **generate_task_loss_dict(latent_recon_loss_multi_task, 'latent_recon_loss_task{}'), **generate_task_loss_dict(perceptual_loss_multi_task, 'perceptual_loss_task{}'), + **generate_task_loss_dict(latent_state_l2_norms_multi_task, 'latent_state_l2_norms_task{}'), + **generate_task_loss_dict(policy_loss_multi_task, 'policy_loss_task{}'), **generate_task_loss_dict(orig_policy_loss_multi_task, 'orig_policy_loss_task{}'), **generate_task_loss_dict(policy_entropy_multi_task, 'policy_entropy_task{}'), @@ -963,7 +979,7 @@ def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_ # Clear caches if the current steps are a multiple of the clear interval if current_steps % clear_interval == 0: - print(f'clear_interval: {clear_interval}') + # print(f'clear_interval: {clear_interval}') # Clear various caches in the collect model's world model world_model = self._collect_model.world_model @@ -978,9 +994,9 @@ def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_ print('collector: collect_model clear()') print(f'eps_steps_lst[{env_id}]: {current_steps}') - - # TODO: check its correctness - self._reset_target_model() + + # TODO: check its correctness + self._reset_target_model() def _reset_target_model(self) -> None: """ @@ -1078,6 +1094,7 @@ def _monitor_vars_learn(self, num_tasks=2) -> List[str]: 'reward_loss', 'value_loss', 'perceptual_loss', + 'latent_state_l2_norms', 'lambd', ] diff --git a/zoo/atari/config/atari_unizero_multitask_config.py b/zoo/atari/config/atari_unizero_multitask_config.py index 5e0ee7cf6..17fd8a710 100644 --- a/zoo/atari/config/atari_unizero_multitask_config.py +++ b/zoo/atari/config/atari_unizero_multitask_config.py @@ -2,7 +2,7 @@ from copy import deepcopy # from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map -def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length): +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type): return EasyDict(dict( env=dict( stop_value=int(1e6), @@ -37,6 +37,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu model=dict( observation_shape=(3, 64, 64), action_space_size=action_space_size, + norm_type=norm_type, world_model_cfg=dict( max_blocks=num_unroll_steps, max_tokens=2 * num_unroll_steps, @@ -52,8 +53,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu # collector_env_num=collector_env_num, # evaluator_env_num=evaluator_env_num, task_num=len(env_id_list), - num_experts_in_softmoe=4, # NOTE - # num_experts_in_softmoe=-1, # NOTE + # num_experts_in_softmoe=4, # NOTE + num_experts_in_softmoe=-1, # NOTE ), ), cuda=True, @@ -72,10 +73,13 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu ), )) -def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, seed): +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed): configs = [] - # exp_name_prefix = f'data_unizero_mt_0711/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head-softmoe4_1-encoder-LN_lsd768-nlayer4-nh8_seed{seed}/' - exp_name_prefix = f'data_unizero_mt_0711_debug/{len(env_id_list)}games_1-head-softmoe4_1-encoder-LN_lsd768-nlayer4-nh8_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0711/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head-softmoe4_1-encoder-{norm_type}_lsd768-nlayer4-nh8_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_1-head-softmoe4-dynamics_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_8-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_MOCO_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' for task_id, env_id in enumerate(env_id_list): config = create_config( @@ -91,7 +95,8 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod reanalyze_ratio, batch_size, num_unroll_steps, - infer_context_length + infer_context_length, + norm_type ) config.policy.task_id = task_id config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" @@ -115,14 +120,25 @@ def create_env_manager(): if __name__ == "__main__": from lzero.entry import train_unizero_multitask - + # TODO env_id_list = [ 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', - 'BoxingNoFrameskip-v4' + 'BoxingNoFrameskip-v4', ] + # env_id_list = [ + # 'PongNoFrameskip-v4', + # 'MsPacmanNoFrameskip-v4', + # 'SeaquestNoFrameskip-v4', + # 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', + # 'CrazyClimberNoFrameskip-v4', + # 'BreakoutNoFrameskip-v4', + # 'QbertNoFrameskip-v4', + # ] + action_space_size = 18 # Full action space seed = 0 collector_env_num = 8 @@ -131,18 +147,23 @@ def create_env_manager(): num_simulations = 50 max_env_step = int(1e6) reanalyze_ratio = 0. - batch_size = [32, 32, 32, 32] + # batch_size = [32, 32, 32, 32] + max_batch_size = 1500 + batch_size = [1500/len(env_id_list) for i in range(len(env_id_list))] num_unroll_steps = 10 infer_context_length = 4 + norm_type = 'LN' + # norm_type = 'BN' + # ======== only for debug ======== - collector_env_num = 3 - n_episode = 3 - evaluator_env_num = 2 - num_simulations = 5 - batch_size = [4, 4, 4, 4] + # collector_env_num = 3 + # n_episode = 3 + # evaluator_env_num = 2 + # num_simulations = 5 + # batch_size = [4, 4, 4, 4] - configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, seed) + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed) # Uncomment the desired training run # train_unizero_multitask(configs[:1], seed=seed, max_env_step=max_env_step) # Pong From 4954581604bcb8e19afc51c6964e49dd163ef10b Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Tue, 16 Jul 2024 16:49:28 +0800 Subject: [PATCH 07/13] fix(pu): fix replay ratio --- lzero/entry/train_unizero_multitask.py | 2 +- lzero/policy/unizero_multitask.py | 4 ++-- zoo/atari/config/atari_unizero_multitask_config.py | 4 +++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/lzero/entry/train_unizero_multitask.py b/lzero/entry/train_unizero_multitask.py index e93b1ecff..3697d1a9f 100644 --- a/lzero/entry/train_unizero_multitask.py +++ b/lzero/entry/train_unizero_multitask.py @@ -174,7 +174,7 @@ def train_unizero_multitask( update_per_collect = cfg.policy.update_per_collect if update_per_collect is None: collected_transitions_num = sum(len(game_segment) for game_segment in new_data[0]) - update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) # Update replay buffer replay_buffer.push_game_segments(new_data) diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py index d1945cedb..d24c53e1e 100644 --- a/lzero/policy/unizero_multitask.py +++ b/lzero/policy/unizero_multitask.py @@ -278,10 +278,10 @@ class UniZeroMTPolicy(UniZeroPolicy): # collect data -> update policy-> collect data -> ... # For different env, we have different episode_length, # we usually set update_per_collect = collector_env_num * episode_length / batch_size * reuse_factor. - # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.model_update_ratio automatically. + # If we set update_per_collect=None, we will set update_per_collect = collected_transitions_num * cfg.policy.replay_ratio automatically. update_per_collect=None, # (float) The ratio of the collected data used for training. Only effective when ``update_per_collect`` is not None. - model_update_ratio=0.25, + replay_ratio=0.25, # (int) Minibatch size for one gradient descent. batch_size=256, # (str) Optimizer for training policy network. ['SGD', 'Adam'] diff --git a/zoo/atari/config/atari_unizero_multitask_config.py b/zoo/atari/config/atari_unizero_multitask_config.py index 17fd8a710..c5eb62b01 100644 --- a/zoo/atari/config/atari_unizero_multitask_config.py +++ b/zoo/atari/config/atari_unizero_multitask_config.py @@ -21,6 +21,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu # eval_max_episode_steps=int(500), ), policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000,),),), # default is 10000 grad_correct_params=dict( # for MoCo MoCo_beta=0.5, @@ -76,6 +77,7 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed): configs = [] # exp_name_prefix = f'data_unizero_mt_0711/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head-softmoe4_1-encoder-{norm_type}_lsd768-nlayer4-nh8_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_1-head-softmoe4-dynamics_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_8-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_MOCO_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' @@ -149,7 +151,7 @@ def create_env_manager(): reanalyze_ratio = 0. # batch_size = [32, 32, 32, 32] max_batch_size = 1500 - batch_size = [1500/len(env_id_list) for i in range(len(env_id_list))] + batch_size = [int(1500/len(env_id_list)) for i in range(len(env_id_list))] num_unroll_steps = 10 infer_context_length = 4 norm_type = 'LN' From 44304bfd8fa380700157deb6bf6451b96299dd05 Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Tue, 16 Jul 2024 20:35:22 +0800 Subject: [PATCH 08/13] feature(pu): add moe option of feedforward in transformer backbone --- lzero/model/unizero_world_models/moe.py | 34 +++++++++++++++++++ .../model/unizero_world_models/transformer.py | 34 ++++++++++++++----- .../world_model_multitask.py | 12 +++---- lzero/policy/unizero_multitask.py | 5 ++- .../config/atari_unizero_multitask_config.py | 20 +++++++---- 5 files changed, 80 insertions(+), 25 deletions(-) create mode 100644 lzero/model/unizero_world_models/moe.py diff --git a/lzero/model/unizero_world_models/moe.py b/lzero/model/unizero_world_models/moe.py new file mode 100644 index 000000000..af2f8dac2 --- /dev/null +++ b/lzero/model/unizero_world_models/moe.py @@ -0,0 +1,34 @@ +import dataclasses +from typing import List + +import torch +import torch.nn.functional as F +from simple_parsing.helpers import Serializable +from torch import nn + + +@dataclasses.dataclass +class MoeArgs(Serializable): + num_experts: int + num_experts_per_tok: int + + +class MoeLayer(nn.Module): + def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok=1): + super().__init__() + assert len(experts) > 0 + self.experts = nn.ModuleList(experts) + self.gate = gate + self.num_experts_per_tok = num_experts_per_tok + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + gate_logits = self.gate(inputs) + weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok) + weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) + results = torch.zeros_like(inputs) + for i, expert in enumerate(self.experts): + # batch_idx, nth_expert = torch.where(selected_experts == i) + # results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(inputs[batch_idx]) + batch_idx, token_idx, nth_expert = torch.where(selected_experts == i) + results[batch_idx, token_idx] += weights[batch_idx, token_idx, nth_expert][:, None] * expert(inputs[batch_idx, token_idx]) + return results \ No newline at end of file diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index ee43056f0..0ad288fe2 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -13,7 +13,7 @@ from torch.nn import functional as F from .kv_caching import KeysValues - +from .moe import MoeLayer @dataclass class TransformerConfig: @@ -121,12 +121,28 @@ def __init__(self, config: TransformerConfig) -> None: self.ln1 = nn.LayerNorm(config.embed_dim) self.ln2 = nn.LayerNorm(config.embed_dim) self.attn = SelfAttention(config) - self.mlp = nn.Sequential( - nn.Linear(config.embed_dim, 4 * config.embed_dim), - nn.GELU(approximate='tanh'), - nn.Linear(4 * config.embed_dim, config.embed_dim), - nn.Dropout(config.resid_pdrop), - ) + if config.moe_in_transformer: + self.mlp = nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) + self.feed_forward = MoeLayer( + experts=[self.mlp for _ in range(config.num_experts_of_moe_in_transformer)], + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + print("="*20) + print('use moe in feed_forward of transformer') + print("="*20) + else: + self.feed_forward = nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: @@ -144,10 +160,10 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths) if self.gru_gating: x = self.gate1(x, x_attn) - x = self.gate2(x, self.mlp(self.ln2(x))) + x = self.gate2(x, self.feed_forward(self.ln2(x))) else: x = x + x_attn - x = x + self.mlp(self.ln2(x)) + x = x + self.feed_forward(self.ln2(x)) return x diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index a997d3ebf..e4cd24b44 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -53,7 +53,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.head_rewards_multi_task = nn.ModuleList() self.head_observations_multi_task = nn.ModuleList() - self.num_experts_in_softmoe = config.num_experts_in_softmoe + self.num_experts_in_softmoe_head = config.num_experts_in_softmoe_head # Move all modules to the specified device print(f"self.config.device: {self.config.device}") @@ -74,7 +74,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) print(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") - if self.num_experts_in_softmoe == -1: + if self.num_experts_in_softmoe_head == -1: print('We use normal head') # TODO: Normal Head for task_id in range(self.task_num): # TODO @@ -94,7 +94,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.sim_norm) # NOTE: we add a sim_norm to the head for observations self.head_observations_multi_task.append(self.head_observations) else: - print(f'We use softmoe head, self.num_experts_in_softmoe is {self.num_experts_in_softmoe}') + print(f'We use softmoe head, self.num_experts_in_softmoe_head is {self.num_experts_in_softmoe_head}') # Dictionary to store SoftMoE instances self.soft_moe_instances = {} @@ -193,13 +193,13 @@ def get_soft_moe(self, name): # self.soft_moe_instances[name] = SoftMoE( # dim=self.embed_dim, # seq_len=20, # TODO - # num_experts=self.num_experts_in_softmoe, + # num_experts=self.num_experts_in_softmoe_head, # ) from soft_moe_pytorch import DynamicSlotsSoftMoE as SoftMoE if name not in self.soft_moe_instances: self.soft_moe_instances[name] = SoftMoE( dim=self.embed_dim, - num_experts=self.num_experts_in_softmoe, + num_experts=self.num_experts_in_softmoe_head, geglu = True ) return self.soft_moe_instances[name] @@ -412,7 +412,7 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu # Generate logits # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - # one head or soft_moe + # TODO: one head or soft_moe # logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) # logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) # logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py index d24c53e1e..67912d5c3 100644 --- a/lzero/policy/unizero_multitask.py +++ b/lzero/policy/unizero_multitask.py @@ -628,9 +628,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # self.grad_correct.update(curr_loss.detach()) # self.grad_correct.prev_loss = curr_loss.detach() - # ============= 不使用梯度矫正的情况 ============= + # ============= TODO: 不使用梯度矫正的情况 ============= lambd = torch.tensor([0. for i in range(self.task_num)], device=self._cfg.device) - weighted_total_loss.backward() # ========== for debugging ========== @@ -1097,7 +1096,7 @@ def _monitor_vars_learn(self, num_tasks=2) -> List[str]: 'latent_state_l2_norms', 'lambd', ] - + num_tasks = self.task_num # If the number of tasks is provided, extend the monitored variables list with task-specific variables if num_tasks is not None: for var in task_specific_vars: diff --git a/zoo/atari/config/atari_unizero_multitask_config.py b/zoo/atari/config/atari_unizero_multitask_config.py index c5eb62b01..0e0f23a8c 100644 --- a/zoo/atari/config/atari_unizero_multitask_config.py +++ b/zoo/atari/config/atari_unizero_multitask_config.py @@ -54,8 +54,11 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu # collector_env_num=collector_env_num, # evaluator_env_num=evaluator_env_num, task_num=len(env_id_list), - # num_experts_in_softmoe=4, # NOTE - num_experts_in_softmoe=-1, # NOTE + # num_experts_in_softmoe_head=4, # NOTE + num_experts_in_softmoe_head=-1, # NOTE + # moe_in_transformer=True, + moe_in_transformer=False, + num_experts_of_moe_in_transformer=4, ), ), cuda=True, @@ -78,10 +81,13 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod configs = [] # exp_name_prefix = f'data_unizero_mt_0711/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head-softmoe4_1-encoder-{norm_type}_lsd768-nlayer4-nh8_seed{seed}/' - # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_1-head-softmoe4-dynamics_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716_debug/{len(env_id_list)}games_1-head-softmoe4-dynamics_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_8-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' - # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_MOCO_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' - exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_CAGrad_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_MoCo_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + + exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_trans-ffw-moe4_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_1-head_1-encoder-{norm_type}_trans-ffw-moe4_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' for task_id, env_id in enumerate(env_id_list): config = create_config( @@ -127,7 +133,7 @@ def create_env_manager(): 'PongNoFrameskip-v4', 'MsPacmanNoFrameskip-v4', 'SeaquestNoFrameskip-v4', - 'BoxingNoFrameskip-v4', + 'BoxingNoFrameskip-v4' ] # env_id_list = [ @@ -158,7 +164,7 @@ def create_env_manager(): # norm_type = 'BN' - # ======== only for debug ======== + # ======== TODO: only for debug ======== # collector_env_num = 3 # n_episode = 3 # evaluator_env_num = 2 From d6be21a2adf0cab8b5df6586b4160d1e526318ca Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Wed, 17 Jul 2024 16:47:41 +0800 Subject: [PATCH 09/13] feature(pu): add value_priority in unizero_multitask --- lzero/entry/train_unizero_multitask.py | 39 +++++++++++++++ lzero/mcts/buffer/game_buffer.py | 47 ++++++++++--------- lzero/mcts/buffer/game_buffer_muzero.py | 1 + lzero/model/unizero_world_models/utils.py | 2 +- .../world_model_multitask.py | 17 ++++++- lzero/policy/unizero_multitask.py | 24 ++++++++-- requirements.txt | 1 + .../config/atari_unizero_multitask_config.py | 24 ++++++---- 8 files changed, 118 insertions(+), 37 deletions(-) diff --git a/lzero/entry/train_unizero_multitask.py b/lzero/entry/train_unizero_multitask.py index 3697d1a9f..bc48fb08e 100644 --- a/lzero/entry/train_unizero_multitask.py +++ b/lzero/entry/train_unizero_multitask.py @@ -4,6 +4,7 @@ from typing import Tuple, Optional, List import torch +import numpy as np from ding.config import compile_config from ding.envs import create_env_manager, get_vec_env_setting from ding.policy import create_policy @@ -126,6 +127,7 @@ def train_unizero_multitask( evaluators.append(evaluator) learner.call_hook('before_run') + value_priority_tasks = {} while True: # Precompute positional embedding matrices for collect/eval (not training) @@ -207,6 +209,43 @@ def train_unizero_multitask( if train_data_multi_task: log_vars = learner.train(train_data_multi_task, envstep_multi_task) + + if cfg.policy.use_priority: + for task_id, replay_buffer in enumerate(game_buffers): + # Update the priority for the task-specific replay buffer. + replay_buffer.update_priority(train_data_multi_task[task_id], log_vars[0][f'value_priority_task{task_id}']) + + # Retrieve the updated priorities for the current task. + current_priorities = log_vars[0][f'value_priority_task{task_id}'] + + # Calculate statistics: mean, running mean, standard deviation for the priorities. + mean_priority = np.mean(current_priorities) + std_priority = np.std(current_priorities) + + # Using exponential moving average for running mean (alpha is the smoothing factor). + alpha = 0.1 # You can adjust this smoothing factor as needed. + if f'running_mean_priority_task{task_id}' not in value_priority_tasks: + # Initialize running mean if it does not exist. + value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority + else: + # Update running mean. + value_priority_tasks[f'running_mean_priority_task{task_id}'] = ( + alpha * mean_priority + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}'] + ) + + # Calculate the normalized priority using the running mean. + running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}'] + normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6) + + # Store the normalized priorities back to the replay buffer (if needed). + # replay_buffer.update_priority(train_data_multi_task[task_id], normalized_priorities) + + # Log the statistics if the print_task_priority_logs flag is set. + if cfg.policy.print_task_priority_logs: + print(f"Task {task_id} - Mean Priority: {mean_priority:.8f}, " + f"Running Mean Priority: {running_mean_priority:.8f}, " + f"Standard Deviation: {std_priority:.8f}") + if all(collector.envstep >= max_env_step for collector in collectors) or learner.train_iter >= max_train_iter: break diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index ec1babb6b..8ee143f03 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -102,47 +102,46 @@ def _make_batch(self, orig_data: Any, reanalyze_ratio: float) -> Tuple[Any]: """ pass - def _sample_orig_data(self, batch_size: int) -> Tuple: + def _sample_orig_data(self, batch_size: int, print_priority_logs: bool = False) -> Tuple: """ Overview: - sample orig_data that contains: - game_segment_list: a list of game segments - pos_in_game_segment_list: transition index in game (relative index) - batch_index_list: the index of start transition of sampled minibatch in replay buffer - weights_list: the weight concerning the priority - make_time: the time the batch is made (for correctly updating replay buffer when data is deleted) + Sample original data which includes: + - game_segment_list: A list of game segments. + - pos_in_game_segment_list: Transition index in the game (relative index). + - batch_index_list: The index of the start transition of the sampled mini-batch in the replay buffer. + - weights_list: The weight concerning the priority. + - make_time: The time the batch is made (for correctly updating the replay buffer when data is deleted). Arguments: - - batch_size (:obj:`int`): batch size - - beta: float the parameter in PER for calculating the priority + - batch_size (:obj:`int`): The size of the batch. + - print_priority_logs (:obj:`bool`): Whether to print logs related to priority statistics, defaults to False. """ - assert self._beta > 0 + assert self._beta > 0, "Beta should be greater than 0" num_of_transitions = self.get_num_of_transitions() - if self._cfg.use_priority is False: + if not self._cfg.use_priority: + # If priority is not used, set all priorities to 1 self.game_pos_priorities = np.ones_like(self.game_pos_priorities) # +1e-6 for numerical stability probs = self.game_pos_priorities ** self._alpha + 1e-6 probs /= probs.sum() - # sample according to transition index - # TODO(pu): replace=True - # print(f"num transitions is {num_of_transitions}") - # print(f"length of probs is {len(probs)}") + # Sample according to transition index with replacement batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False) - - if self._cfg.reanalyze_outdated is True: - # NOTE: used in reanalyze part + + if self._cfg.reanalyze_outdated: + # Sort the batch indices if reanalyze is enabled batch_index_list.sort() - + + # Calculate weights for the sampled transitions weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta) - weights_list /= weights_list.max() + weights_list /= weights_list.max() # Normalize weights game_segment_list = [] pos_in_game_segment_list = [] for idx in batch_index_list: game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx] - game_segment_idx -= self.base_idx + game_segment_idx -= self.base_idx # Adjust index based on base index game_segment = self.game_segment_buffer[game_segment_idx] game_segment_list.append(game_segment) @@ -151,6 +150,12 @@ def _sample_orig_data(self, batch_size: int) -> Tuple: make_time = [time.time() for _ in range(len(batch_index_list))] orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time) + + if print_priority_logs: + print(f"Sampled batch indices: {batch_index_list}") + print(f"Sampled priorities: {self.game_pos_priorities[batch_index_list]}") + print(f"Sampled weights: {weights_list}") + return orig_data def _sample_orig_reanalyze_data(self, batch_size: int) -> Tuple: diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 72876b9a5..d67890bf1 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -745,6 +745,7 @@ def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) - NOTE: train_data = [current_batch, target_batch] current_batch = [obs_list, action_list, improved_policy_list(only in Gumbel MuZero), mask_list, batch_index_list, weights, make_time_list] + target_batch = [batch_rewards, batch_target_values, batch_target_policies] """ indices = train_data[0][-3] metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities} diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index e5bb3eaf0..b04375ded 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -185,7 +185,7 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg self.loss_total += self.perceptual_loss_weight * v self.intermediate_losses = { - k: v if isinstance(v, dict) else (v if isinstance(v, float) else v.item()) + k: v if isinstance(v, dict) or isinstance(v, np.ndarray) else (v if isinstance(v, float) else v.item()) for k, v in kwargs.items() } diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index e4cd24b44..93ab87619 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -1015,8 +1015,9 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # Uncomment the lines below for visual analysis # original_images, reconstructed_images = batch['observations'], reconstructed_images # target_policy = batch['target_policy'] - # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( - # batch['observations'].shape[0], batch['observations'].shape[1], 1) + # ==== for value priority ==== + target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + batch['observations'].shape[0], batch['observations'].shape[1], 1) # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( # batch['observations'].shape[0], batch['observations'].shape[1], 1) # ========== for visualization ========== @@ -1142,6 +1143,17 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar loss_policy, orig_policy_loss, policy_entropy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') + + # ============ for value priority ============ + # transform the scaled value or its categorical representation to its original value, + # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. + original_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape( + batch['observations'].shape[0], batch['observations'].shape[1], 1) + # calculate the new priorities for each transition. + from torch.nn import L1Loss + value_priority = L1Loss(reduction='none')(original_value[:,0], target_predict_value[:, 0]) # TODO: mix of mean and sum + value_priority = value_priority.data.cpu().numpy() + 1e-6 + # ============ for value priority ============ # Compute timesteps timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) @@ -1212,6 +1224,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar dormant_ratio_encoder=dormant_ratio_encoder, dormant_ratio_world_model=dormant_ratio_world_model, latent_state_l2_norms=latent_state_l2_norms, + value_priority=value_priority, ) def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py index 67912d5c3..a8f6f66c2 100644 --- a/lzero/policy/unizero_multitask.py +++ b/lzero/policy/unizero_multitask.py @@ -35,7 +35,10 @@ def generate_task_loss_dict(multi_task_losses, task_name_template): task_loss_dict = {} for task_idx, task_loss in enumerate(multi_task_losses): task_name = task_name_template.format(task_idx) - task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + try: + task_loss_dict[task_name] = task_loss.item() if hasattr(task_loss, 'item') else task_loss + except Exception as e: + task_loss_dict[task_name] = task_loss return task_loss_dict @@ -505,6 +508,9 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in average_target_policy_entropy_multi_task = [] + value_priority_multi_task = [] + value_priority_mean_multi_task = [] + losses_list = [] # 用于存储每个任务的损失 for task_id, data_one_task in enumerate(data): @@ -597,7 +603,7 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in latent_recon_loss = intermediate_losses['latent_recon_loss'] perceptual_loss = intermediate_losses['perceptual_loss'] latent_state_l2_norms = intermediate_losses['latent_state_l2_norms'] - + value_priority = intermediate_losses['value_priority'] obs_loss_multi_task.append(obs_loss) reward_loss_multi_task.append(reward_loss) @@ -609,12 +615,14 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in latent_recon_loss_multi_task.append(latent_recon_loss) perceptual_loss_multi_task.append(perceptual_loss) latent_state_l2_norms_multi_task.append(latent_state_l2_norms) + value_priority_multi_task.append(value_priority) + value_priority_mean_multi_task.append(value_priority.mean().item()) + # Core learn model update step self._optimizer_world_model.zero_grad() - # TODO MoCo - # 使用 MoCo 和 CAGrad 来计算梯度和权重 + # TODO 使用 MoCo 和 CAGrad 来计算梯度和权重 # ============= for CAGrad and MoCo ============= # lambd = self.grad_correct.backward(losses=losses_list, **self._cfg.grad_correct_params) @@ -691,7 +699,12 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in **generate_task_loss_dict(reward_loss_multi_task, 'reward_loss_task{}'), **generate_task_loss_dict(value_loss_multi_task, 'value_loss_task{}'), **generate_task_loss_dict(average_target_policy_entropy_multi_task, 'target_policy_entropy_task{}'), - **generate_task_loss_dict(lambd, 'lambd_task{}'), + **generate_task_loss_dict(lambd, 'lambd_task{}'), + # ============================================================== + # priority related + # ============================================================== + **generate_task_loss_dict(value_priority_multi_task, 'value_priority_task{}'), + **generate_task_loss_dict(value_priority_mean_multi_task, 'value_priority_mean_task{}'), } # 合并两个字典 @@ -1095,6 +1108,7 @@ def _monitor_vars_learn(self, num_tasks=2) -> List[str]: 'perceptual_loss', 'latent_state_l2_norms', 'lambd', + 'value_priority_mean', ] num_tasks = self.task_num # If the number of tasks is provided, extend the monitored variables list with task-specific variables diff --git a/requirements.txt b/requirements.txt index e3534039e..46b28b157 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ pycolab pytest pooltool-billiards>=0.3.1 line_profiler +simple_parsing diff --git a/zoo/atari/config/atari_unizero_multitask_config.py b/zoo/atari/config/atari_unizero_multitask_config.py index 0e0f23a8c..343948965 100644 --- a/zoo/atari/config/atari_unizero_multitask_config.py +++ b/zoo/atari/config/atari_unizero_multitask_config.py @@ -15,6 +15,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu manager=dict(shared_memory=False, ), full_action_space=True, # ===== only for debug ===== + # collect_max_episode_steps=int(30), + # eval_max_episode_steps=int(30), # collect_max_episode_steps=int(50), # eval_max_episode_steps=int(50), # collect_max_episode_steps=int(500), @@ -57,14 +59,20 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu # num_experts_in_softmoe_head=4, # NOTE num_experts_in_softmoe_head=-1, # NOTE # moe_in_transformer=True, - moe_in_transformer=False, - num_experts_of_moe_in_transformer=4, + moe_in_transformer=False, # NOTE + # num_experts_of_moe_in_transformer=4, + num_experts_of_moe_in_transformer=1, ), ), + use_priority=False, + print_task_priority_logs=False, + # use_priority=True, + # print_task_priority_logs=True, cuda=True, model_path=None, num_unroll_steps=num_unroll_steps, - update_per_collect=None, + # update_per_collect=None, + update_per_collect=1000, replay_ratio=0.25, batch_size=batch_size, optim_type='AdamW', @@ -82,11 +90,11 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod # exp_name_prefix = f'data_unizero_mt_0711/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head-softmoe4_1-encoder-{norm_type}_lsd768-nlayer4-nh8_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716_debug/{len(env_id_list)}games_1-head-softmoe4-dynamics_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' - # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_8-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_upc1000_value-priority_seed{seed}/' + exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_CAGrad_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_MoCo_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' - - exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_trans-ffw-moe4_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_trans-ffw-moe1_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_1-head_1-encoder-{norm_type}_trans-ffw-moe4_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' for task_id, env_id in enumerate(env_id_list): @@ -161,14 +169,14 @@ def create_env_manager(): num_unroll_steps = 10 infer_context_length = 4 norm_type = 'LN' - # norm_type = 'BN' + # norm_type = 'BN' # bad performance now # ======== TODO: only for debug ======== # collector_env_num = 3 # n_episode = 3 # evaluator_env_num = 2 - # num_simulations = 5 + # num_simulations = 2 # batch_size = [4, 4, 4, 4] configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed) From fde51ccc4c809a058003e524749df5ffc64d174a Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Wed, 17 Jul 2024 17:14:20 +0800 Subject: [PATCH 10/13] polish(pu): polish value_priority in unizero_multitask --- lzero/model/unizero_world_models/utils.py | 2 +- .../world_model_multitask.py | 17 +++---------- lzero/policy/unizero_multitask.py | 12 ++++++++- .../config/atari_unizero_multitask_config.py | 25 ++++++++++--------- 4 files changed, 28 insertions(+), 28 deletions(-) diff --git a/lzero/model/unizero_world_models/utils.py b/lzero/model/unizero_world_models/utils.py index b04375ded..6e39345e9 100644 --- a/lzero/model/unizero_world_models/utils.py +++ b/lzero/model/unizero_world_models/utils.py @@ -185,7 +185,7 @@ def __init__(self, latent_recon_loss_weight=0, perceptual_loss_weight=0, **kwarg self.loss_total += self.perceptual_loss_weight * v self.intermediate_losses = { - k: v if isinstance(v, dict) or isinstance(v, np.ndarray) else (v if isinstance(v, float) else v.item()) + k: v if isinstance(v, dict) or isinstance(v, np.ndarray) or isinstance(v, torch.Tensor) else (v if isinstance(v, float) else v.item()) for k, v in kwargs.items() } diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index 93ab87619..31de71ef6 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -1016,8 +1016,8 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # original_images, reconstructed_images = batch['observations'], reconstructed_images # target_policy = batch['target_policy'] # ==== for value priority ==== - target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( - batch['observations'].shape[0], batch['observations'].shape[1], 1) + # target_predict_value = inverse_scalar_transform_handle(batch['target_value'].reshape(-1, 101)).reshape( + # batch['observations'].shape[0], batch['observations'].shape[1], 1) # true_rewards = inverse_scalar_transform_handle(batch['rewards'].reshape(-1, 101)).reshape( # batch['observations'].shape[0], batch['observations'].shape[1], 1) # ========== for visualization ========== @@ -1143,17 +1143,6 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar loss_policy, orig_policy_loss, policy_entropy = self.compute_cross_entropy_loss(outputs, labels_policy, batch, element='policy') loss_value = self.compute_cross_entropy_loss(outputs, labels_value, batch, element='value') - - # ============ for value priority ============ - # transform the scaled value or its categorical representation to its original value, - # i.e. h^(-1)(.) function in paper https://arxiv.org/pdf/1805.11593.pdf. - original_value = inverse_scalar_transform_handle(outputs.logits_value.reshape(-1, 101)).reshape( - batch['observations'].shape[0], batch['observations'].shape[1], 1) - # calculate the new priorities for each transition. - from torch.nn import L1Loss - value_priority = L1Loss(reduction='none')(original_value[:,0], target_predict_value[:, 0]) # TODO: mix of mean and sum - value_priority = value_priority.data.cpu().numpy() + 1e-6 - # ============ for value priority ============ # Compute timesteps timesteps = torch.arange(batch['actions'].shape[1], device=batch['actions'].device) @@ -1224,7 +1213,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar dormant_ratio_encoder=dormant_ratio_encoder, dormant_ratio_world_model=dormant_ratio_world_model, latent_state_l2_norms=latent_state_l2_norms, - value_priority=value_priority, + logits_value=outputs.logits_value, ) def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py index a8f6f66c2..b7cc67737 100644 --- a/lzero/policy/unizero_multitask.py +++ b/lzero/policy/unizero_multitask.py @@ -603,7 +603,17 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in latent_recon_loss = intermediate_losses['latent_recon_loss'] perceptual_loss = intermediate_losses['perceptual_loss'] latent_state_l2_norms = intermediate_losses['latent_state_l2_norms'] - value_priority = intermediate_losses['value_priority'] + # value_priority = intermediate_losses['value_priority'] + logits_value = intermediate_losses['logits_value'] + + # ============ for value priority ============ + # transform the categorical representation of the scaled value to its original value + original_value = self.inverse_scalar_transform_handle(logits_value.reshape(-1, 101)).reshape( + batch_for_gpt['observations'].shape[0], batch_for_gpt['observations'].shape[1], 1) + # calculate the new priorities for each transition. + value_priority = torch.nn.L1Loss(reduction='none')(original_value.squeeze(-1)[:,0], target_value[:, 0]) # TODO: mix of mean and sum + value_priority = value_priority.data.cpu().numpy() + 1e-6 + # ============ for value priority ============ obs_loss_multi_task.append(obs_loss) reward_loss_multi_task.append(reward_loss) diff --git a/zoo/atari/config/atari_unizero_multitask_config.py b/zoo/atari/config/atari_unizero_multitask_config.py index 343948965..4597e87b3 100644 --- a/zoo/atari/config/atari_unizero_multitask_config.py +++ b/zoo/atari/config/atari_unizero_multitask_config.py @@ -15,8 +15,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu manager=dict(shared_memory=False, ), full_action_space=True, # ===== only for debug ===== - # collect_max_episode_steps=int(30), - # eval_max_episode_steps=int(30), + collect_max_episode_steps=int(30), + eval_max_episode_steps=int(30), # collect_max_episode_steps=int(50), # eval_max_episode_steps=int(50), # collect_max_episode_steps=int(500), @@ -64,10 +64,10 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu num_experts_of_moe_in_transformer=1, ), ), - use_priority=False, - print_task_priority_logs=False, - # use_priority=True, - # print_task_priority_logs=True, + # use_priority=False, + # print_task_priority_logs=False, + use_priority=True, + print_task_priority_logs=True, cuda=True, model_path=None, num_unroll_steps=num_unroll_steps, @@ -87,11 +87,12 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed): configs = [] + # TODO # exp_name_prefix = f'data_unizero_mt_0711/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head-softmoe4_1-encoder-{norm_type}_lsd768-nlayer4-nh8_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716_debug/{len(env_id_list)}games_1-head-softmoe4-dynamics_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_upc1000_value-priority_seed{seed}/' - exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' + exp_name_prefix = f'data_unizero_mt_0716_debug/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_CAGrad_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_MoCo_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_trans-ffw-moe1_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' @@ -173,11 +174,11 @@ def create_env_manager(): # ======== TODO: only for debug ======== - # collector_env_num = 3 - # n_episode = 3 - # evaluator_env_num = 2 - # num_simulations = 2 - # batch_size = [4, 4, 4, 4] + collector_env_num = 3 + n_episode = 3 + evaluator_env_num = 2 + num_simulations = 2 + batch_size = [4, 4, 4, 4] configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed) From b460d2f687f92bbed3324e4e668bf200c4154231 Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Fri, 19 Jul 2024 03:12:36 +0800 Subject: [PATCH 11/13] sync code --- lzero/model/unizero_world_models/moe.py | 4 + lzero/model/unizero_world_models/test_moe.py | 104 +++++++++++++++++ .../model/unizero_world_models/test_moe_v2.py | 107 ++++++++++++++++++ .../model/unizero_world_models/transformer.py | 2 +- .../config/atari_unizero_multitask_config.py | 49 ++++---- 5 files changed, 242 insertions(+), 24 deletions(-) create mode 100644 lzero/model/unizero_world_models/test_moe.py create mode 100644 lzero/model/unizero_world_models/test_moe_v2.py diff --git a/lzero/model/unizero_world_models/moe.py b/lzero/model/unizero_world_models/moe.py index af2f8dac2..68ccfe7cd 100644 --- a/lzero/model/unizero_world_models/moe.py +++ b/lzero/model/unizero_world_models/moe.py @@ -22,6 +22,10 @@ def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_to self.num_experts_per_tok = num_experts_per_tok def forward(self, inputs: torch.Tensor) -> torch.Tensor: + # if len(self.experts) == 1: + # # 只有一个专家时,直接使用该专家 + # return self.experts[0](inputs) + gate_logits = self.gate(inputs) weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok) weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) diff --git a/lzero/model/unizero_world_models/test_moe.py b/lzero/model/unizero_world_models/test_moe.py new file mode 100644 index 000000000..e842f6552 --- /dev/null +++ b/lzero/model/unizero_world_models/test_moe.py @@ -0,0 +1,104 @@ +import dataclasses +from typing import List + +import torch +import torch.nn.functional as F +from simple_parsing.helpers import Serializable +from torch import nn + +# 定义MoeArgs数据类,用于存储MoE的配置参数 +@dataclasses.dataclass +class MoeArgs(Serializable): + num_experts: int + num_experts_per_tok: int + +# 定义Mixture of Experts(MoE)层 +class MoeLayer(nn.Module): + def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok=1): + super().__init__() + assert len(experts) > 0 + self.experts = nn.ModuleList(experts) + self.gate = gate + self.num_experts_per_tok = num_experts_per_tok + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + gate_logits = self.gate(inputs) + weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok) + weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) + results = torch.zeros_like(inputs) + for i, expert in enumerate(self.experts): + batch_idx, token_idx, nth_expert = torch.where(selected_experts == i) + results[batch_idx, token_idx] += weights[batch_idx, token_idx, nth_expert][:, None] * expert(inputs[batch_idx, token_idx]) + return results + +# 定义一个简单的Transformer块 +class TransformerBlock(nn.Module): + def __init__(self, config): + super().__init__() + if config.moe_in_transformer: + self.mlp = nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) + self.feed_forward = MoeLayer( + experts=[self.mlp for _ in range(config.num_experts_of_moe_in_transformer)], + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + print("="*20) + print('使用MoE在Transformer的feed_forward中') + print("="*20) + else: + self.feed_forward = nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) + + def forward(self, x): + return self.feed_forward(x) + +# 定义配置类 +class Config: + def __init__(self, embed_dim, resid_pdrop, num_experts_of_moe_in_transformer, moe_in_transformer): + self.embed_dim = embed_dim + self.resid_pdrop = resid_pdrop + self.num_experts_of_moe_in_transformer = num_experts_of_moe_in_transformer + self.moe_in_transformer = moe_in_transformer + +# 测试代码 +def test_transformer_block(): + # 初始化配置 + embed_dim = 64 + resid_pdrop = 0.1 + num_experts_of_moe_in_transformer = 1 + moe_in_transformer_values = [True, False] + + # 创建输入数据 + inputs = torch.randn(10, 5, embed_dim) # (batch_size, seq_len, embed_dim) + + # 对于moe_in_transformer为True和False分别进行测试 + for moe_in_transformer in moe_in_transformer_values: + config = Config(embed_dim, resid_pdrop, num_experts_of_moe_in_transformer, moe_in_transformer) + transformer_block = TransformerBlock(config) + + outputs = transformer_block(inputs) + print(f"moe_in_transformer={moe_in_transformer}: outputs={outputs}") + + if moe_in_transformer: + outputs_true = outputs + else: + outputs_false = outputs + + # 计算输出的差异 + mse_difference = None + if outputs_true is not None and outputs_false is not None: + mse_difference = F.mse_loss(outputs_true, outputs_false).item() + + print(f"输出差异的均方误差(MSE): {mse_difference}") + +if __name__ == "__main__": + test_transformer_block() \ No newline at end of file diff --git a/lzero/model/unizero_world_models/test_moe_v2.py b/lzero/model/unizero_world_models/test_moe_v2.py new file mode 100644 index 000000000..6ab93cc16 --- /dev/null +++ b/lzero/model/unizero_world_models/test_moe_v2.py @@ -0,0 +1,107 @@ +import dataclasses +from typing import List + +import torch +import torch.nn.functional as F +from simple_parsing.helpers import Serializable +from torch import nn + +# 定义MoeArgs数据类,用于存储MoE的配置参数 +@dataclasses.dataclass +class MoeArgs(Serializable): + num_experts: int + num_experts_per_tok: int + +# 定义Mixture of Experts(MoE)层 +class MoeLayer(nn.Module): + def __init__(self, experts: List[nn.Module], gate: nn.Module, num_experts_per_tok=1): + super().__init__() + assert len(experts) > 0 + self.experts = nn.ModuleList(experts) + self.gate = gate + self.num_experts_per_tok = num_experts_per_tok + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + if len(self.experts) == 1: + # 只有一个专家时,直接使用该专家 + return self.experts[0](inputs) + + gate_logits = self.gate(inputs) + weights, selected_experts = torch.topk(gate_logits, self.num_experts_per_tok) + weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype) + results = torch.zeros_like(inputs) + for i, expert in enumerate(self.experts): + batch_idx, token_idx, nth_expert = torch.where(selected_experts == i) + results[batch_idx, token_idx] += weights[batch_idx, token_idx, nth_expert][:, None] * expert(inputs[batch_idx, token_idx]) + return results + +# 定义一个简单的Transformer块 +class TransformerBlock(nn.Module): + def __init__(self, config): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) + + if config.moe_in_transformer: + self.feed_forward = MoeLayer( + experts=[self.mlp for _ in range(config.num_experts_of_moe_in_transformer)], + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + print("="*20) + print('使用MoE在Transformer的feed_forward中') + print("="*20) + else: + self.feed_forward = self.mlp + + def forward(self, x): + return self.feed_forward(x) + +# 定义配置类 +class Config: + def __init__(self, embed_dim, resid_pdrop, num_experts_of_moe_in_transformer, moe_in_transformer): + self.embed_dim = embed_dim + self.resid_pdrop = resid_pdrop + self.num_experts_of_moe_in_transformer = num_experts_of_moe_in_transformer + self.moe_in_transformer = moe_in_transformer + +# 测试代码 +def test_transformer_block(): + # 初始化配置 + embed_dim = 64 + resid_pdrop = 0.1 + num_experts_of_moe_in_transformer = 1 + + # 创建输入数据 + inputs = torch.randn(10, 5, embed_dim) # (batch_size, seq_len, embed_dim) + + # 初始化两个输出变量 + outputs_true = None + outputs_false = None + + # 对于moe_in_transformer为True和False分别进行测试 + for moe_in_transformer in [True, False]: + config = Config(embed_dim, resid_pdrop, num_experts_of_moe_in_transformer, moe_in_transformer) + transformer_block = TransformerBlock(config) + + outputs = transformer_block(inputs) + print(f"moe_in_transformer={moe_in_transformer}: outputs={outputs}") + + if moe_in_transformer: + outputs_true = outputs + else: + outputs_false = outputs + + # 计算输出的差异 + mse_difference = None + if outputs_true is not None and outputs_false is not None: + mse_difference = F.mse_loss(outputs_true, outputs_false).item() + + print(f"输出差异的均方误差(MSE): {mse_difference}") + +if __name__ == "__main__": + test_transformer_block() \ No newline at end of file diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index 0ad288fe2..bcfff389e 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -134,7 +134,7 @@ def __init__(self, config: TransformerConfig) -> None: num_experts_per_tok=1, ) print("="*20) - print('use moe in feed_forward of transformer') + print(f'use moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') print("="*20) else: self.feed_forward = nn.Sequential( diff --git a/zoo/atari/config/atari_unizero_multitask_config.py b/zoo/atari/config/atari_unizero_multitask_config.py index 4597e87b3..fc77f71c9 100644 --- a/zoo/atari/config/atari_unizero_multitask_config.py +++ b/zoo/atari/config/atari_unizero_multitask_config.py @@ -15,8 +15,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu manager=dict(shared_memory=False, ), full_action_space=True, # ===== only for debug ===== - collect_max_episode_steps=int(30), - eval_max_episode_steps=int(30), + # collect_max_episode_steps=int(30), + # eval_max_episode_steps=int(30), # collect_max_episode_steps=int(50), # eval_max_episode_steps=int(50), # collect_max_episode_steps=int(500), @@ -52,22 +52,23 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu num_heads=8, embed_dim=768, obs_type='image', - env_num=max(collector_env_num, evaluator_env_num), + # env_num=max(collector_env_num, evaluator_env_num), + env_num=8, # TODO: the max of all tasks # collector_env_num=collector_env_num, # evaluator_env_num=evaluator_env_num, task_num=len(env_id_list), # num_experts_in_softmoe_head=4, # NOTE num_experts_in_softmoe_head=-1, # NOTE - # moe_in_transformer=True, - moe_in_transformer=False, # NOTE + moe_in_transformer=True, + # moe_in_transformer=False, # NOTE # num_experts_of_moe_in_transformer=4, - num_experts_of_moe_in_transformer=1, + num_experts_of_moe_in_transformer=2, ), ), - # use_priority=False, - # print_task_priority_logs=False, - use_priority=True, - print_task_priority_logs=True, + use_priority=False, + print_task_priority_logs=False, + # use_priority=True, # TODO + # print_task_priority_logs=True, cuda=True, model_path=None, num_unroll_steps=num_unroll_steps, @@ -92,22 +93,24 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod # exp_name_prefix = f'data_unizero_mt_0716_debug/{len(env_id_list)}games_1-head-softmoe4-dynamics_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_upc1000_value-priority_seed{seed}/' - exp_name_prefix = f'data_unizero_mt_0716_debug/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716_debug/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_CAGrad_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_MoCo_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' - # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_trans-ffw-moe1_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_trans-ffw-moe1-same_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' + exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_trans-ffw-moe2_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_1-head_1-encoder-{norm_type}_trans-ffw-moe4_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' for task_id, env_id in enumerate(env_id_list): config = create_config( env_id, action_space_size, - # collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing - # evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, - # n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, - collector_env_num, - evaluator_env_num, - n_episode, + collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing + evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + # collector_env_num, + # evaluator_env_num, + # n_episode, num_simulations, reanalyze_ratio, batch_size, @@ -174,11 +177,11 @@ def create_env_manager(): # ======== TODO: only for debug ======== - collector_env_num = 3 - n_episode = 3 - evaluator_env_num = 2 - num_simulations = 2 - batch_size = [4, 4, 4, 4] + # collector_env_num = 3 + # n_episode = 3 + # evaluator_env_num = 2 + # num_simulations = 2 + # batch_size = [4, 4, 4, 4] configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed) From 5117459807897367b3c473cd31f34dbc38a583ac Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Fri, 19 Jul 2024 16:37:41 +0800 Subject: [PATCH 12/13] fix(pu): fix moe in feedforward layer of transformer and polish configs --- .../model/unizero_world_models/transformer.py | 19 ++++++++------ .../config/atari_unizero_multitask_config.py | 25 +++++++++++-------- 2 files changed, 26 insertions(+), 18 deletions(-) diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index bcfff389e..d509aa2f5 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -122,17 +122,22 @@ def __init__(self, config: TransformerConfig) -> None: self.ln2 = nn.LayerNorm(config.embed_dim) self.attn = SelfAttention(config) if config.moe_in_transformer: - self.mlp = nn.Sequential( - nn.Linear(config.embed_dim, 4 * config.embed_dim), - nn.GELU(approximate='tanh'), - nn.Linear(4 * config.embed_dim, config.embed_dim), - nn.Dropout(config.resid_pdrop), - ) + # 创建多个独立的 MLP 实例 + self.experts = nn.ModuleList([ + nn.Sequential( + nn.Linear(config.embed_dim, 4 * config.embed_dim), + nn.GELU(approximate='tanh'), + nn.Linear(4 * config.embed_dim, config.embed_dim), + nn.Dropout(config.resid_pdrop), + ) for _ in range(config.num_experts_of_moe_in_transformer) + ]) + self.feed_forward = MoeLayer( - experts=[self.mlp for _ in range(config.num_experts_of_moe_in_transformer)], + experts=self.experts, gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), num_experts_per_tok=1, ) + print("="*20) print(f'use moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') print("="*20) diff --git a/zoo/atari/config/atari_unizero_multitask_config.py b/zoo/atari/config/atari_unizero_multitask_config.py index fc77f71c9..66f7d148b 100644 --- a/zoo/atari/config/atari_unizero_multitask_config.py +++ b/zoo/atari/config/atari_unizero_multitask_config.py @@ -48,7 +48,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu # device='cpu', # 'cuda', device='cuda', # 'cuda', action_space_size=action_space_size, - num_layers=4, # NOTE + # num_layers=4, # NOTE + num_layers=2, # NOTE num_heads=8, embed_dim=768, obs_type='image', @@ -59,21 +60,22 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu task_num=len(env_id_list), # num_experts_in_softmoe_head=4, # NOTE num_experts_in_softmoe_head=-1, # NOTE - moe_in_transformer=True, - # moe_in_transformer=False, # NOTE - # num_experts_of_moe_in_transformer=4, - num_experts_of_moe_in_transformer=2, + # moe_in_transformer=True, + moe_in_transformer=False, # NOTE + num_experts_of_moe_in_transformer=4, + # num_experts_of_moe_in_transformer=2, ), ), - use_priority=False, + # use_priority=False, + # print_task_priority_logs=False, + use_priority=True, # TODO print_task_priority_logs=False, - # use_priority=True, # TODO - # print_task_priority_logs=True, cuda=True, model_path=None, num_unroll_steps=num_unroll_steps, # update_per_collect=None, - update_per_collect=1000, + # update_per_collect=1000, + update_per_collect=500, replay_ratio=0.25, batch_size=batch_size, optim_type='AdamW', @@ -97,7 +99,8 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_CAGrad_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_MoCo_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_trans-ffw-moe1-same_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' - exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_trans-ffw-moe2_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0719/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_trans-ffw-moe4_lsd768-nlayer2-nh8_max-bs1500_upc1000_seed{seed}/' + exp_name_prefix = f'data_unizero_mt_0719/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_lsd768-nlayer2-nh8_max-bs1500_upc500_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_1-head_1-encoder-{norm_type}_trans-ffw-moe4_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' @@ -173,7 +176,7 @@ def create_env_manager(): num_unroll_steps = 10 infer_context_length = 4 norm_type = 'LN' - # norm_type = 'BN' # bad performance now + # # norm_type = 'BN' # bad performance now # ======== TODO: only for debug ======== From 2495d6099a61ea8e904bf130f7ea2e2e64469a7e Mon Sep 17 00:00:00 2001 From: dyyoungg Date: Tue, 23 Jul 2024 15:10:58 +0800 Subject: [PATCH 13/13] feature(pu): add mistralai moe in transformer feedforward and head of unizero --- lzero/entry/train_unizero_multitask.py | 2 + lzero/mcts/buffer/game_buffer_unizero.py | 6 + lzero/mcts/tree_search/mcts_ctree.py | 7 +- lzero/model/unizero_model_multitask.py | 4 + lzero/model/unizero_world_models/moe.py | 11 + .../model/unizero_world_models/transformer.py | 25 +- .../world_model_multitask.py | 139 +++++++++-- lzero/policy/unizero_multitask.py | 9 +- .../atari_unizero_multitask_26games_config.py | 224 ++++++++++++++++++ .../config/atari_unizero_multitask_config.py | 49 ++-- 10 files changed, 431 insertions(+), 45 deletions(-) create mode 100644 zoo/atari/config/atari_unizero_multitask_26games_config.py diff --git a/lzero/entry/train_unizero_multitask.py b/lzero/entry/train_unizero_multitask.py index bc48fb08e..f8fa1ee23 100644 --- a/lzero/entry/train_unizero_multitask.py +++ b/lzero/entry/train_unizero_multitask.py @@ -18,7 +18,9 @@ from lzero.worker import MuZeroCollector as Collector, MuZeroEvaluator as Evaluator from lzero.mcts import UniZeroGameBuffer as GameBuffer +from line_profiler import line_profiler +#@profile def train_unizero_multitask( input_cfg_list: List[Tuple[int, Tuple[dict, dict]]], seed: int = 0, diff --git a/lzero/mcts/buffer/game_buffer_unizero.py b/lzero/mcts/buffer/game_buffer_unizero.py index 2803b0213..ff13e325c 100644 --- a/lzero/mcts/buffer/game_buffer_unizero.py +++ b/lzero/mcts/buffer/game_buffer_unizero.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy +from line_profiler import line_profiler @BUFFER_REGISTRY.register('game_buffer_unizero') @@ -56,6 +57,7 @@ def __init__(self, cfg: dict): self.task_id = None print("No task_id found in configuration. Task ID is set to None.") + #@profile def sample( self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"] ) -> List[Any]: @@ -103,6 +105,7 @@ def sample( train_data = [current_batch, target_batch] return train_data + #@profile def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: """ Overview: @@ -198,6 +201,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]: context = reward_value_context, policy_re_context, policy_non_re_context, current_batch return context + #@profile def _prepare_policy_reanalyzed_context( self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str] ) -> List[Any]: @@ -251,6 +255,7 @@ def _prepare_policy_reanalyzed_context( ] return policy_re_context + #@profile def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, action_batch) -> np.ndarray: """ Overview: @@ -374,6 +379,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: return batch_target_policies_re + #@profile def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, action_batch) -> Tuple[ Any, Any]: """ diff --git a/lzero/mcts/tree_search/mcts_ctree.py b/lzero/mcts/tree_search/mcts_ctree.py index f19410f3e..ad4698923 100644 --- a/lzero/mcts/tree_search/mcts_ctree.py +++ b/lzero/mcts/tree_search/mcts_ctree.py @@ -15,6 +15,7 @@ from lzero.mcts.ctree.ctree_muzero import mz_tree as mz_ctree from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as gmz_ctree +from line_profiler import line_profiler class UniZeroMCTSCtree(object): """ @@ -71,7 +72,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile + #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, List[Any]], task_id=None @@ -225,7 +226,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree return ctree.Roots(active_collect_env_num, legal_actions) - # @profile + # #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int, List[Any]] @@ -494,7 +495,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "e """ return tree_muzero.Roots(active_collect_env_num, legal_actions) - # @profile + # #@profile def search( self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], world_model_latent_history_roots: List[Any], to_play_batch: Union[int, List[Any]], ready_env_id=None, diff --git a/lzero/model/unizero_model_multitask.py b/lzero/model/unizero_model_multitask.py index 257bcb56c..6f601c372 100644 --- a/lzero/model/unizero_model_multitask.py +++ b/lzero/model/unizero_model_multitask.py @@ -10,11 +10,13 @@ from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model_multitask import WorldModelMT +from line_profiler import line_profiler # use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document. @MODEL_REGISTRY.register('UniZeroMTModel') class UniZeroMTModel(nn.Module): + #@profile def __init__( self, observation_shape: SequenceType = (4, 64, 64), @@ -162,6 +164,7 @@ def __init__( print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network') print('==' * 20) + #@profile def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, task_id=None) -> MZNetworkOutput: """ Overview: @@ -198,6 +201,7 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_ latent_state, ) + #@profile def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index=0, latent_state_index_in_search_path=[], task_id=None) -> MZNetworkOutput: """ diff --git a/lzero/model/unizero_world_models/moe.py b/lzero/model/unizero_world_models/moe.py index 68ccfe7cd..159afd69e 100644 --- a/lzero/model/unizero_world_models/moe.py +++ b/lzero/model/unizero_world_models/moe.py @@ -6,6 +6,17 @@ from simple_parsing.helpers import Serializable from torch import nn +# Modified from https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/transformer.py#L108 +class MultiplicationFeedForward(nn.Module): + def __init__(self, config): + super().__init__() + + self.w1 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False) + self.w2 = nn.Linear(4 * config.embed_dim, config.embed_dim, bias=False) + self.w3 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore @dataclasses.dataclass class MoeArgs(Serializable): diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index d509aa2f5..fa96a0be4 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -13,7 +13,9 @@ from torch.nn import functional as F from .kv_caching import KeysValues -from .moe import MoeLayer +from .moe import MoeLayer, MultiplicationFeedForward +from line_profiler import line_profiler + @dataclass class TransformerConfig: @@ -69,6 +71,7 @@ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: device = self.ln_f.weight.device # Assumption: All submodules are on the same device return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) + #@profile def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ @@ -91,6 +94,8 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues return x + + class Block(nn.Module): """ Transformer block class. @@ -122,7 +127,7 @@ def __init__(self, config: TransformerConfig) -> None: self.ln2 = nn.LayerNorm(config.embed_dim) self.attn = SelfAttention(config) if config.moe_in_transformer: - # 创建多个独立的 MLP 实例 + # 创Create multiple independent MLP instances self.experts = nn.ModuleList([ nn.Sequential( nn.Linear(config.embed_dim, 4 * config.embed_dim), @@ -141,6 +146,21 @@ def __init__(self, config: TransformerConfig) -> None: print("="*20) print(f'use moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') print("="*20) + elif config.multiplication_moe_in_transformer: + # Create multiple FeedForward instances for multiplication-based MoE + self.experts = nn.ModuleList([ + MultiplicationFeedForward(config) for _ in range(config.num_experts_of_moe_in_transformer) + ]) + + self.feed_forward = MoeLayer( + experts=self.experts, + gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + + print("="*20) + print(f'use multiplication moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}') + print("="*20) else: self.feed_forward = nn.Sequential( nn.Linear(config.embed_dim, 4 * config.embed_dim), @@ -209,6 +229,7 @@ def __init__(self, config: TransformerConfig) -> None: causal_mask = torch.tril(torch.ones(config.max_tokens, config.max_tokens)) self.register_buffer('mask', causal_mask) + #@profile def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ diff --git a/lzero/model/unizero_world_models/world_model_multitask.py b/lzero/model/unizero_world_models/world_model_multitask.py index 31de71ef6..a9d6b9061 100644 --- a/lzero/model/unizero_world_models/world_model_multitask.py +++ b/lzero/model/unizero_world_models/world_model_multitask.py @@ -18,10 +18,13 @@ from .transformer import Transformer, TransformerConfig from .utils import LossWithIntermediateLosses, init_weights, to_device_for_kvcache from .utils import WorldModelOutput, quantize_state +from .moe import MoeLayer, MultiplicationFeedForward logging.getLogger().setLevel(logging.DEBUG) +from line_profiler import line_profiler + class WorldModelMT(nn.Module): """ Overview: @@ -32,6 +35,8 @@ class WorldModelMT(nn.Module): - a transformer, which processes the input sequences, - and heads, which generate the logits for observations, rewards, policy, and value. """ + + #@profile def __init__(self, config: TransformerConfig, tokenizer) -> None: """ Overview: @@ -53,7 +58,10 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.head_rewards_multi_task = nn.ModuleList() self.head_observations_multi_task = nn.ModuleList() - self.num_experts_in_softmoe_head = config.num_experts_in_softmoe_head + self.num_experts_in_moe_head = config.num_experts_in_moe_head + self.use_normal_head = config.use_normal_head + self.use_moe_head = config.use_moe_head + self.use_softmoe_head = config.use_softmoe_head # Move all modules to the specified device print(f"self.config.device: {self.config.device}") @@ -74,7 +82,9 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.act_embedding_table = nn.Embedding(config.action_space_size, config.embed_dim, device=self.device) print(f"self.act_embedding_table.weight.device: {self.act_embedding_table.weight.device}") - if self.num_experts_in_softmoe_head == -1: + # if self.num_experts_in_moe_head == -1: + assert self.num_experts_in_moe_head > 0 + if self.use_normal_head: print('We use normal head') # TODO: Normal Head for task_id in range(self.task_num): # TODO @@ -93,8 +103,8 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.obs_per_embdding_dim, self.sim_norm) # NOTE: we add a sim_norm to the head for observations self.head_observations_multi_task.append(self.head_observations) - else: - print(f'We use softmoe head, self.num_experts_in_softmoe_head is {self.num_experts_in_softmoe_head}') + elif self.use_softmoe_head: + print(f'We use softmoe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') # Dictionary to store SoftMoE instances self.soft_moe_instances = {} @@ -105,6 +115,18 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: self.head_value_multi_task.append(self.head_value) self.head_rewards_multi_task.append(self.head_rewards) self.head_observations_multi_task.append(self.head_observations) + elif self.use_moe_head: + print(f'We use moe head, self.num_experts_in_moe_head is {self.num_experts_in_moe_head}') + # Dictionary to store moe instances + self.moe_instances = {} + + # Create moe head modules + self.create_head_modules_moe() + + self.head_policy_multi_task.append(self.head_policy) + self.head_value_multi_task.append(self.head_value) + self.head_rewards_multi_task.append(self.head_rewards) + self.head_observations_multi_task.append(self.head_observations) # Apply weight initialization, the order is important @@ -172,6 +194,66 @@ def _create_head(self, block_mask: torch.Tensor, output_dim: int, norm_layer=Non head_module=nn.Sequential(*modules) ) + def _create_head_moe(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, moe=None) -> Head: + """Create moe head modules for the transformer.""" + modules = [ + moe, + nn.Linear(self.config.embed_dim, output_dim) + ] + if norm_layer: + modules.append(norm_layer) + return Head( + max_blocks=self.config.max_blocks, + block_mask=block_mask, + head_module=nn.Sequential(*modules) + ) + def get_moe(self, name): + """Get or create a MoE instance""" + if name not in self.moe_instances: + # Create multiple FeedForward instances for multiplication-based MoE + self.experts = nn.ModuleList([ + MultiplicationFeedForward(self.config) for _ in range(self.config.num_experts_of_moe_in_transformer) + ]) + + self.moe_instances[name] = MoeLayer( + experts=self.experts, + gate=nn.Linear(self.config.embed_dim, self.config.num_experts_of_moe_in_transformer, bias=False), + num_experts_per_tok=1, + ) + + return self.moe_instances[name] + + def create_head_modules_moe(self): + """Create all softmoe head modules""" + # Rewards head + self.head_rewards = self._create_head_moe( + self.act_tokens_pattern, + self.support_size, + moe=self.get_moe("rewards_moe") + ) + + # Observations head + self.head_observations = self._create_head_moe( + self.all_but_last_latent_state_pattern, + self.obs_per_embdding_dim, + norm_layer=self.sim_norm, # NOTE + moe=self.get_moe("observations_moe") + ) + + # Policy head + self.head_policy = self._create_head_moe( + self.value_policy_tokens_pattern, + self.action_space_size, + moe=self.get_moe("policy_moe") + ) + + # Value head + self.head_value = self._create_head_moe( + self.value_policy_tokens_pattern, + self.support_size, + moe=self.get_moe("value_moe") + ) + def _create_head_softmoe(self, block_mask: torch.Tensor, output_dim: int, norm_layer=None, soft_moe=None) -> Head: """Create softmoe head modules for the transformer.""" modules = [ @@ -193,13 +275,13 @@ def get_soft_moe(self, name): # self.soft_moe_instances[name] = SoftMoE( # dim=self.embed_dim, # seq_len=20, # TODO - # num_experts=self.num_experts_in_softmoe_head, + # num_experts=self.num_experts_in_moe_head, # ) from soft_moe_pytorch import DynamicSlotsSoftMoE as SoftMoE if name not in self.soft_moe_instances: self.soft_moe_instances[name] = SoftMoE( dim=self.embed_dim, - num_experts=self.num_experts_in_softmoe_head, + num_experts=self.num_experts_in_moe_head, geglu = True ) return self.soft_moe_instances[name] @@ -281,6 +363,7 @@ def _initialize_statistics(self) -> None: self.root_hit_cnt = 0 self.root_total_query_cnt = 0 + #@profile def _initialize_transformer_keys_values(self) -> None: """Initialize keys and values for the transformer.""" self.keys_values_wm_single_env = self.transformer.generate_empty_keys_values(n=1, @@ -288,6 +371,7 @@ def _initialize_transformer_keys_values(self) -> None: self.keys_values_wm = self.transformer.generate_empty_keys_values(n=self.env_num, max_tokens=self.context_length) + #@profile def precompute_pos_emb_diff_kv(self): """ Precompute positional embedding differences for key and value. """ if self.context_length <= 2: @@ -325,6 +409,7 @@ def precompute_pos_emb_diff_kv(self): self.pos_emb_diff_k.append(layer_pos_emb_diff_k) self.pos_emb_diff_v.append(layer_pos_emb_diff_v) + #@profile def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: """ Helper function to get positional embedding for a given layer and attention type. @@ -346,6 +431,7 @@ def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads ).transpose(1, 2).detach() + #@profile def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]], past_keys_values: Optional[torch.Tensor] = None, kvcache_independent: bool = False, is_init_infer: bool = True, @@ -412,21 +498,24 @@ def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tu # Generate logits # 1,...,0,1 https://github.com/eloialonso/iris/issues/19 - # TODO: one head or soft_moe - # logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) - # logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) - # logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) - # logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) - - # TODO: N head - logits_observations = self.head_observations_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) - logits_rewards = self.head_rewards_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) - logits_policy = self.head_policy_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) - logits_value = self.head_value_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + # TODO: one head or moe head + if self.use_moe_head: + logits_observations = self.head_observations(x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards(x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps) + else: + # TODO: N head + logits_observations = self.head_observations_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_rewards = self.head_rewards_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_policy = self.head_policy_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) + logits_value = self.head_value_multi_task[task_id](x, num_steps=num_steps, prev_steps=prev_steps) # logits_ends is None return WorldModelOutput(x, logits_observations, logits_rewards, None, logits_policy, logits_value) + + #@profile def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_independent, is_init_infer, valid_context_lengths): """ @@ -455,6 +544,7 @@ def _add_position_embeddings(self, embeddings, prev_steps, num_steps, kvcache_in valid_context_lengths + torch.arange(num_steps, device=self.device)).unsqueeze(1) return embeddings + position_embeddings + #@profile def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): """ Process combined observation embeddings and action tokens. @@ -485,6 +575,7 @@ def _process_obs_act_combined(self, obs_embeddings_or_act_tokens, prev_steps): return obs_act_embeddings + self.pos_emb(prev_steps + torch.arange(num_steps, device=self.device)), num_steps + #@profile def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, valid_context_lengths): """ Pass sequences through the transformer. @@ -505,6 +596,7 @@ def _transformer_pass(self, sequences, past_keys_values, kvcache_independent, va else: return self.transformer(sequences, past_keys_values, valid_context_lengths=valid_context_lengths) + #@profile @torch.no_grad() def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor, task_id=0) -> torch.FloatTensor: """ @@ -538,6 +630,8 @@ def reset_from_initial_observations(self, obs_act_dict: torch.FloatTensor, task_ return outputs_wm, self.latent_state + + #@profile @torch.no_grad() def refresh_kvs_with_initial_latent_state_for_init_infer(self, latent_state: torch.LongTensor, buffer_action=None, @@ -644,6 +738,8 @@ def refresh_kvs_with_initial_latent_state_for_init_infer(self, latent_state: tor return outputs_wm + + #@profile @torch.no_grad() def forward_initial_inference(self, obs_act_dict, task_id=0): """ @@ -661,6 +757,7 @@ def forward_initial_inference(self, obs_act_dict, task_id=0): return (outputs_wm.output_sequence, latent_state, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value) + #@profile @torch.no_grad() def forward_recurrent_inference(self, state_action_history, simulation_index=0, latent_state_index_in_search_path=[], task_id=0): @@ -795,6 +892,7 @@ def trim_and_pad_kv_cache(self, is_init_infer=True) -> list: return self.keys_values_wm_size_list + #@profile def update_cache_context(self, latent_state, is_init_infer=True, simulation_index=0, latent_state_index_in_search_path=[], valid_context_lengths=None): """ @@ -930,6 +1028,7 @@ def update_cache_context(self, latent_state, is_init_infer=True, simulation_inde # Store the latest key-value cache for recurrent inference self.past_kv_cache_recurrent_infer[cache_key] = copy.deepcopy(to_device_for_kvcache(self.keys_values_wm_single_env, 'cpu')) + #@profile def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, simulation_index: int = 0, task_id=0) -> list: """ @@ -978,6 +1077,7 @@ def retrieve_or_generate_kvcache(self, latent_state: list, ready_env_num: int, return self.keys_values_wm_size_list + #@profile def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar_transform_handle=None, task_id=0, **kwargs: Any) -> LossWithIntermediateLosses: # Encode observations into latent state representations obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch['observations'], task_id=task_id) @@ -1216,6 +1316,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar logits_value=outputs.logits_value, ) + #@profile def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): # Assume outputs is an object with logits attributes like 'rewards', 'policy', and 'value'. # labels is a target tensor for comparison. batch is a dictionary with a mask indicating valid timesteps. @@ -1242,6 +1343,7 @@ def compute_cross_entropy_loss(self, outputs, labels, batch, element='rewards'): return loss + #@profile def compute_policy_entropy_loss(self, logits, mask): # Compute entropy of the policy probs = torch.softmax(logits, dim=1) @@ -1251,6 +1353,7 @@ def compute_policy_entropy_loss(self, logits, mask): entropy_loss = (entropy * mask) return entropy_loss + #@profile def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torch.Tensor, ends: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: assert torch.all(ends.sum(dim=1) <= 1) # Each sequence sample should have at most one 'done' flag @@ -1268,6 +1371,7 @@ def compute_labels_world_model(self, obs_embeddings: torch.Tensor, rewards: torc return labels_observations, labels_rewards.reshape(-1, self.support_size), labels_ends.reshape(-1) + #@profile def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, target_policy: torch.Tensor, mask_padding: torch.BoolTensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Compute labels for value and policy predictions. """ @@ -1283,6 +1387,7 @@ def compute_labels_world_model_value_policy(self, target_value: torch.Tensor, ta return labels_policy.reshape(-1, self.action_space_size), labels_value.reshape(-1, self.support_size) + #@profile def clear_caches(self): """ Clears the caches of the world model. diff --git a/lzero/policy/unizero_multitask.py b/lzero/policy/unizero_multitask.py index b7cc67737..d5bae15cb 100644 --- a/lzero/policy/unizero_multitask.py +++ b/lzero/policy/unizero_multitask.py @@ -16,6 +16,7 @@ DiscreteSupport, to_torch_float_tensor, mz_network_output_unpack, select_action, prepare_obs from lzero.policy.unizero import UniZeroPolicy from .utils import configure_optimizers_nanogpt +from line_profiler import line_profiler sys.path.append('/Users/puyuan/code/LibMTL/') from LibMTL.weighting.MoCo_unizero import MoCo as GradCorrect @@ -475,7 +476,7 @@ def _init_learn(self) -> None: # self.curr_min_loss = torch.tensor([0. for i in range(self.task_num)], device=self._cfg.device) # self.grad_correct.prev_loss = self.curr_min_loss - # @profile + #@profile def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, int]]: """ Overview: @@ -753,7 +754,7 @@ def _init_collect(self) -> None: self.last_batch_obs = torch.zeros([self.collector_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) self.last_batch_action = [-1 for i in range(self.collector_env_num)] - # @profile + #@profile def _forward_collect( self, data: torch.Tensor, @@ -887,6 +888,7 @@ def _init_eval(self) -> None: self.last_batch_obs = torch.zeros([self.evaluator_env_num, self._cfg.model.observation_shape]).to(self._cfg.device) self.last_batch_action = [-1 for _ in range(self.evaluator_env_num)] + #@profile def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1, ready_env_id: np.array = None, task_id=None) -> Dict: """ @@ -971,6 +973,7 @@ def _forward_eval(self, data: torch.Tensor, action_mask: list, to_play: int = -1 return output + #@profile def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True) -> None: """ Overview: @@ -1020,6 +1023,7 @@ def _reset_collect(self, env_id: int = None, current_steps: int = 0, reset_init_ # TODO: check its correctness self._reset_target_model() + #@profile def _reset_target_model(self) -> None: """ Overview: @@ -1040,6 +1044,7 @@ def _reset_target_model(self) -> None: torch.cuda.empty_cache() print('collector: target_model past_kv_cache.clear()') + #@profile def _reset_eval(self, env_id: int = None, current_steps: int = 0, reset_init_data: bool = True) -> None: """ Overview: diff --git a/zoo/atari/config/atari_unizero_multitask_26games_config.py b/zoo/atari/config/atari_unizero_multitask_26games_config.py new file mode 100644 index 000000000..91d632fdf --- /dev/null +++ b/zoo/atari/config/atari_unizero_multitask_26games_config.py @@ -0,0 +1,224 @@ +from easydict import EasyDict +from copy import deepcopy +# from zoo.atari.config.atari_env_action_space_map import atari_env_action_space_map + +def create_config(env_id, action_space_size, collector_env_num, evaluator_env_num, n_episode, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type): + return EasyDict(dict( + env=dict( + stop_value=int(1e6), + env_id=env_id, + observation_shape=(3, 64, 64), + gray_scale=False, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ), + full_action_space=True, + # ===== only for debug ===== + # collect_max_episode_steps=int(30), + # eval_max_episode_steps=int(30), + # collect_max_episode_steps=int(50), + # eval_max_episode_steps=int(50), + # collect_max_episode_steps=int(500), + # eval_max_episode_steps=int(500), + ), + policy=dict( + learn=dict(learner=dict(hook=dict(save_ckpt_after_iter=200000,),),), # default is 10000 + grad_correct_params=dict( + # for MoCo + MoCo_beta=0.5, + MoCo_beta_sigma=0.5, + MoCo_gamma=0.1, + MoCo_gamma_sigma=0.5, + MoCo_rho=0, + # for CAGrad + calpha=0.5, + rescale=1, + ), + task_num=len(env_id_list), + task_id=0, + model=dict( + observation_shape=(3, 64, 64), + action_space_size=action_space_size, + norm_type=norm_type, + world_model_cfg=dict( + max_blocks=num_unroll_steps, + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + # device='cpu', # 'cuda', + device='cuda', # 'cuda', + action_space_size=action_space_size, + # num_layers=4, # NOTE + num_layers=2, # NOTE + num_heads=8, + embed_dim=768, + obs_type='image', + # env_num=max(collector_env_num, evaluator_env_num), + env_num=8, # TODO: the max of all tasks + # collector_env_num=collector_env_num, + # evaluator_env_num=evaluator_env_num, + task_num=len(env_id_list), + use_normal_head=True, + # use_normal_head=False, + use_softmoe_head=False, + # use_moe_head=True, + use_moe_head=False, + num_experts_in_moe_head=4, # NOTE + # moe_in_transformer=True, + moe_in_transformer=False, # NOTE + # multiplication_moe_in_transformer=True, + multiplication_moe_in_transformer=False, # NOTE + num_experts_of_moe_in_transformer=4, + # num_experts_of_moe_in_transformer=2, + ), + ), + # use_priority=False, + # print_task_priority_logs=False, + use_priority=True, # TODO + print_task_priority_logs=False, + cuda=True, + model_path=None, + num_unroll_steps=num_unroll_steps, + # update_per_collect=None, + update_per_collect=1000, + # update_per_collect=500, + replay_ratio=0.25, + batch_size=batch_size, + optim_type='AdamW', + num_simulations=num_simulations, + reanalyze_ratio=reanalyze_ratio, + n_episode=n_episode, + replay_buffer_size=int(1e6), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + )) + +def generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed): + configs = [] + # TODO + # exp_name_prefix = f'data_unizero_mt_0711/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head-softmoe4_1-encoder-{norm_type}_lsd768-nlayer4-nh8_seed{seed}/' + + # exp_name_prefix = f'data_unizero_mt_0716_debug/{len(env_id_list)}games_1-head-softmoe4-dynamics_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_upc1000_value-priority_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716_debug/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_CAGrad_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_MoCo_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_trans-ffw-moe1-same_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0719/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_trans-ffw-moe4_lsd768-nlayer2-nh8_max-bs1500_upc1000_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_1-head_1-encoder-{norm_type}_trans-ffw-moe4_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0722/{len(env_id_list)}games_1-encoder-{norm_type}_trans-ffw-moeV2-expert4_4-head_lsd768-nlayer2-nh8_max-bs2000_upc1000_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0722/{len(env_id_list)}games_1-encoder-{norm_type}_1-head-moeV2-expert4_lsd768-nlayer2-nh8_max-bs2000_upc1000_seed{seed}/' + exp_name_prefix = f'data_unizero_mt_0722/{len(env_id_list)}games_1-encoder-{norm_type}_26-head_lsd768-nlayer2-nh8_max-bs2000_upc1000_seed{seed}/' + + for task_id, env_id in enumerate(env_id_list): + config = create_config( + env_id, + action_space_size, + # collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing + # evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + # n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + collector_env_num, + evaluator_env_num, + n_episode, + num_simulations, + reanalyze_ratio, + batch_size, + num_unroll_steps, + infer_context_length, + norm_type + ) + config.policy.task_id = task_id + config.exp_name = exp_name_prefix + f"{env_id.split('NoFrameskip')[0]}_unizero-mt_seed{seed}" + + configs.append([task_id, [config, create_env_manager()]]) + return configs + + +def create_env_manager(): + return EasyDict(dict( + env=dict( + type='atari_lightzero', + import_names=['zoo.atari.envs.atari_lightzero_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero_multitask', + import_names=['lzero.policy.unizero_multitask'], + ), + )) + +if __name__ == "__main__": + from lzero.entry import train_unizero_multitask + # TODO + env_id_list = [ + 'PongNoFrameskip-v4', + 'MsPacmanNoFrameskip-v4', + 'SeaquestNoFrameskip-v4', + 'BoxingNoFrameskip-v4', # + 'AlienNoFrameskip-v4', + 'AmidarNoFrameskip-v4', + 'AssaultNoFrameskip-v4', + 'AsterixNoFrameskip-v4', + 'BankHeistNoFrameskip-v4', + 'BattleZoneNoFrameskip-v4', + 'ChopperCommandNoFrameskip-v4', + 'CrazyClimberNoFrameskip-v4', + 'DemonAttackNoFrameskip-v4', + 'FreewayNoFrameskip-v4', + 'FrostbiteNoFrameskip-v4', + 'GopherNoFrameskip-v4', + 'HeroNoFrameskip-v4', + 'JamesbondNoFrameskip-v4', + 'KangarooNoFrameskip-v4', + 'KrullNoFrameskip-v4', + 'KungFuMasterNoFrameskip-v4', + 'PrivateEyeNoFrameskip-v4', + 'RoadRunnerNoFrameskip-v4', + 'UpNDownNoFrameskip-v4', + 'QbertNoFrameskip-v4', + 'BreakoutNoFrameskip-v4', + ] + + # env_id_list = [ + # 'PongNoFrameskip-v4', + # 'MsPacmanNoFrameskip-v4', + # 'SeaquestNoFrameskip-v4', + # 'BoxingNoFrameskip-v4', + # 'AlienNoFrameskip-v4', + # 'CrazyClimberNoFrameskip-v4', + # 'BreakoutNoFrameskip-v4', + # 'QbertNoFrameskip-v4', + # ] + + action_space_size = 18 # Full action space + seed = 0 + collector_env_num = 8 + n_episode = 8 + evaluator_env_num = 3 + num_simulations = 50 + max_env_step = int(1e6) + reanalyze_ratio = 0. + # batch_size = [32, 32, 32, 32] + # max_batch_size = 2000 + max_batch_size = 1000 + batch_size = [int(max_batch_size/len(env_id_list)) for i in range(len(env_id_list))] + num_unroll_steps = 10 + infer_context_length = 4 + norm_type = 'LN' + # # norm_type = 'BN' # bad performance now + + + # ======== TODO: only for debug ======== + # collector_env_num = 3 + # n_episode = 3 + # evaluator_env_num = 2 + # num_simulations = 2 + # batch_size = [4, 4, 4, 4] + + configs = generate_configs(env_id_list, action_space_size, collector_env_num, n_episode, evaluator_env_num, num_simulations, reanalyze_ratio, batch_size, num_unroll_steps, infer_context_length, norm_type, seed) + + # Uncomment the desired training run + # train_unizero_multitask(configs[:1], seed=seed, max_env_step=max_env_step) # Pong + # train_unizero_multitask(configs[:2], seed=seed, max_env_step=max_env_step) # Pong, MsPacman + train_unizero_multitask(configs, seed=seed, max_env_step=max_env_step) # Pong, MsPacman, Seaquest, Boxing \ No newline at end of file diff --git a/zoo/atari/config/atari_unizero_multitask_config.py b/zoo/atari/config/atari_unizero_multitask_config.py index 66f7d148b..73a167239 100644 --- a/zoo/atari/config/atari_unizero_multitask_config.py +++ b/zoo/atari/config/atari_unizero_multitask_config.py @@ -58,10 +58,16 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu # collector_env_num=collector_env_num, # evaluator_env_num=evaluator_env_num, task_num=len(env_id_list), - # num_experts_in_softmoe_head=4, # NOTE - num_experts_in_softmoe_head=-1, # NOTE + use_normal_head=True, + # use_normal_head=False, + use_softmoe_head=False, + # use_moe_head=True, + use_moe_head=False, + num_experts_in_moe_head=4, # NOTE # moe_in_transformer=True, moe_in_transformer=False, # NOTE + # multiplication_moe_in_transformer=True, + multiplication_moe_in_transformer=False, # NOTE num_experts_of_moe_in_transformer=4, # num_experts_of_moe_in_transformer=2, ), @@ -74,8 +80,8 @@ def create_config(env_id, action_space_size, collector_env_num, evaluator_env_nu model_path=None, num_unroll_steps=num_unroll_steps, # update_per_collect=None, - # update_per_collect=1000, - update_per_collect=500, + update_per_collect=1000, + # update_per_collect=500, replay_ratio=0.25, batch_size=batch_size, optim_type='AdamW', @@ -92,28 +98,23 @@ def generate_configs(env_id_list, action_space_size, collector_env_num, n_episod configs = [] # TODO # exp_name_prefix = f'data_unizero_mt_0711/{len(env_id_list)}games_{"-".join(env_id_list)}_1-head-softmoe4_1-encoder-{norm_type}_lsd768-nlayer4-nh8_seed{seed}/' - - # exp_name_prefix = f'data_unizero_mt_0716_debug/{len(env_id_list)}games_1-head-softmoe4-dynamics_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' - # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_upc1000_value-priority_seed{seed}/' - # exp_name_prefix = f'data_unizero_mt_0716_debug/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' - # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_CAGrad_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_4-head_1-encoder-{norm_type}_MoCo_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' - # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_trans-ffw-moe1-same_lsd768-nlayer4-nh8_max-bs1500_upc1000_seed{seed}/' # exp_name_prefix = f'data_unizero_mt_0719/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_trans-ffw-moe4_lsd768-nlayer2-nh8_max-bs1500_upc1000_seed{seed}/' - exp_name_prefix = f'data_unizero_mt_0719/{len(env_id_list)}games_pong-boxing-envnum2_4-head_1-encoder-{norm_type}_lsd768-nlayer2-nh8_max-bs1500_upc500_seed{seed}/' - # exp_name_prefix = f'data_unizero_mt_0716/{len(env_id_list)}games_1-head_1-encoder-{norm_type}_trans-ffw-moe4_lsd768-nlayer4-nh8_max-bs1500_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0722_debug/{len(env_id_list)}games_1-encoder-{norm_type}_trans-ffw-moeV2-expert4_4-head_lsd768-nlayer2-nh8_max-bs2000_upc1000_seed{seed}/' + # exp_name_prefix = f'data_unizero_mt_0722_profile/lineprofile_{len(env_id_list)}games_1-encoder-{norm_type}_4-head_lsd768-nlayer2-nh8_max-bs2000_upc1000_seed{seed}/' + exp_name_prefix = f'data_unizero_mt_0722/{len(env_id_list)}games_1-encoder-{norm_type}_4-head_lsd768-nlayer2-nh8_max-bs2000_upc1000_seed{seed}/' for task_id, env_id in enumerate(env_id_list): config = create_config( env_id, action_space_size, - collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing - evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, - n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, - # collector_env_num, - # evaluator_env_num, - # n_episode, + # collector_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, # TODO: different collector_env_num for Pong and Boxing + # evaluator_env_num if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + # n_episode if env_id not in ['PongNoFrameskip-v4', 'BoxingNoFrameskip-v4'] else 2, + collector_env_num, + evaluator_env_num, + n_episode, num_simulations, reanalyze_ratio, batch_size, @@ -171,8 +172,8 @@ def create_env_manager(): max_env_step = int(1e6) reanalyze_ratio = 0. # batch_size = [32, 32, 32, 32] - max_batch_size = 1500 - batch_size = [int(1500/len(env_id_list)) for i in range(len(env_id_list))] + max_batch_size = 2000 + batch_size = [int(max_batch_size/len(env_id_list)) for i in range(len(env_id_list))] num_unroll_steps = 10 infer_context_length = 4 norm_type = 'LN' @@ -191,4 +192,10 @@ def create_env_manager(): # Uncomment the desired training run # train_unizero_multitask(configs[:1], seed=seed, max_env_step=max_env_step) # Pong # train_unizero_multitask(configs[:2], seed=seed, max_env_step=max_env_step) # Pong, MsPacman - train_unizero_multitask(configs, seed=seed, max_env_step=max_env_step) # Pong, MsPacman, Seaquest, Boxing \ No newline at end of file + train_unizero_multitask(configs, seed=seed, max_env_step=max_env_step) # Pong, MsPacman, Seaquest, Boxing + + # only for cprofile + # def run(max_env_step: int): + # train_unizero_multitask(configs, seed=seed, max_env_step=max_env_step) # Pong, MsPacman, Seaquest, Boxing + # import cProfile + # cProfile.run(f"run({20000})", filename="unizero_mt_4games_cprofile_20k_envstep", sort="cumulative") \ No newline at end of file