diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html similarity index 100% rename from docs/source/_templates/layout.html rename to docs/source/_templates/layout.html diff --git a/lzero/entry/train_unizero.py b/lzero/entry/train_unizero.py index cd7ff7605..edd67d6f8 100644 --- a/lzero/entry/train_unizero.py +++ b/lzero/entry/train_unizero.py @@ -21,6 +21,8 @@ from lzero.worker import MuZeroEvaluator as Evaluator from lzero.worker import MuZeroCollector as Collector from .utils import random_collect +import torch.distributed as dist +from ding.utils import set_pkg_seed, get_rank, get_world_size def train_unizero( @@ -52,33 +54,43 @@ def train_unizero( cfg, create_cfg = input_cfg - # Ensure the specified policy type is supported - assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'" + logging.info("===== 开始训练 UniZero =====") + + # 检查是否支持指定的 policy 类型 + assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero 仅支持以下算法: 'unizero', 'sampled_unizero'" + logging.info(f"使用的 policy 类型为: {create_cfg.policy.type}") - # Import the correct GameBuffer class based on the policy type + # 根据 policy 类型导入对应的 GameBuffer 类 game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'} - GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]), game_buffer_classes[create_cfg.policy.type]) - # Set device based on CUDA availability + # 检查是否有 GPU 可用,设置设备 cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' - logging.info(f'cfg.policy.device: {cfg.policy.device}') + logging.info(f"设备已设置为: {cfg.policy.device}") - # Compile the configuration + # 编译配置文件 + logging.info("正在编译配置文件...") cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + logging.info("配置文件编译完成!") - # Create main components: env, policy + # 创建环境管理器 + logging.info("正在创建环境管理器...") env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + logging.info("环境管理器创建完成!") + # 环境和随机种子初始化 + logging.info("正在初始化环境和随机种子...") collector_env.seed(cfg.seed) evaluator_env.seed(cfg.seed, dynamic_seed=False) set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available()) + logging.info("环境和随机种子初始化完成!") + # 如果使用 wandb,初始化 wandb if cfg.policy.use_wandb: - # Initialize wandb + logging.info("正在初始化 wandb...") wandb.init( project="LightZero", config=cfg, @@ -86,72 +98,99 @@ def train_unizero( monitor_gym=False, save_code=True, ) + logging.info("wandb 初始化完成!") + # 创建 policy + logging.info("正在创建 policy...") policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + logging.info("policy 创建完成!") - # Load pretrained model if specified + # 如果指定了模型路径,加载预训练模型 if model_path is not None: - logging.info(f'Loading model from {model_path} begin...') + logging.info(f"正在从 {model_path} 加载预训练模型...") policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) - logging.info(f'Loading model from {model_path} end!') + logging.info("预训练模型加载完成!") - # Create worker components: learner, collector, evaluator, replay buffer, commander + # 创建训练的核心组件 + logging.info("正在创建训练的核心组件...") tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) - - # MCTS+RL algorithms related core code - policy_config = cfg.policy - replay_buffer = GameBuffer(policy_config) + replay_buffer = GameBuffer(cfg.policy) collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name, - policy_config=policy_config) + policy_config=cfg.policy) evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode, stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode, - tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config) + tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=cfg.policy) + logging.info("训练核心组件创建完成!") - # Learner's before_run hook + # Learner 的前置 hook + logging.info("正在执行 Learner 的 before_run hook...") learner.call_hook('before_run') - if policy_config.use_wandb: + logging.info("Learner 的 before_run hook 执行完成!") + + if cfg.policy.use_wandb: policy.set_train_iter_env_step(learner.train_iter, collector.envstep) - # Collect random data before training + # 随机收集数据 if cfg.policy.random_collect_episode_num > 0: + logging.info("正在进行随机数据收集...") random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + logging.info("随机数据收集完成!") batch_size = policy._cfg.batch_size + if cfg.policy.multi_gpu: + # 获取当前的 world_size 和 rank + world_size = get_world_size() + rank = get_rank() + else: + world_size = 1 + rank = 0 + while True: - # Log buffer memory usage + # torch.cuda.empty_cache() + + # 记录 replay buffer 的内存使用情况 + # logging.info(f"训练迭代 {learner.train_iter}: 正在记录 replay buffer 的内存使用情况...") log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + # logging.info(f"训练迭代 {learner.train_iter}: 内存使用记录完成!") - # Set temperature for visit count distributions + # 设置温度参数 collect_kwargs = { 'temperature': visit_count_temperature( - policy_config.manual_temperature_decay, - policy_config.fixed_temperature_value, - policy_config.threshold_training_steps_for_final_temperature, + cfg.policy.manual_temperature_decay, + cfg.policy.fixed_temperature_value, + cfg.policy.threshold_training_steps_for_final_temperature, trained_steps=learner.train_iter ), - 'epsilon': 0.0 # Default epsilon value + 'epsilon': 0.0 # 默认 epsilon 值 } + # logging.info(f"训练迭代 {learner.train_iter}: 温度设置完成,值为 {collect_kwargs['temperature']}") - # Configure epsilon for epsilon-greedy exploration - if policy_config.eps.eps_greedy_exploration_in_collect: + # 配置 epsilon-greedy 探索 + if cfg.policy.eps.eps_greedy_exploration_in_collect: epsilon_greedy_fn = get_epsilon_greedy_fn( - start=policy_config.eps.start, - end=policy_config.eps.end, - decay=policy_config.eps.decay, - type_=policy_config.eps.type + start=cfg.policy.eps.start, + end=cfg.policy.eps.end, + decay=cfg.policy.eps.decay, + type_=cfg.policy.eps.type ) collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + # logging.info(f"训练迭代 {learner.train_iter}: epsilon 设置完成,值为 {collect_kwargs['epsilon']}") - # Evaluate policy performance + # 评估 policy 的表现 if evaluator.should_eval(learner.train_iter): + logging.info(f"训练迭代 {learner.train_iter}: 开始评估...") stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + logging.info(f"训练迭代 {learner.train_iter}: 评估完成,是否停止: {stop}, 当前奖励: {reward}") if stop: + logging.info("满足停止条件,训练结束!") break - # Collect new data + # 收集新数据 + # logging.info(f"Rank {rank}, 训练迭代 {learner.train_iter}: 开始收集新数据...") new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + logging.info(f"Rank {rank}, 训练迭代 {learner.train_iter}: 新数据收集完成!") # Determine updates per collection update_per_collect = cfg.policy.update_per_collect @@ -162,44 +201,60 @@ def train_unizero( collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0]) update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) - # Update replay buffer + # 更新 replay buffer + # logging.info(f"训练迭代 {learner.train_iter}: 开始更新 replay buffer...") replay_buffer.push_game_segments(new_data) replay_buffer.remove_oldest_data_to_fit() + # logging.info(f"训练迭代 {learner.train_iter}: replay buffer 更新完成!") + + if world_size > 1: + # 同步训练前所有rank的准备状态 + try: + dist.barrier() + # logging.info(f'Rank {rank}: 通过训练前的同步障碍') + except Exception as e: + logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}') + break - # Train the policy if sufficient data is available + # 检查是否有足够数据进行训练 if collector.envstep > cfg.policy.train_start_after_envsteps: if cfg.policy.sample_type == 'episode': data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size else: data_sufficient = replay_buffer.get_num_of_transitions() > batch_size + if not data_sufficient: - logging.warning( - f'The data in replay_buffer is not sufficient to sample a mini-batch: ' - f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....' - ) + # NOTE: 注意ddp训练时,不同rank可能有的replay buffer 数据不足,导致有的没有进入训练阶段,从而通信超时,需要确保同时进入训练阶段 + logging.warning(f"Rank {rank}: 训练迭代 {learner.train_iter}: replay buffer 数据不足,继续收集数据...") continue + logging.info(f"Rank {rank}, 训练迭代 {learner.train_iter}: 开始训练!") + + # 执行多轮训练 for i in range(update_per_collect): train_data = replay_buffer.sample(batch_size, policy) if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0: - # Clear caches and precompute positional embedding matrices - policy.recompute_pos_emb_diff_and_clear_cache() # TODO - - if policy_config.use_wandb: + policy.recompute_pos_emb_diff_and_clear_cache() + + if cfg.policy.use_wandb: policy.set_train_iter_env_step(learner.train_iter, collector.envstep) train_data.append({'train_which_component': 'transformer'}) log_vars = learner.train(train_data, collector.envstep) - if cfg.policy.use_priority: replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + logging.info(f"Rank {rank}, 训练迭代 {learner.train_iter}: 训练完成!") policy.recompute_pos_emb_diff_and_clear_cache() - # Check stopping criteria + # 检查停止条件 if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + logging.info("满足停止条件,训练结束!") break learner.call_hook('after_run') - wandb.finish() - return policy + if cfg.policy.use_wandb: + wandb.finish() + logging.info("===== 训练完成 =====") + return policy \ No newline at end of file diff --git a/lzero/entry/train_unizero_bkp.py b/lzero/entry/train_unizero_bkp.py new file mode 100644 index 000000000..75ee9bcfb --- /dev/null +++ b/lzero/entry/train_unizero_bkp.py @@ -0,0 +1,206 @@ +import logging +import os +from functools import partial +from typing import Tuple, Optional + +import torch +import wandb +from ding.config import compile_config +from ding.envs import create_env_manager +from ding.envs import get_vec_env_setting +from ding.policy import create_policy +from ding.rl_utils import get_epsilon_greedy_fn +from ding.utils import set_pkg_seed, get_rank +from ding.worker import BaseLearner +from tensorboardX import SummaryWriter +from torch.utils.tensorboard import SummaryWriter + +from lzero.entry.utils import log_buffer_memory_usage +from lzero.policy import visit_count_temperature +from lzero.policy.random_policy import LightZeroRandomPolicy +from lzero.worker import MuZeroEvaluator as Evaluator +from lzero.worker import MuZeroCollector as Collector +from .utils import random_collect + + +def train_unizero( + input_cfg: Tuple[dict, dict], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + model_path: Optional[str] = None, + max_train_iter: Optional[int] = int(1e10), + max_env_step: Optional[int] = int(1e10), +) -> 'Policy': + """ + Overview: + The train entry for UniZero, proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models. + UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms, + particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667. + Arguments: + - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type. + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - model_path (:obj:`Optional[str]`): The pretrained model path, which should + point to the ckpt file of the pretrained model, and an absolute path is recommended. + In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``. + - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training. + - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps. + Returns: + - policy (:obj:`Policy`): Converged policy. + """ + + cfg, create_cfg = input_cfg + + # Ensure the specified policy type is supported + assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'" + + # Import the correct GameBuffer class based on the policy type + game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'} + + GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]), + game_buffer_classes[create_cfg.policy.type]) + + # Set device based on CUDA availability + cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu' + logging.info(f'cfg.policy.device: {cfg.policy.device}') + + # Compile the configuration + cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True) + + # Create main components: env, policy + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + + collector_env.seed(cfg.seed) + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available()) + + if cfg.policy.use_wandb: + # Initialize wandb + wandb.init( + project="LightZero", + config=cfg, + sync_tensorboard=False, + monitor_gym=False, + save_code=True, + ) + + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval']) + + # Load pretrained model if specified + if model_path is not None: + logging.info(f'Loading model from {model_path} begin...') + policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device)) + logging.info(f'Loading model from {model_path} end!') + + # Create worker components: learner, collector, evaluator, replay buffer, commander + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + + # MCTS+RL algorithms related core code + policy_config = cfg.policy + replay_buffer = GameBuffer(policy_config) + collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name, + policy_config=policy_config) + evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode, + stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode, + tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config) + + # Learner's before_run hook + learner.call_hook('before_run') + if policy_config.use_wandb: + policy.set_train_iter_env_step(learner.train_iter, collector.envstep) + + # Collect random data before training + if cfg.policy.random_collect_episode_num > 0: + random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer) + + batch_size = policy._cfg.batch_size + + while True: + # Log buffer memory usage + log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger) + + # Set temperature for visit count distributions + collect_kwargs = { + 'temperature': visit_count_temperature( + policy_config.manual_temperature_decay, + policy_config.fixed_temperature_value, + policy_config.threshold_training_steps_for_final_temperature, + trained_steps=learner.train_iter + ), + 'epsilon': 0.0 # Default epsilon value + } + + # Configure epsilon for epsilon-greedy exploration + if policy_config.eps.eps_greedy_exploration_in_collect: + epsilon_greedy_fn = get_epsilon_greedy_fn( + start=policy_config.eps.start, + end=policy_config.eps.end, + decay=policy_config.eps.decay, + type_=policy_config.eps.type + ) + collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep) + + # Evaluate policy performance + # if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter): + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) + if stop: + break + + # Collect new data + new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs) + + # Determine updates per collection + update_per_collect = cfg.policy.update_per_collect + if update_per_collect is None: + # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio. + # The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game. + # On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps. + collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0]) + update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio) + + # Update replay buffer + replay_buffer.push_game_segments(new_data) + replay_buffer.remove_oldest_data_to_fit() + + # Train the policy if sufficient data is available + if collector.envstep > cfg.policy.train_start_after_envsteps: + if cfg.policy.sample_type == 'episode': + data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size + else: + data_sufficient = replay_buffer.get_num_of_transitions() > batch_size + if not data_sufficient: + logging.warning( + f'The data in replay_buffer is not sufficient to sample a mini-batch: ' + f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....' + ) + continue + + for i in range(update_per_collect): + train_data = replay_buffer.sample(batch_size, policy) + if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0: + # Clear caches and precompute positional embedding matrices + policy.recompute_pos_emb_diff_and_clear_cache() # TODO + + if policy_config.use_wandb: + policy.set_train_iter_env_step(learner.train_iter, collector.envstep) + + train_data.append({'train_which_component': 'transformer'}) + log_vars = learner.train(train_data, collector.envstep) + + if cfg.policy.use_priority: + replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig']) + + policy.recompute_pos_emb_diff_and_clear_cache() + + # Check stopping criteria + if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter: + break + + learner.call_hook('after_run') + wandb.finish() + return policy diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py index 2678066e9..db548049a 100644 --- a/lzero/mcts/buffer/game_buffer.py +++ b/lzero/mcts/buffer/game_buffer.py @@ -401,8 +401,13 @@ def _preprocess_to_play_and_action_mask( unroll_steps + 1] ) if len(action_mask_tmp) < unroll_steps + 1: + # action_mask_tmp += [ + # list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) + # for _ in range(unroll_steps + 1 - len(action_mask_tmp)) + # ] + # TODO: padded data action_mask_tmp += [ - list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) + list(np.zeros(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(unroll_steps + 1 - len(action_mask_tmp)) ] action_mask.append(action_mask_tmp) diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 2ba8180de..aa923c56b 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -711,7 +711,9 @@ def _compute_target_policy_non_reanalyzed( ] else: legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)] - + # print(f'='*20) + # print(f'buffer_muzero: action_mask:{action_mask}') + with torch.no_grad(): policy_index = 0 # 0 -> Invalid target policy for padding outside of game segments, @@ -730,11 +732,17 @@ def _compute_target_policy_non_reanalyzed( # for atari/classic_control/box2d environments that only have one player. target_policies.append(distributions) else: - # for board games that have two players. + # for board games that have two players or envs that have varied action space. policy_tmp = [0 for _ in range(policy_shape)] for index, legal_action in enumerate(legal_actions[policy_index]): # only the action in ``legal_action`` the policy logits is nonzero - policy_tmp[legal_action] = distributions[index] + # policy_tmp[legal_action] = distributions[index] + try: + policy_tmp[legal_action] = distributions[index] + except Exception as e: + print('='*20) + print(f'Exception:{e}, distributions:{distributions}, legal_action:{legal_actions[policy_index]}') + # TODO: 出现这个问题的原因在于采样的序列末尾可能是padding的action_mask是以np.zeros(self._cfg.model.action_space_size, dtype=np.int8)进行pad的 target_policies.append(policy_tmp) else: # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0 diff --git a/lzero/model/common.py b/lzero/model/common.py index 22afa95fe..f2ed32798 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -274,6 +274,86 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return output +""" +使用 google-bert/bert-base-uncased , 模型的输入为id +""" +class HFLanguageRepresentationNetwork(nn.Module): + def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: int = 768, group_size: int = 8): + """ + 初始化语言表示网络 + + 参数: + - url (str): 预训练 Hugging Face 模型的地址,默认为 'google-bert/bert-base-uncased'。 + - embedding_size (int): 输出嵌入的维度大小,默认为 768。 + """ + super().__init__() + from transformers import AutoModel + # 加载 Hugging Face 预训练模型 + self.model = AutoModel.from_pretrained(url) + + # 设置嵌入维度,如果目标维度不是 768,则添加一个线性变换层用于降维或升维 + self.embedding_size = embedding_size + if self.embedding_size != 768: + self.embed_head = nn.Linear(768, self.embedding_size) + + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: + """ + 前向传播,获取输入序列的语言表示。 + + 参数: + - x (torch.Tensor): 输入的张量,通常是序列的 token 索引,形状为 [batch_size, seq_len]。 + - no_grad (bool): 是否在无梯度模式下运行,默认为 True。 + + 返回: + - torch.Tensor: 经过处理的语言嵌入向量,形状为 [batch_size, embedding_size]。 + """ + if no_grad: + # 在 no_grad 模式下禁用梯度计算以节省显存 + with torch.no_grad(): + x = x.long() # 确保输入张量为长整型 + outputs = self.model(x) # 获取模型的输出 + + # 模型输出的 last_hidden_state 形状为 [batch_size, seq_len, hidden_size] + # 我们通常取 [CLS] 标记对应的向量,即 outputs.last_hidden_state[:, 0, :] + cls_embedding = outputs.last_hidden_state[:, 0, :] + + # 如果目标的 embedding_size 不是 768,则应用线性变换 + if self.embedding_size == 768: + # NOTE: very important for training stability. + cls_embedding = self.sim_norm(cls_embedding) + + return cls_embedding + else: + cls_embedding = self.embed_head(cls_embedding) + + # NOTE: very important for training stability. + cls_embedding = self.sim_norm(cls_embedding) + + return cls_embedding + else: + # 非 no_grad 模式下,启用梯度计算 + x = x.long() # 确保输入张量为长整型 + outputs = self.model(x) + cls_embedding = outputs.last_hidden_state[:, 0, :] + + # 如果目标的 embedding_size 不是 768,则应用线性变换 + if self.embedding_size == 768: + # NOTE: very important for training stability. + cls_embedding = self.sim_norm(cls_embedding) + + return cls_embedding + else: + cls_embedding = self.embed_head(cls_embedding) + + # NOTE: very important for training stability. + cls_embedding = self.sim_norm(cls_embedding) + + return cls_embedding + + + class RepresentationNetworkUniZero(nn.Module): def __init__( diff --git a/lzero/model/common_noserve_input_id.py b/lzero/model/common_noserve_input_id.py new file mode 100644 index 000000000..f2ed32798 --- /dev/null +++ b/lzero/model/common_noserve_input_id.py @@ -0,0 +1,1211 @@ +""" +Overview: + In this Python file, we provide a collection of reusable model templates designed to streamline the development + process for various custom algorithms. By utilizing these pre-built model templates, users can quickly adapt and + customize their custom algorithms, ensuring efficient and effective development. + BTW, users can refer to the unittest of these model templates to learn how to use them. +""" +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from ding.torch_utils import MLP, ResBlock +from ding.utils import SequenceType +from ditk import logging + + +# use dataclass to make the output of network more convenient to use +@dataclass +class MZRNNNetworkOutput: + # output format of the MuZeroRNN model + value: torch.Tensor + value_prefix: torch.Tensor + policy_logits: torch.Tensor + latent_state: torch.Tensor + predict_next_latent_state: torch.Tensor + reward_hidden_state: Tuple[torch.Tensor] + + +@dataclass +class EZNetworkOutput: + # output format of the EfficientZero model + value: torch.Tensor + value_prefix: torch.Tensor + policy_logits: torch.Tensor + latent_state: torch.Tensor + reward_hidden_state: Tuple[torch.Tensor] + + +@dataclass +class MZNetworkOutput: + # output format of the MuZero model + value: torch.Tensor + reward: torch.Tensor + policy_logits: torch.Tensor + latent_state: torch.Tensor + + +class SimNorm(nn.Module): + + def __init__(self, simnorm_dim: int) -> None: + """ + Overview: + Simplicial normalization. Adapted from https://arxiv.org/abs/2204.00616. + Arguments: + - simnorm_dim (:obj:`int`): The dimension for simplicial normalization. + """ + super().__init__() + self.dim = simnorm_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass of the SimNorm layer. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor to normalize. + Returns: + - x (:obj:`torch.Tensor`): The normalized tensor. + """ + shp = x.shape + # Ensure that there is at least one simplex to normalize across. + if shp[1] != 0: + x = x.view(*shp[:-1], -1, self.dim) + x = F.softmax(x, dim=-1) + return x.view(*shp) + else: + return x + + def __repr__(self) -> str: + """ + Overview: + String representation of the SimNorm layer. + Returns: + - output (:obj:`str`): The string representation. + """ + return f"SimNorm(dim={self.dim})" + + +def AvgL1Norm(x, eps=1e-8): + """ + Overview: + Normalize the input tensor by the L1 norm. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor to normalize. + - eps (:obj:`float`): The epsilon value to prevent division by zero. + Returns: + - :obj:`torch.Tensor`: The normalized tensor. + """ + return x / x.abs().mean(-1, keepdim=True).clamp(min=eps) + + +class FeatureAndGradientHook: + + def __init__(self): + """ + Overview: + Class to capture features and gradients at SimNorm. + """ + self.features_before = [] + self.features_after = [] + self.grads_before = [] + self.grads_after = [] + + def setup_hooks(self, model): + # Hooks to capture features and gradients at SimNorm + self.forward_handler = model.sim_norm.register_forward_hook(self.forward_hook) + self.backward_handler = model.sim_norm.register_full_backward_hook(self.backward_hook) + + def forward_hook(self, module, input, output): + with torch.no_grad(): + self.features_before.append(input[0]) + self.features_after.append(output) + + def backward_hook(self, module, grad_input, grad_output): + with torch.no_grad(): + self.grads_before.append(grad_input[0] if grad_input[0] is not None else None) + self.grads_after.append(grad_output[0] if grad_output[0] is not None else None) + + def analyze(self): + # Calculate L2 norms of features + l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_before])) + l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_after])) + + # Calculate norms of gradients + grad_norm_before = torch.mean( + torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_before if g is not None])) + grad_norm_after = torch.mean( + torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_after if g is not None])) + + # Clear stored data and delete tensors to free memory + self.clear_data() + + # Optionally clear CUDA cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return l2_norm_before, l2_norm_after, grad_norm_before, grad_norm_after + + def clear_data(self): + del self.features_before[:] + del self.features_after[:] + del self.grads_before[:] + del self.grads_after[:] + + def remove_hooks(self): + self.forward_handler.remove() + self.backward_handler.remove() + + +class DownSample(nn.Module): + + def __init__(self, observation_shape: SequenceType, out_channels: int, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + num_resblocks: int = 1, + ) -> None: + """ + Overview: + Define downSample convolution network. Encode the observation into hidden state. + This network is often used in video games like Atari. In board games like go and chess, + we don't need this module. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[12, 96, 96] + for video games like atari, RGB 3 channel times stack 4 frames. + - out_channels (:obj:`int`): The output channels of output hidden state. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ + Use the inplace operation to speed up. + - norm_type (:obj:`Optional[str]`): The normalization type used in network, defaults to 'BN'. + - num_resblocks (:obj:`int`): The number of residual blocks. Defaults to 1. + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + + self.observation_shape = observation_shape + self.conv1 = nn.Conv2d( + observation_shape[0], + out_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False, # disable bias for better convergence + ) + if norm_type == 'BN': + self.norm1 = nn.BatchNorm2d(out_channels // 2) + elif norm_type == 'LN': + self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2], + eps=1e-5) + + self.resblocks1 = nn.ModuleList( + [ + ResBlock( + in_channels=out_channels // 2, + activation=activation, + norm_type=norm_type, + res_type='basic', + bias=False + ) for _ in range(num_resblocks) + ] + ) + self.downsample_block = ResBlock( + in_channels=out_channels // 2, + out_channels=out_channels, + activation=activation, + norm_type=norm_type, + res_type='downsample', + bias=False + ) + self.resblocks2 = nn.ModuleList( + [ + ResBlock( + in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_resblocks) + ] + ) + self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + self.resblocks3 = nn.ModuleList( + [ + ResBlock( + in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(1) + ] + ) + self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ + H is height. + - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ + output width, H_ is output height. + """ + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + + for block in self.resblocks1: + x = block(x) + x = self.downsample_block(x) + for block in self.resblocks2: + x = block(x) + x = self.pooling1(x) + for block in self.resblocks3: + x = block(x) + + # 64, 84, 96 are the most common observation shapes in Atari games. + if self.observation_shape[1] == 64: + output = x + elif self.observation_shape[1] == 84: + x = self.pooling2(x) + output = x + elif self.observation_shape[1] == 96: + x = self.pooling2(x) + output = x + else: + raise NotImplementedError(f"DownSample for observation shape {self.observation_shape} is not implemented now. " + f"You should transform the observation shape to 64 or 96 in the env.") + + return output + + +""" +使用 google-bert/bert-base-uncased , 模型的输入为id +""" +class HFLanguageRepresentationNetwork(nn.Module): + def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: int = 768, group_size: int = 8): + """ + 初始化语言表示网络 + + 参数: + - url (str): 预训练 Hugging Face 模型的地址,默认为 'google-bert/bert-base-uncased'。 + - embedding_size (int): 输出嵌入的维度大小,默认为 768。 + """ + super().__init__() + from transformers import AutoModel + # 加载 Hugging Face 预训练模型 + self.model = AutoModel.from_pretrained(url) + + # 设置嵌入维度,如果目标维度不是 768,则添加一个线性变换层用于降维或升维 + self.embedding_size = embedding_size + if self.embedding_size != 768: + self.embed_head = nn.Linear(768, self.embedding_size) + + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor: + """ + 前向传播,获取输入序列的语言表示。 + + 参数: + - x (torch.Tensor): 输入的张量,通常是序列的 token 索引,形状为 [batch_size, seq_len]。 + - no_grad (bool): 是否在无梯度模式下运行,默认为 True。 + + 返回: + - torch.Tensor: 经过处理的语言嵌入向量,形状为 [batch_size, embedding_size]。 + """ + if no_grad: + # 在 no_grad 模式下禁用梯度计算以节省显存 + with torch.no_grad(): + x = x.long() # 确保输入张量为长整型 + outputs = self.model(x) # 获取模型的输出 + + # 模型输出的 last_hidden_state 形状为 [batch_size, seq_len, hidden_size] + # 我们通常取 [CLS] 标记对应的向量,即 outputs.last_hidden_state[:, 0, :] + cls_embedding = outputs.last_hidden_state[:, 0, :] + + # 如果目标的 embedding_size 不是 768,则应用线性变换 + if self.embedding_size == 768: + # NOTE: very important for training stability. + cls_embedding = self.sim_norm(cls_embedding) + + return cls_embedding + else: + cls_embedding = self.embed_head(cls_embedding) + + # NOTE: very important for training stability. + cls_embedding = self.sim_norm(cls_embedding) + + return cls_embedding + else: + # 非 no_grad 模式下,启用梯度计算 + x = x.long() # 确保输入张量为长整型 + outputs = self.model(x) + cls_embedding = outputs.last_hidden_state[:, 0, :] + + # 如果目标的 embedding_size 不是 768,则应用线性变换 + if self.embedding_size == 768: + # NOTE: very important for training stability. + cls_embedding = self.sim_norm(cls_embedding) + + return cls_embedding + else: + cls_embedding = self.embed_head(cls_embedding) + + # NOTE: very important for training stability. + cls_embedding = self.sim_norm(cls_embedding) + + return cls_embedding + + + +class RepresentationNetworkUniZero(nn.Module): + + def __init__( + self, + observation_shape: SequenceType = (3, 64, 64), + num_res_blocks: int = 1, + num_channels: int = 64, + downsample: bool = True, + activation: nn.Module = nn.GELU(approximate='tanh'), + norm_type: str = 'BN', + embedding_dim: int = 256, + group_size: int = 8, + ) -> None: + """ + Overview: + Representation network used in UniZero. Encode the 2D image obs into latent state. + Currently, the network only supports obs images with both a width and height of 64. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] + for video games like atari, RGB 3 channel. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - num_channels (:obj:`int`): The channel of output hidden state. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ + Use the inplace operation to speed up. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - embedding_dim (:obj:`int`): The dimension of the latent state. + - group_size (:obj:`int`): The dimension for simplicial normalization. + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + logging.info(f"Using norm type: {norm_type}") + logging.info(f"Using activation type: {activation}") + + self.observation_shape = observation_shape + self.downsample = downsample + if self.downsample: + self.downsample_net = DownSample( + observation_shape, + num_channels, + activation=activation, + norm_type=norm_type, + ) + else: + self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) + + if norm_type == 'BN': + self.norm = nn.BatchNorm2d(num_channels) + elif norm_type == 'LN': + if downsample: + self.norm = nn.LayerNorm( + [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], + eps=1e-5) + else: + self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + self.activation = activation + self.embedding_dim = embedding_dim + + if self.observation_shape[1] == 64: + self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False) + + elif self.observation_shape[1] in [84, 96]: + self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False) + + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ + H is height. + - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ + output width, H_ is output height. + """ + if self.downsample: + x = self.downsample_net(x) + else: + x = self.conv(x) + x = self.norm(x) + x = self.activation(x) + for block in self.resblocks: + x = block(x) + + # Important: Transform the output feature plane to the latent state. + # For example, for an Atari feature plane of shape (64, 8, 8), + # flattening results in a size of 4096, which is then transformed to 768. + x = self.last_linear(x.view(x.size(0), -1)) + + x = x.view(-1, self.embedding_dim) + + # NOTE: very important for training stability. + x = self.sim_norm(x) + + return x + + +class RepresentationNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType = (4, 96, 96), + num_res_blocks: int = 1, + num_channels: int = 64, + downsample: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: str = 'BN', + embedding_dim: int = 256, + group_size: int = 8, + use_sim_norm: bool = False, + ) -> None: + """ + Overview: + Representation network used in MuZero and derived algorithms. Encode the 2D image obs into latent state. + Currently, the network only supports obs images with both a width and height of 96. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[4, 96, 96] + for video games like atari, 1 gray channel times stack 4 frames. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - num_channels (:obj:`int`): The channel of output hidden state. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ + Use the inplace operation to speed up. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - embedding_dim (:obj:`int`): The dimension of the output hidden state. + - group_size (:obj:`int`): The size of group in the SimNorm layer. + - use_sim_norm (:obj:`bool`): Whether to use SimNorm layer, defaults to False. + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + + self.downsample = downsample + if self.downsample: + self.downsample_net = DownSample( + observation_shape, + num_channels, + activation=activation, + norm_type=norm_type, + ) + else: + self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) + + if norm_type == 'BN': + self.norm = nn.BatchNorm2d(num_channels) + elif norm_type == 'LN': + if downsample: + self.norm = nn.LayerNorm( + [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], + eps=1e-5) + else: + self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + self.activation = activation + + self.use_sim_norm = use_sim_norm + + if self.use_sim_norm: + self.embedding_dim = embedding_dim + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ + H is height. + - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ + output width, H_ is output height. + """ + if self.downsample: + x = self.downsample_net(x) + else: + x = self.conv(x) + x = self.norm(x) + x = self.activation(x) + + for block in self.resblocks: + x = block(x) + + if self.use_sim_norm: + # NOTE: very important. + # for atari 64,8,8 = 4096 -> 768 + x = self.sim_norm(x) + + return x + + +class RepresentationNetworkMLP(nn.Module): + + def __init__( + self, + observation_shape: int, + hidden_channels: int = 64, + layer_num: int = 2, + activation: nn.Module = nn.GELU(approximate='tanh'), + norm_type: Optional[str] = 'BN', + group_size: int = 8, + ) -> torch.Tensor: + """ + Overview: + Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \ + with Multi-Layer Perceptron (MLP). + Arguments: + - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - hidden_channels (:obj:`int`): The channel of output hidden state. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ + Use the inplace operation to speed up. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + self.fc_representation = MLP( + in_channels=observation_shape, + hidden_channels=hidden_channels, + out_channels=hidden_channels, + layer_num=layer_num, + activation=activation, + norm_type=norm_type, + # don't use activation and norm in the last layer of representation network is important for convergence. + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=True, + ) + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. + - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. + """ + x = self.fc_representation(x) + # TODO + x = self.sim_norm(x) + return x + + +class LatentDecoder(nn.Module): + + def __init__(self, embedding_dim: int, output_shape: SequenceType, num_channels: int = 64, activation: nn.Module = nn.GELU(approximate='tanh')): + """ + Overview: + Decoder network used in UniZero. Decode the latent state into 2D image obs. + Arguments: + - embedding_dim (:obj:`int`): The dimension of the latent state. + - output_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] + for video games like atari, RGB 3 channel times stack 4 frames. + - num_channels (:obj:`int`): The channel of output hidden state. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). + """ + super().__init__() + self.embedding_dim = embedding_dim + self.output_shape = output_shape # (C, H, W) + self.num_channels = num_channels + self.activation = activation + + # Assuming that the output shape is (C, H, W) = (12, 96, 96) and embedding_dim is 256 + # We will reverse the process of the representation network + self.initial_size = ( + num_channels, output_shape[1] // 8, output_shape[2] // 8) # This should match the last layer of the encoder + self.fc = nn.Linear(self.embedding_dim, np.prod(self.initial_size)) + + # Upsampling blocks + self.conv_blocks = nn.ModuleList([ + # Block 1: (num_channels, H/8, W/8) -> (num_channels//2, H/4, W/4) + nn.ConvTranspose2d(num_channels, num_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1), + self.activation, + nn.BatchNorm2d(num_channels // 2), + # Block 2: (num_channels//2, H/4, W/4) -> (num_channels//4, H/2, W/2) + nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1, + output_padding=1), + self.activation, + nn.BatchNorm2d(num_channels // 4), + # Block 3: (num_channels//4, H/2, W/2) -> (output_shape[0], H, W) + nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1, + output_padding=1), + ]) + # TODO: last layer use sigmoid? + + def forward(self, embeddings: torch.Tensor) -> torch.Tensor: + # Map embeddings back to the image space + x = self.fc(embeddings) # (B, embedding_dim) -> (B, C*H/8*W/8) + x = x.view(-1, *self.initial_size) # (B, C*H/8*W/8) -> (B, C, H/8, W/8) + + # Apply conv blocks + for block in self.conv_blocks: + x = block(x) # Upsample progressively + + # The output x should have the shape of (B, output_shape[0], output_shape[1], output_shape[2]) + return x + + +class LatentEncoderForMemoryEnv(nn.Module): + + def __init__( + self, + image_shape=(3, 5, 5), + embedding_size=100, + channels=[16, 32, 64], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation: nn.Module = nn.GELU(approximate='tanh'), + normalize_pixel=False, + group_size: int = 8, + **kwargs, + ): + """ + Overview: + Encoder network used in UniZero in MemoryEnv. Encode the 2D image obs into latent state. + Arguments: + - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] + for video games like atari, RGB 3 channel times stack 4 frames. + - embedding_size (:obj:`int`): The dimension of the latent state. + - channels (:obj:`List[int]`): The channel of output hidden state. + - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers. + - strides (:obj:`List[int]`): The stride of convolution layers. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). \ + Use the inplace operation to speed up. + - normalize_pixel (:obj:`bool`): Whether to normalize the pixel values to [0, 1], defaults to False. + - group_size (:obj:`int`): The dimension for simplicial normalization + """ + super(LatentEncoderForMemoryEnv, self).__init__() + self.shape = image_shape + self.channels = [image_shape[0]] + list(channels) + + layers = [] + for i in range(len(self.channels) - 1): + layers.append( + nn.Conv2d( + self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i], + padding=kernel_sizes[i] // 2 # keep the same size of feature map + ) + ) + layers.append(nn.BatchNorm2d(self.channels[i + 1])) + layers.append(activation) + + layers.append(nn.AdaptiveAvgPool2d(1)) + + self.cnn = nn.Sequential(*layers) + self.linear = nn.Sequential( + nn.Linear(self.channels[-1], embedding_size, bias=False), + ) + init.kaiming_normal_(self.linear[0].weight, mode='fan_out', nonlinearity='relu') + + self.normalize_pixel = normalize_pixel + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, image): + if self.normalize_pixel: + image = image / 255.0 + x = self.cnn(image.float()) # (B, C, 1, 1) + x = torch.flatten(x, start_dim=1) # (B, C) + x = self.linear(x) # (B, embedding_size) + x = self.sim_norm(x) + return x + + +class LatentDecoderForMemoryEnv(nn.Module): + + def __init__( + self, + image_shape=(3, 5, 5), + embedding_size=256, + channels=[64, 32, 16], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), + **kwargs, + ): + """ + Overview: + Decoder network used in UniZero in MemoryEnv. Decode the latent state into 2D image obs. + Arguments: + - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] + for video games like atari, RGB 3 channel times stack 4 frames. + - embedding_size (:obj:`int`): The dimension of the latent state. + - channels (:obj:`List[int]`): The channel of output hidden state. + - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers. + - strides (:obj:`List[int]`): The stride of convolution layers. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.LeakyReLU(). \ + Use the inplace operation to speed up. + """ + super(LatentDecoderForMemoryEnv, self).__init__() + self.shape = image_shape + self.channels = list(channels) + [image_shape[0]] + + self.linear = nn.Linear(embedding_size, channels[0] * image_shape[1] * image_shape[2]) + + layers = [] + for i in range(len(self.channels) - 1): + layers.append( + nn.ConvTranspose2d( + self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i], + padding=kernel_sizes[i] // 2, output_padding=strides[i] - 1 + ) + ) + if i < len(self.channels) - 2: + layers.append(nn.BatchNorm2d(self.channels[i + 1])) + layers.append(activation) + else: + layers.append(nn.Sigmoid()) + + self.deconv = nn.Sequential(*layers) + + def forward(self, embedding): + x = self.linear(embedding) + x = x.view(-1, self.channels[0], self.shape[1], self.shape[2]) + x = self.deconv(x) # (B, C, H, W) + return x + + +class VectorDecoderForMemoryEnv(nn.Module): + + def __init__( + self, + embedding_dim: int, + output_shape: SequenceType, + hidden_channels: int = 64, + layer_num: int = 2, + activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), # TODO + norm_type: Optional[str] = 'BN', + ) -> torch.Tensor: + """ + Overview: + Decoder network used in UniZero in MemoryEnv. Decode the latent state into vector obs. + Arguments: + - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - hidden_channels (:obj:`int`): The channel of output hidden state. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \ + Use the inplace operation to speed up. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + self.fc_representation = MLP( + in_channels=embedding_dim, + hidden_channels=hidden_channels, + out_channels=output_shape, + layer_num=layer_num, + activation=activation, + norm_type=norm_type, + # don't use activation and norm in the last layer of representation network is important for convergence. + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. + - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. + """ + x = self.fc_representation(x) + return x + + +class PredictionNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType, + action_space_size: int, + num_res_blocks: int, + num_channels: int, + value_head_channels: int, + policy_head_channels: int, + fc_value_layers: int, + fc_policy_layers: int, + output_support_size: int, + flatten_output_size_for_value_head: int, + flatten_output_size_for_policy_head: int, + downsample: bool = False, + last_linear_layer_init_zero: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + ) -> None: + """ + Overview: + The definition of policy and value prediction network, which is used to predict value and policy by the + given latent state. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image. + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. + - num_channels (:obj:`int`): The channels of hidden states. + - value_head_channels (:obj:`int`): The channels of value head. + - policy_head_channels (:obj:`int`): The channels of policy head. + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - output_support_size (:obj:`int`): The size of categorical value output. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ + - flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ + of the value head. + - flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ + of the policy head. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ + dynamics/prediction mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super(PredictionNetwork, self).__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + + self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) + self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) + + if observation_shape[1] == 96: + latent_shape = (observation_shape[1] / 16, observation_shape[2] / 16) + elif observation_shape[1] == 64: + latent_shape = (observation_shape[1] / 8, observation_shape[2] / 8) + + if norm_type == 'BN': + self.norm_value = nn.BatchNorm2d(value_head_channels) + self.norm_policy = nn.BatchNorm2d(policy_head_channels) + elif norm_type == 'LN': + if downsample: + self.norm_value = nn.LayerNorm( + [value_head_channels, *latent_shape], + eps=1e-5) + self.norm_policy = nn.LayerNorm([policy_head_channels, *latent_shape], eps=1e-5) + else: + self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]], + eps=1e-5) + self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]], + eps=1e-5) + + self.flatten_output_size_for_value_head = flatten_output_size_for_value_head + self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head + + self.activation = activation + + self.fc_value = MLP( + in_channels=self.flatten_output_size_for_value_head, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=len(fc_value_layers) + 1, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.fc_policy = MLP( + in_channels=self.flatten_output_size_for_policy_head, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=len(fc_policy_layers) + 1, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + for res_block in self.resblocks: + latent_state = res_block(latent_state) + + value = self.conv1x1_value(latent_state) + value = self.norm_value(value) + value = self.activation(value) + + policy = self.conv1x1_policy(latent_state) + policy = self.norm_policy(policy) + policy = self.activation(policy) + + value = value.reshape(-1, self.flatten_output_size_for_value_head) + policy = policy.reshape(-1, self.flatten_output_size_for_policy_head) + + value = self.fc_value(value) + policy = self.fc_policy(policy) + return policy, value + + +class PredictionNetworkMLP(nn.Module): + + def __init__( + self, + action_space_size, + num_channels, + common_layer_num: int = 2, + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + ): + """ + Overview: + The definition of policy and value prediction network with Multi-Layer Perceptron (MLP), + which is used to predict value and policy by the given latent state. + Arguments: + - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \ + space, it is the number of discrete actions. + - num_channels (:obj:`int`): The channels of latent states. + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - output_support_size (:obj:`int`): The size of categorical value output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ + dynamics/prediction mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + self.num_channels = num_channels + + # ******* common backbone ****** + self.fc_prediction_common = MLP( + in_channels=self.num_channels, + hidden_channels=self.num_channels, + out_channels=self.num_channels, + layer_num=common_layer_num, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + # ******* value and policy head ****** + self.fc_value_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=len(fc_value_layers) + 1, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.fc_policy_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=len(fc_policy_layers) + 1, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, latent_state: torch.Tensor): + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + x_prediction_common = self.fc_prediction_common(latent_state) + + value = self.fc_value_head(x_prediction_common) + policy = self.fc_policy_head(x_prediction_common) + return policy, value + + +class PredictionHiddenNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType, + action_space_size: int, + num_res_blocks: int, + num_channels: int, + value_head_channels: int, + policy_head_channels: int, + fc_value_layers: int, + fc_policy_layers: int, + output_support_size: int, + flatten_output_size_for_value_head: int, + flatten_output_size_for_policy_head: int, + downsample: bool = False, + last_linear_layer_init_zero: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + gru_hidden_size: int = 512, + ) -> None: + """ + Overview: + The definition of policy and value prediction network, which is used to predict value and policy by the + given latent state. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image. + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. + - num_channels (:obj:`int`): The channels of hidden states. + - value_head_channels (:obj:`int`): The channels of value head. + - policy_head_channels (:obj:`int`): The channels of policy head. + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - output_support_size (:obj:`int`): The size of categorical value output. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ + - flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ + of the value head. + - flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ + of the policy head. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ + dynamics/prediction mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super(PredictionHiddenNetwork, self).__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + + self.observation_shape = observation_shape + self.gru_hidden_size = gru_hidden_size + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + + self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) + self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) + + if norm_type == 'BN': + self.norm_value = nn.BatchNorm2d(value_head_channels) + self.norm_policy = nn.BatchNorm2d(policy_head_channels) + elif norm_type == 'LN': + if downsample: + self.norm_value = nn.LayerNorm( + [value_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], + eps=1e-5) + self.norm_policy = nn.LayerNorm([policy_head_channels, math.ceil(observation_shape[-2] / 16), + math.ceil(observation_shape[-1] / 16)], eps=1e-5) + else: + self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]], + eps=1e-5) + self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]], + eps=1e-5) + + self.flatten_output_size_for_value_head = flatten_output_size_for_value_head + self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head + + self.activation = activation + + self.fc_value = MLP( + in_channels=self.flatten_output_size_for_value_head + self.gru_hidden_size, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=len(fc_value_layers) + 1, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.fc_policy = MLP( + in_channels=self.flatten_output_size_for_policy_head + self.gru_hidden_size, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=len(fc_policy_layers) + 1, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, latent_state: torch.Tensor, world_model_latent_history: torch.Tensor) -> Tuple[ + torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + for res_block in self.resblocks: + latent_state = res_block(latent_state) + + value = self.conv1x1_value(latent_state) + value = self.norm_value(value) + value = self.activation(value) + + policy = self.conv1x1_policy(latent_state) + policy = self.norm_policy(policy) + policy = self.activation(policy) + + latent_state_value = value.reshape(-1, self.flatten_output_size_for_value_head) + latent_state_policy = policy.reshape(-1, self.flatten_output_size_for_policy_head) + + # TODO: world_model_latent_history.squeeze(0) shape: (num_layers * num_directions, batch_size, hidden_size) -> ( batch_size, hidden_size) + latent_history_value = torch.cat([latent_state_value, world_model_latent_history.squeeze(0)], dim=1) + latent_history_policy = torch.cat([latent_state_policy, world_model_latent_history.squeeze(0)], dim=1) + + value = self.fc_value(latent_history_value) + policy = self.fc_policy(latent_history_policy) + return policy, value + + diff --git a/lzero/model/common_serve_input_id.py b/lzero/model/common_serve_input_id.py new file mode 100644 index 000000000..c77d9cb3b --- /dev/null +++ b/lzero/model/common_serve_input_id.py @@ -0,0 +1,1247 @@ +""" +Overview: + In this Python file, we provide a collection of reusable model templates designed to streamline the development + process for various custom algorithms. By utilizing these pre-built model templates, users can quickly adapt and + customize their custom algorithms, ensuring efficient and effective development. + BTW, users can refer to the unittest of these model templates to learn how to use them. +""" +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from ding.torch_utils import MLP, ResBlock +from ding.utils import SequenceType +from ditk import logging +from openai import OpenAI +from transformers import AutoTokenizer + +# use dataclass to make the output of network more convenient to use +@dataclass +class MZRNNNetworkOutput: + # output format of the MuZeroRNN model + value: torch.Tensor + value_prefix: torch.Tensor + policy_logits: torch.Tensor + latent_state: torch.Tensor + predict_next_latent_state: torch.Tensor + reward_hidden_state: Tuple[torch.Tensor] + + +@dataclass +class EZNetworkOutput: + # output format of the EfficientZero model + value: torch.Tensor + value_prefix: torch.Tensor + policy_logits: torch.Tensor + latent_state: torch.Tensor + reward_hidden_state: Tuple[torch.Tensor] + + +@dataclass +class MZNetworkOutput: + # output format of the MuZero model + value: torch.Tensor + reward: torch.Tensor + policy_logits: torch.Tensor + latent_state: torch.Tensor + + +class SimNorm(nn.Module): + + def __init__(self, simnorm_dim: int) -> None: + """ + Overview: + Simplicial normalization. Adapted from https://arxiv.org/abs/2204.00616. + Arguments: + - simnorm_dim (:obj:`int`): The dimension for simplicial normalization. + """ + super().__init__() + self.dim = simnorm_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass of the SimNorm layer. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor to normalize. + Returns: + - x (:obj:`torch.Tensor`): The normalized tensor. + """ + shp = x.shape + # Ensure that there is at least one simplex to normalize across. + if shp[1] != 0: + x = x.view(*shp[:-1], -1, self.dim) + x = F.softmax(x, dim=-1) + return x.view(*shp) + else: + return x + + def __repr__(self) -> str: + """ + Overview: + String representation of the SimNorm layer. + Returns: + - output (:obj:`str`): The string representation. + """ + return f"SimNorm(dim={self.dim})" + + +def AvgL1Norm(x, eps=1e-8): + """ + Overview: + Normalize the input tensor by the L1 norm. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor to normalize. + - eps (:obj:`float`): The epsilon value to prevent division by zero. + Returns: + - :obj:`torch.Tensor`: The normalized tensor. + """ + return x / x.abs().mean(-1, keepdim=True).clamp(min=eps) + + +class FeatureAndGradientHook: + + def __init__(self): + """ + Overview: + Class to capture features and gradients at SimNorm. + """ + self.features_before = [] + self.features_after = [] + self.grads_before = [] + self.grads_after = [] + + def setup_hooks(self, model): + # Hooks to capture features and gradients at SimNorm + self.forward_handler = model.sim_norm.register_forward_hook(self.forward_hook) + self.backward_handler = model.sim_norm.register_full_backward_hook(self.backward_hook) + + def forward_hook(self, module, input, output): + with torch.no_grad(): + self.features_before.append(input[0]) + self.features_after.append(output) + + def backward_hook(self, module, grad_input, grad_output): + with torch.no_grad(): + self.grads_before.append(grad_input[0] if grad_input[0] is not None else None) + self.grads_after.append(grad_output[0] if grad_output[0] is not None else None) + + def analyze(self): + # Calculate L2 norms of features + l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_before])) + l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_after])) + + # Calculate norms of gradients + grad_norm_before = torch.mean( + torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_before if g is not None])) + grad_norm_after = torch.mean( + torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_after if g is not None])) + + # Clear stored data and delete tensors to free memory + self.clear_data() + + # Optionally clear CUDA cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return l2_norm_before, l2_norm_after, grad_norm_before, grad_norm_after + + def clear_data(self): + del self.features_before[:] + del self.features_after[:] + del self.grads_before[:] + del self.grads_after[:] + + def remove_hooks(self): + self.forward_handler.remove() + self.backward_handler.remove() + + +class DownSample(nn.Module): + + def __init__(self, observation_shape: SequenceType, out_channels: int, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + num_resblocks: int = 1, + ) -> None: + """ + Overview: + Define downSample convolution network. Encode the observation into hidden state. + This network is often used in video games like Atari. In board games like go and chess, + we don't need this module. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[12, 96, 96] + for video games like atari, RGB 3 channel times stack 4 frames. + - out_channels (:obj:`int`): The output channels of output hidden state. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ + Use the inplace operation to speed up. + - norm_type (:obj:`Optional[str]`): The normalization type used in network, defaults to 'BN'. + - num_resblocks (:obj:`int`): The number of residual blocks. Defaults to 1. + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + + self.observation_shape = observation_shape + self.conv1 = nn.Conv2d( + observation_shape[0], + out_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False, # disable bias for better convergence + ) + if norm_type == 'BN': + self.norm1 = nn.BatchNorm2d(out_channels // 2) + elif norm_type == 'LN': + self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2], + eps=1e-5) + + self.resblocks1 = nn.ModuleList( + [ + ResBlock( + in_channels=out_channels // 2, + activation=activation, + norm_type=norm_type, + res_type='basic', + bias=False + ) for _ in range(num_resblocks) + ] + ) + self.downsample_block = ResBlock( + in_channels=out_channels // 2, + out_channels=out_channels, + activation=activation, + norm_type=norm_type, + res_type='downsample', + bias=False + ) + self.resblocks2 = nn.ModuleList( + [ + ResBlock( + in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_resblocks) + ] + ) + self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + self.resblocks3 = nn.ModuleList( + [ + ResBlock( + in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(1) + ] + ) + self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ + H is height. + - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ + output width, H_ is output height. + """ + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + + for block in self.resblocks1: + x = block(x) + x = self.downsample_block(x) + for block in self.resblocks2: + x = block(x) + x = self.pooling1(x) + for block in self.resblocks3: + x = block(x) + + # 64, 84, 96 are the most common observation shapes in Atari games. + if self.observation_shape[1] == 64: + output = x + elif self.observation_shape[1] == 84: + x = self.pooling2(x) + output = x + elif self.observation_shape[1] == 96: + x = self.pooling2(x) + output = x + else: + raise NotImplementedError(f"DownSample for observation shape {self.observation_shape} is not implemented now. " + f"You should transform the observation shape to 64 or 96 in the env.") + + return output +""" +使用vllm的BAAI/bge-base-en-v1.5 server, 模型的输入为id 需要先decode回string +""" +class HFLanguageRepresentationNetwork(nn.Module): + def __init__( + self, + url: str = 'BAAI/bge-base-en-v1.5', + embedding_size: int = 768, + group_size: int = 8, + api_base: str = "http://10.119.30.189:8081/v1", + api_key: str = "EMPTY" + ): + """ + 初始化语言表示网络,使用 vLLM 的 API 服务获取嵌入。 + + 参数: + - url (str): vLLM 服务的模型名称,默认为 'BAAI/bge-base-en-v1.5'。 + - embedding_size (int): 输出嵌入的维度大小,默认为 768。 + - group_size (int): SimNorm 的组大小,默认为 8。 + - api_base (str): vLLM API 服务器的基本 URL,默认为 "http://10.119.30.189:8081/v1"。 + - api_key (str): API 密钥,默认值为 "EMPTY"。 + """ + super().__init__() + self.url = url + self.embedding_size = embedding_size + self.api_base = api_base + self.api_key = api_key + + # 初始化 OpenAI 客户端以连接 vLLM 的 API 服务器 + self.client = OpenAI( + api_key=api_key, + base_url=api_base, + ) + + # 获取模型 ID + models = self.client.models.list() + self.model_id = models.data[0].id if models.data else url + + # 初始化线性变换层(如果需要) + if self.embedding_size != 768: + self.embed_head = nn.Linear(768, self.embedding_size) + else: + self.embed_head = None + + # 初始化 SimNorm + self.sim_norm = SimNorm(simnorm_dim=group_size) + + # 初始化分词器,用于将 token 索引解码为字符串 + self.tokenizer = AutoTokenizer.from_pretrained(url) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + 前向传播,获取输入序列的语言表示。 + + 参数: + - x (torch.Tensor): 输入的张量,形状为 [batch_size, seq_len],类型为 torch.long。 + + 返回: + - torch.Tensor: 经过处理的语言嵌入向量,形状为 [batch_size, embedding_size]。 + """ + with torch.no_grad(): + x = x.long() # 确保输入张量为长整型 + + # # 检查索引范围 + # min_idx = x.min().item() + # max_idx = x.max().item() + # # print(f"min_idx: {min_idx}, max_idx: {max_idx}") + # assert min_idx >= 0, "Negative token indices found." + # assert max_idx < self.tokenizer.vocab_size, f"Token index {max_idx} exceeds vocab size {self.tokenizer.vocab_size}." + + # 将 token 索引解码为字符串 + # 假设每个样本的 [CLS] token 在位置 0 + # 可以根据实际情况调整 + batch_size = x.size(0) + sentences: List[str] = [] + for i in range(batch_size): + # 解码为字符串 + tokens = x[i].tolist() + sentence = self.tokenizer.decode(tokens, skip_special_tokens=True) + sentences.append(sentence) + + # 调用 vLLM 的嵌入 API + response = self.client.embeddings.create( + input=sentences, + model=self.model_id, + ) + + # 提取嵌入并转换为张量 + embeddings = [data.embedding for data in response.data] # List[List[float]] + embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32, device=x.device) # [batch_size, 768] + + # 如果需要降维或升维 + if self.embed_head is not None: + embeddings_tensor = self.embed_head(embeddings_tensor) # [batch_size, embedding_size] + + # 应用 SimNorm + embeddings_tensor = self.sim_norm(embeddings_tensor) # [batch_size, embedding_size] + + return embeddings_tensor + + def __getstate__(self): + state = self.__dict__.copy() + # 移除无法序列化的对象 + del state['client'] + del state['tokenizer'] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + # 重新初始化无法序列化的对象 + self.client = OpenAI( + api_key=self.api_key, + base_url=self.api_base, + ) + self.tokenizer = AutoTokenizer.from_pretrained(self.url) + + +class RepresentationNetworkUniZero(nn.Module): + + def __init__( + self, + observation_shape: SequenceType = (3, 64, 64), + num_res_blocks: int = 1, + num_channels: int = 64, + downsample: bool = True, + activation: nn.Module = nn.GELU(approximate='tanh'), + norm_type: str = 'BN', + embedding_dim: int = 256, + group_size: int = 8, + ) -> None: + """ + Overview: + Representation network used in UniZero. Encode the 2D image obs into latent state. + Currently, the network only supports obs images with both a width and height of 64. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] + for video games like atari, RGB 3 channel. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - num_channels (:obj:`int`): The channel of output hidden state. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ + Use the inplace operation to speed up. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - embedding_dim (:obj:`int`): The dimension of the latent state. + - group_size (:obj:`int`): The dimension for simplicial normalization. + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + logging.info(f"Using norm type: {norm_type}") + logging.info(f"Using activation type: {activation}") + + self.observation_shape = observation_shape + self.downsample = downsample + if self.downsample: + self.downsample_net = DownSample( + observation_shape, + num_channels, + activation=activation, + norm_type=norm_type, + ) + else: + self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) + + if norm_type == 'BN': + self.norm = nn.BatchNorm2d(num_channels) + elif norm_type == 'LN': + if downsample: + self.norm = nn.LayerNorm( + [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], + eps=1e-5) + else: + self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + self.activation = activation + self.embedding_dim = embedding_dim + + if self.observation_shape[1] == 64: + self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False) + + elif self.observation_shape[1] in [84, 96]: + self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False) + + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ + H is height. + - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ + output width, H_ is output height. + """ + if self.downsample: + x = self.downsample_net(x) + else: + x = self.conv(x) + x = self.norm(x) + x = self.activation(x) + for block in self.resblocks: + x = block(x) + + # Important: Transform the output feature plane to the latent state. + # For example, for an Atari feature plane of shape (64, 8, 8), + # flattening results in a size of 4096, which is then transformed to 768. + x = self.last_linear(x.view(x.size(0), -1)) + + x = x.view(-1, self.embedding_dim) + + # NOTE: very important for training stability. + x = self.sim_norm(x) + + return x + + +class RepresentationNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType = (4, 96, 96), + num_res_blocks: int = 1, + num_channels: int = 64, + downsample: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: str = 'BN', + embedding_dim: int = 256, + group_size: int = 8, + use_sim_norm: bool = False, + ) -> None: + """ + Overview: + Representation network used in MuZero and derived algorithms. Encode the 2D image obs into latent state. + Currently, the network only supports obs images with both a width and height of 96. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[4, 96, 96] + for video games like atari, 1 gray channel times stack 4 frames. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - num_channels (:obj:`int`): The channel of output hidden state. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ + Use the inplace operation to speed up. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - embedding_dim (:obj:`int`): The dimension of the output hidden state. + - group_size (:obj:`int`): The size of group in the SimNorm layer. + - use_sim_norm (:obj:`bool`): Whether to use SimNorm layer, defaults to False. + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + + self.downsample = downsample + if self.downsample: + self.downsample_net = DownSample( + observation_shape, + num_channels, + activation=activation, + norm_type=norm_type, + ) + else: + self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) + + if norm_type == 'BN': + self.norm = nn.BatchNorm2d(num_channels) + elif norm_type == 'LN': + if downsample: + self.norm = nn.LayerNorm( + [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], + eps=1e-5) + else: + self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + self.activation = activation + + self.use_sim_norm = use_sim_norm + + if self.use_sim_norm: + self.embedding_dim = embedding_dim + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ + H is height. + - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ + output width, H_ is output height. + """ + if self.downsample: + x = self.downsample_net(x) + else: + x = self.conv(x) + x = self.norm(x) + x = self.activation(x) + + for block in self.resblocks: + x = block(x) + + if self.use_sim_norm: + # NOTE: very important. + # for atari 64,8,8 = 4096 -> 768 + x = self.sim_norm(x) + + return x + + +class RepresentationNetworkMLP(nn.Module): + + def __init__( + self, + observation_shape: int, + hidden_channels: int = 64, + layer_num: int = 2, + activation: nn.Module = nn.GELU(approximate='tanh'), + norm_type: Optional[str] = 'BN', + group_size: int = 8, + ) -> torch.Tensor: + """ + Overview: + Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \ + with Multi-Layer Perceptron (MLP). + Arguments: + - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - hidden_channels (:obj:`int`): The channel of output hidden state. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ + Use the inplace operation to speed up. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + self.fc_representation = MLP( + in_channels=observation_shape, + hidden_channels=hidden_channels, + out_channels=hidden_channels, + layer_num=layer_num, + activation=activation, + norm_type=norm_type, + # don't use activation and norm in the last layer of representation network is important for convergence. + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=True, + ) + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. + - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. + """ + x = self.fc_representation(x) + # TODO + x = self.sim_norm(x) + return x + + +class LatentDecoder(nn.Module): + + def __init__(self, embedding_dim: int, output_shape: SequenceType, num_channels: int = 64, activation: nn.Module = nn.GELU(approximate='tanh')): + """ + Overview: + Decoder network used in UniZero. Decode the latent state into 2D image obs. + Arguments: + - embedding_dim (:obj:`int`): The dimension of the latent state. + - output_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] + for video games like atari, RGB 3 channel times stack 4 frames. + - num_channels (:obj:`int`): The channel of output hidden state. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). + """ + super().__init__() + self.embedding_dim = embedding_dim + self.output_shape = output_shape # (C, H, W) + self.num_channels = num_channels + self.activation = activation + + # Assuming that the output shape is (C, H, W) = (12, 96, 96) and embedding_dim is 256 + # We will reverse the process of the representation network + self.initial_size = ( + num_channels, output_shape[1] // 8, output_shape[2] // 8) # This should match the last layer of the encoder + self.fc = nn.Linear(self.embedding_dim, np.prod(self.initial_size)) + + # Upsampling blocks + self.conv_blocks = nn.ModuleList([ + # Block 1: (num_channels, H/8, W/8) -> (num_channels//2, H/4, W/4) + nn.ConvTranspose2d(num_channels, num_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1), + self.activation, + nn.BatchNorm2d(num_channels // 2), + # Block 2: (num_channels//2, H/4, W/4) -> (num_channels//4, H/2, W/2) + nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1, + output_padding=1), + self.activation, + nn.BatchNorm2d(num_channels // 4), + # Block 3: (num_channels//4, H/2, W/2) -> (output_shape[0], H, W) + nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1, + output_padding=1), + ]) + # TODO: last layer use sigmoid? + + def forward(self, embeddings: torch.Tensor) -> torch.Tensor: + # Map embeddings back to the image space + x = self.fc(embeddings) # (B, embedding_dim) -> (B, C*H/8*W/8) + x = x.view(-1, *self.initial_size) # (B, C*H/8*W/8) -> (B, C, H/8, W/8) + + # Apply conv blocks + for block in self.conv_blocks: + x = block(x) # Upsample progressively + + # The output x should have the shape of (B, output_shape[0], output_shape[1], output_shape[2]) + return x + + +class LatentEncoderForMemoryEnv(nn.Module): + + def __init__( + self, + image_shape=(3, 5, 5), + embedding_size=100, + channels=[16, 32, 64], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation: nn.Module = nn.GELU(approximate='tanh'), + normalize_pixel=False, + group_size: int = 8, + **kwargs, + ): + """ + Overview: + Encoder network used in UniZero in MemoryEnv. Encode the 2D image obs into latent state. + Arguments: + - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] + for video games like atari, RGB 3 channel times stack 4 frames. + - embedding_size (:obj:`int`): The dimension of the latent state. + - channels (:obj:`List[int]`): The channel of output hidden state. + - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers. + - strides (:obj:`List[int]`): The stride of convolution layers. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). \ + Use the inplace operation to speed up. + - normalize_pixel (:obj:`bool`): Whether to normalize the pixel values to [0, 1], defaults to False. + - group_size (:obj:`int`): The dimension for simplicial normalization + """ + super(LatentEncoderForMemoryEnv, self).__init__() + self.shape = image_shape + self.channels = [image_shape[0]] + list(channels) + + layers = [] + for i in range(len(self.channels) - 1): + layers.append( + nn.Conv2d( + self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i], + padding=kernel_sizes[i] // 2 # keep the same size of feature map + ) + ) + layers.append(nn.BatchNorm2d(self.channels[i + 1])) + layers.append(activation) + + layers.append(nn.AdaptiveAvgPool2d(1)) + + self.cnn = nn.Sequential(*layers) + self.linear = nn.Sequential( + nn.Linear(self.channels[-1], embedding_size, bias=False), + ) + init.kaiming_normal_(self.linear[0].weight, mode='fan_out', nonlinearity='relu') + + self.normalize_pixel = normalize_pixel + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, image): + if self.normalize_pixel: + image = image / 255.0 + x = self.cnn(image.float()) # (B, C, 1, 1) + x = torch.flatten(x, start_dim=1) # (B, C) + x = self.linear(x) # (B, embedding_size) + x = self.sim_norm(x) + return x + + +class LatentDecoderForMemoryEnv(nn.Module): + + def __init__( + self, + image_shape=(3, 5, 5), + embedding_size=256, + channels=[64, 32, 16], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), + **kwargs, + ): + """ + Overview: + Decoder network used in UniZero in MemoryEnv. Decode the latent state into 2D image obs. + Arguments: + - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] + for video games like atari, RGB 3 channel times stack 4 frames. + - embedding_size (:obj:`int`): The dimension of the latent state. + - channels (:obj:`List[int]`): The channel of output hidden state. + - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers. + - strides (:obj:`List[int]`): The stride of convolution layers. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.LeakyReLU(). \ + Use the inplace operation to speed up. + """ + super(LatentDecoderForMemoryEnv, self).__init__() + self.shape = image_shape + self.channels = list(channels) + [image_shape[0]] + + self.linear = nn.Linear(embedding_size, channels[0] * image_shape[1] * image_shape[2]) + + layers = [] + for i in range(len(self.channels) - 1): + layers.append( + nn.ConvTranspose2d( + self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i], + padding=kernel_sizes[i] // 2, output_padding=strides[i] - 1 + ) + ) + if i < len(self.channels) - 2: + layers.append(nn.BatchNorm2d(self.channels[i + 1])) + layers.append(activation) + else: + layers.append(nn.Sigmoid()) + + self.deconv = nn.Sequential(*layers) + + def forward(self, embedding): + x = self.linear(embedding) + x = x.view(-1, self.channels[0], self.shape[1], self.shape[2]) + x = self.deconv(x) # (B, C, H, W) + return x + + +class VectorDecoderForMemoryEnv(nn.Module): + + def __init__( + self, + embedding_dim: int, + output_shape: SequenceType, + hidden_channels: int = 64, + layer_num: int = 2, + activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), # TODO + norm_type: Optional[str] = 'BN', + ) -> torch.Tensor: + """ + Overview: + Decoder network used in UniZero in MemoryEnv. Decode the latent state into vector obs. + Arguments: + - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - hidden_channels (:obj:`int`): The channel of output hidden state. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \ + Use the inplace operation to speed up. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + self.fc_representation = MLP( + in_channels=embedding_dim, + hidden_channels=hidden_channels, + out_channels=output_shape, + layer_num=layer_num, + activation=activation, + norm_type=norm_type, + # don't use activation and norm in the last layer of representation network is important for convergence. + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. + - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. + """ + x = self.fc_representation(x) + return x + + +class PredictionNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType, + action_space_size: int, + num_res_blocks: int, + num_channels: int, + value_head_channels: int, + policy_head_channels: int, + fc_value_layers: int, + fc_policy_layers: int, + output_support_size: int, + flatten_output_size_for_value_head: int, + flatten_output_size_for_policy_head: int, + downsample: bool = False, + last_linear_layer_init_zero: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + ) -> None: + """ + Overview: + The definition of policy and value prediction network, which is used to predict value and policy by the + given latent state. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image. + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. + - num_channels (:obj:`int`): The channels of hidden states. + - value_head_channels (:obj:`int`): The channels of value head. + - policy_head_channels (:obj:`int`): The channels of policy head. + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - output_support_size (:obj:`int`): The size of categorical value output. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ + - flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ + of the value head. + - flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ + of the policy head. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ + dynamics/prediction mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super(PredictionNetwork, self).__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + + self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) + self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) + + if observation_shape[1] == 96: + latent_shape = (observation_shape[1] / 16, observation_shape[2] / 16) + elif observation_shape[1] == 64: + latent_shape = (observation_shape[1] / 8, observation_shape[2] / 8) + + if norm_type == 'BN': + self.norm_value = nn.BatchNorm2d(value_head_channels) + self.norm_policy = nn.BatchNorm2d(policy_head_channels) + elif norm_type == 'LN': + if downsample: + self.norm_value = nn.LayerNorm( + [value_head_channels, *latent_shape], + eps=1e-5) + self.norm_policy = nn.LayerNorm([policy_head_channels, *latent_shape], eps=1e-5) + else: + self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]], + eps=1e-5) + self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]], + eps=1e-5) + + self.flatten_output_size_for_value_head = flatten_output_size_for_value_head + self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head + + self.activation = activation + + self.fc_value = MLP( + in_channels=self.flatten_output_size_for_value_head, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=len(fc_value_layers) + 1, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.fc_policy = MLP( + in_channels=self.flatten_output_size_for_policy_head, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=len(fc_policy_layers) + 1, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + for res_block in self.resblocks: + latent_state = res_block(latent_state) + + value = self.conv1x1_value(latent_state) + value = self.norm_value(value) + value = self.activation(value) + + policy = self.conv1x1_policy(latent_state) + policy = self.norm_policy(policy) + policy = self.activation(policy) + + value = value.reshape(-1, self.flatten_output_size_for_value_head) + policy = policy.reshape(-1, self.flatten_output_size_for_policy_head) + + value = self.fc_value(value) + policy = self.fc_policy(policy) + return policy, value + + +class PredictionNetworkMLP(nn.Module): + + def __init__( + self, + action_space_size, + num_channels, + common_layer_num: int = 2, + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + ): + """ + Overview: + The definition of policy and value prediction network with Multi-Layer Perceptron (MLP), + which is used to predict value and policy by the given latent state. + Arguments: + - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \ + space, it is the number of discrete actions. + - num_channels (:obj:`int`): The channels of latent states. + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - output_support_size (:obj:`int`): The size of categorical value output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ + dynamics/prediction mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + self.num_channels = num_channels + + # ******* common backbone ****** + self.fc_prediction_common = MLP( + in_channels=self.num_channels, + hidden_channels=self.num_channels, + out_channels=self.num_channels, + layer_num=common_layer_num, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + # ******* value and policy head ****** + self.fc_value_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=len(fc_value_layers) + 1, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.fc_policy_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=len(fc_policy_layers) + 1, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, latent_state: torch.Tensor): + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + x_prediction_common = self.fc_prediction_common(latent_state) + + value = self.fc_value_head(x_prediction_common) + policy = self.fc_policy_head(x_prediction_common) + return policy, value + + +class PredictionHiddenNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType, + action_space_size: int, + num_res_blocks: int, + num_channels: int, + value_head_channels: int, + policy_head_channels: int, + fc_value_layers: int, + fc_policy_layers: int, + output_support_size: int, + flatten_output_size_for_value_head: int, + flatten_output_size_for_policy_head: int, + downsample: bool = False, + last_linear_layer_init_zero: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + gru_hidden_size: int = 512, + ) -> None: + """ + Overview: + The definition of policy and value prediction network, which is used to predict value and policy by the + given latent state. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image. + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. + - num_channels (:obj:`int`): The channels of hidden states. + - value_head_channels (:obj:`int`): The channels of value head. + - policy_head_channels (:obj:`int`): The channels of policy head. + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - output_support_size (:obj:`int`): The size of categorical value output. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ + - flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ + of the value head. + - flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ + of the policy head. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ + dynamics/prediction mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super(PredictionHiddenNetwork, self).__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + + self.observation_shape = observation_shape + self.gru_hidden_size = gru_hidden_size + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + + self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) + self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) + + if norm_type == 'BN': + self.norm_value = nn.BatchNorm2d(value_head_channels) + self.norm_policy = nn.BatchNorm2d(policy_head_channels) + elif norm_type == 'LN': + if downsample: + self.norm_value = nn.LayerNorm( + [value_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], + eps=1e-5) + self.norm_policy = nn.LayerNorm([policy_head_channels, math.ceil(observation_shape[-2] / 16), + math.ceil(observation_shape[-1] / 16)], eps=1e-5) + else: + self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]], + eps=1e-5) + self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]], + eps=1e-5) + + self.flatten_output_size_for_value_head = flatten_output_size_for_value_head + self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head + + self.activation = activation + + self.fc_value = MLP( + in_channels=self.flatten_output_size_for_value_head + self.gru_hidden_size, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=len(fc_value_layers) + 1, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.fc_policy = MLP( + in_channels=self.flatten_output_size_for_policy_head + self.gru_hidden_size, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=len(fc_policy_layers) + 1, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, latent_state: torch.Tensor, world_model_latent_history: torch.Tensor) -> Tuple[ + torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + for res_block in self.resblocks: + latent_state = res_block(latent_state) + + value = self.conv1x1_value(latent_state) + value = self.norm_value(value) + value = self.activation(value) + + policy = self.conv1x1_policy(latent_state) + policy = self.norm_policy(policy) + policy = self.activation(policy) + + latent_state_value = value.reshape(-1, self.flatten_output_size_for_value_head) + latent_state_policy = policy.reshape(-1, self.flatten_output_size_for_policy_head) + + # TODO: world_model_latent_history.squeeze(0) shape: (num_layers * num_directions, batch_size, hidden_size) -> ( batch_size, hidden_size) + latent_history_value = torch.cat([latent_state_value, world_model_latent_history.squeeze(0)], dim=1) + latent_history_policy = torch.cat([latent_state_policy, world_model_latent_history.squeeze(0)], dim=1) + + value = self.fc_value(latent_history_value) + policy = self.fc_policy(latent_history_policy) + return policy, value + + diff --git a/lzero/model/common_serve_input_str.py b/lzero/model/common_serve_input_str.py new file mode 100644 index 000000000..c55a8d8e5 --- /dev/null +++ b/lzero/model/common_serve_input_str.py @@ -0,0 +1,1235 @@ +""" +Overview: + In this Python file, we provide a collection of reusable model templates designed to streamline the development + process for various custom algorithms. By utilizing these pre-built model templates, users can quickly adapt and + customize their custom algorithms, ensuring efficient and effective development. + BTW, users can refer to the unittest of these model templates to learn how to use them. +""" +import math +from dataclasses import dataclass +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.nn.init as init +from ding.torch_utils import MLP, ResBlock +from ding.utils import SequenceType +from ditk import logging +from openai import OpenAI +from transformers import AutoTokenizer + + +# use dataclass to make the output of network more convenient to use +@dataclass +class MZRNNNetworkOutput: + # output format of the MuZeroRNN model + value: torch.Tensor + value_prefix: torch.Tensor + policy_logits: torch.Tensor + latent_state: torch.Tensor + predict_next_latent_state: torch.Tensor + reward_hidden_state: Tuple[torch.Tensor] + + +@dataclass +class EZNetworkOutput: + # output format of the EfficientZero model + value: torch.Tensor + value_prefix: torch.Tensor + policy_logits: torch.Tensor + latent_state: torch.Tensor + reward_hidden_state: Tuple[torch.Tensor] + + +@dataclass +class MZNetworkOutput: + # output format of the MuZero model + value: torch.Tensor + reward: torch.Tensor + policy_logits: torch.Tensor + latent_state: torch.Tensor + + +class SimNorm(nn.Module): + + def __init__(self, simnorm_dim: int) -> None: + """ + Overview: + Simplicial normalization. Adapted from https://arxiv.org/abs/2204.00616. + Arguments: + - simnorm_dim (:obj:`int`): The dimension for simplicial normalization. + """ + super().__init__() + self.dim = simnorm_dim + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Forward pass of the SimNorm layer. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor to normalize. + Returns: + - x (:obj:`torch.Tensor`): The normalized tensor. + """ + shp = x.shape + # Ensure that there is at least one simplex to normalize across. + if shp[1] != 0: + x = x.view(*shp[:-1], -1, self.dim) + x = F.softmax(x, dim=-1) + return x.view(*shp) + else: + return x + + def __repr__(self) -> str: + """ + Overview: + String representation of the SimNorm layer. + Returns: + - output (:obj:`str`): The string representation. + """ + return f"SimNorm(dim={self.dim})" + + +def AvgL1Norm(x, eps=1e-8): + """ + Overview: + Normalize the input tensor by the L1 norm. + Arguments: + - x (:obj:`torch.Tensor`): The input tensor to normalize. + - eps (:obj:`float`): The epsilon value to prevent division by zero. + Returns: + - :obj:`torch.Tensor`: The normalized tensor. + """ + return x / x.abs().mean(-1, keepdim=True).clamp(min=eps) + + +class FeatureAndGradientHook: + + def __init__(self): + """ + Overview: + Class to capture features and gradients at SimNorm. + """ + self.features_before = [] + self.features_after = [] + self.grads_before = [] + self.grads_after = [] + + def setup_hooks(self, model): + # Hooks to capture features and gradients at SimNorm + self.forward_handler = model.sim_norm.register_forward_hook(self.forward_hook) + self.backward_handler = model.sim_norm.register_full_backward_hook(self.backward_hook) + + def forward_hook(self, module, input, output): + with torch.no_grad(): + self.features_before.append(input[0]) + self.features_after.append(output) + + def backward_hook(self, module, grad_input, grad_output): + with torch.no_grad(): + self.grads_before.append(grad_input[0] if grad_input[0] is not None else None) + self.grads_after.append(grad_output[0] if grad_output[0] is not None else None) + + def analyze(self): + # Calculate L2 norms of features + l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_before])) + l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_after])) + + # Calculate norms of gradients + grad_norm_before = torch.mean( + torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_before if g is not None])) + grad_norm_after = torch.mean( + torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_after if g is not None])) + + # Clear stored data and delete tensors to free memory + self.clear_data() + + # Optionally clear CUDA cache + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + return l2_norm_before, l2_norm_after, grad_norm_before, grad_norm_after + + def clear_data(self): + del self.features_before[:] + del self.features_after[:] + del self.grads_before[:] + del self.grads_after[:] + + def remove_hooks(self): + self.forward_handler.remove() + self.backward_handler.remove() + + +class DownSample(nn.Module): + + def __init__(self, observation_shape: SequenceType, out_channels: int, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + num_resblocks: int = 1, + ) -> None: + """ + Overview: + Define downSample convolution network. Encode the observation into hidden state. + This network is often used in video games like Atari. In board games like go and chess, + we don't need this module. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[12, 96, 96] + for video games like atari, RGB 3 channel times stack 4 frames. + - out_channels (:obj:`int`): The output channels of output hidden state. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ + Use the inplace operation to speed up. + - norm_type (:obj:`Optional[str]`): The normalization type used in network, defaults to 'BN'. + - num_resblocks (:obj:`int`): The number of residual blocks. Defaults to 1. + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + + self.observation_shape = observation_shape + self.conv1 = nn.Conv2d( + observation_shape[0], + out_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False, # disable bias for better convergence + ) + if norm_type == 'BN': + self.norm1 = nn.BatchNorm2d(out_channels // 2) + elif norm_type == 'LN': + self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2], + eps=1e-5) + + self.resblocks1 = nn.ModuleList( + [ + ResBlock( + in_channels=out_channels // 2, + activation=activation, + norm_type=norm_type, + res_type='basic', + bias=False + ) for _ in range(num_resblocks) + ] + ) + self.downsample_block = ResBlock( + in_channels=out_channels // 2, + out_channels=out_channels, + activation=activation, + norm_type=norm_type, + res_type='downsample', + bias=False + ) + self.resblocks2 = nn.ModuleList( + [ + ResBlock( + in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_resblocks) + ] + ) + self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + self.resblocks3 = nn.ModuleList( + [ + ResBlock( + in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(1) + ] + ) + self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + self.activation = activation + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ + H is height. + - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ + output width, H_ is output height. + """ + x = self.conv1(x) + x = self.norm1(x) + x = self.activation(x) + + for block in self.resblocks1: + x = block(x) + x = self.downsample_block(x) + for block in self.resblocks2: + x = block(x) + x = self.pooling1(x) + for block in self.resblocks3: + x = block(x) + + # 64, 84, 96 are the most common observation shapes in Atari games. + if self.observation_shape[1] == 64: + output = x + elif self.observation_shape[1] == 84: + x = self.pooling2(x) + output = x + elif self.observation_shape[1] == 96: + x = self.pooling2(x) + output = x + else: + raise NotImplementedError(f"DownSample for observation shape {self.observation_shape} is not implemented now. " + f"You should transform the observation shape to 64 or 96 in the env.") + + return output + +""" +使用vllm的BAAI/bge-base-en-v1.5 server, 模型的输入为string +""" +class HFLanguageRepresentationNetwork(nn.Module): + def __init__( + self, + url: str = 'BAAI/bge-base-en-v1.5', + embedding_size: int = 768, + group_size: int = 8, + api_base: str = "http://10.119.30.189:8081/v1", + api_key: str = "EMPTY" + ): + """ + 初始化语言表示网络,使用 vLLM 的 API 服务获取嵌入。 + + 参数: + - url (str): vLLM 服务的模型名称,默认为 'BAAI/bge-base-en-v1.5'。 + - embedding_size (int): 输出嵌入的维度大小,默认为 768。 + - group_size (int): SimNorm 的组大小,默认为 8。 + - api_base (str): vLLM API 服务器的基本 URL,默认为 "http://10.119.30.189:8081/v1"。 + - api_key (str): API 密钥,默认值为 "EMPTY"。 + """ + super().__init__() + self.url = url + self.embedding_size = embedding_size + self.api_base = api_base + self.api_key = api_key + + # 初始化 OpenAI 客户端以连接 vLLM 的 API 服务器 + self.client = OpenAI( + api_key=api_key, + base_url=api_base, + ) + + # 获取模型 ID + models = self.client.models.list() + self.model_id = models.data[0].id if models.data else url + + # 初始化线性变换层(如果需要) + if self.embedding_size != 768: + self.embed_head = nn.Linear(768, self.embedding_size) + else: + self.embed_head = None + + # 初始化 SimNorm + self.sim_norm = SimNorm(simnorm_dim=group_size) + + # 初始化分词器,用于将字符串转为 token + self.tokenizer = AutoTokenizer.from_pretrained(url) + + def forward(self, x: List[str]) -> torch.Tensor: + """ + 前向传播,获取输入序列的语言表示。 + + 参数: + - x (List[str]): 输入的字符串列表,长度为 batch_size。 + + 返回: + - torch.Tensor: 经过处理的语言嵌入向量,形状为 [batch_size, embedding_size]。 + """ + with torch.no_grad(): + # 分词 + encoded_inputs = self.tokenizer( + x, truncation=True, padding="max_length", max_length=self.max_seq_len, return_tensors='pt' + ) + input_texts = encoded_inputs['input_ids'].tolist() + + # 调用 vLLM 的嵌入 API + response = self.client.embeddings.create( + input=x, + model=self.model_id, + ) + + # 提取嵌入并转换为张量 + embeddings = [data.embedding for data in response.data] # List[List[float]] + embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32, device=self.device) # [batch_size, 768] + + # 如果需要降维或升维 + if self.embed_head is not None: + embeddings_tensor = self.embed_head(embeddings_tensor) # [batch_size, embedding_size] + + # 应用 SimNorm + embeddings_tensor = self.sim_norm(embeddings_tensor) # [batch_size, embedding_size] + + return embeddings_tensor + + def __getstate__(self): + state = self.__dict__.copy() + # 移除无法序列化的对象 + del state['client'] + del state['tokenizer'] + return state + + def __setstate__(self, state): + self.__dict__.update(state) + # 重新初始化无法序列化的对象 + self.client = OpenAI( + api_key=self.api_key, + base_url=self.api_base, + ) + self.tokenizer = AutoTokenizer.from_pretrained(self.url) + + +class RepresentationNetworkUniZero(nn.Module): + + def __init__( + self, + observation_shape: SequenceType = (3, 64, 64), + num_res_blocks: int = 1, + num_channels: int = 64, + downsample: bool = True, + activation: nn.Module = nn.GELU(approximate='tanh'), + norm_type: str = 'BN', + embedding_dim: int = 256, + group_size: int = 8, + ) -> None: + """ + Overview: + Representation network used in UniZero. Encode the 2D image obs into latent state. + Currently, the network only supports obs images with both a width and height of 64. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] + for video games like atari, RGB 3 channel. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - num_channels (:obj:`int`): The channel of output hidden state. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ + Use the inplace operation to speed up. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - embedding_dim (:obj:`int`): The dimension of the latent state. + - group_size (:obj:`int`): The dimension for simplicial normalization. + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + logging.info(f"Using norm type: {norm_type}") + logging.info(f"Using activation type: {activation}") + + self.observation_shape = observation_shape + self.downsample = downsample + if self.downsample: + self.downsample_net = DownSample( + observation_shape, + num_channels, + activation=activation, + norm_type=norm_type, + ) + else: + self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) + + if norm_type == 'BN': + self.norm = nn.BatchNorm2d(num_channels) + elif norm_type == 'LN': + if downsample: + self.norm = nn.LayerNorm( + [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], + eps=1e-5) + else: + self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + self.activation = activation + self.embedding_dim = embedding_dim + + if self.observation_shape[1] == 64: + self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False) + + elif self.observation_shape[1] in [84, 96]: + self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False) + + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ + H is height. + - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ + output width, H_ is output height. + """ + if self.downsample: + x = self.downsample_net(x) + else: + x = self.conv(x) + x = self.norm(x) + x = self.activation(x) + for block in self.resblocks: + x = block(x) + + # Important: Transform the output feature plane to the latent state. + # For example, for an Atari feature plane of shape (64, 8, 8), + # flattening results in a size of 4096, which is then transformed to 768. + x = self.last_linear(x.view(x.size(0), -1)) + + x = x.view(-1, self.embedding_dim) + + # NOTE: very important for training stability. + x = self.sim_norm(x) + + return x + + +class RepresentationNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType = (4, 96, 96), + num_res_blocks: int = 1, + num_channels: int = 64, + downsample: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: str = 'BN', + embedding_dim: int = 256, + group_size: int = 8, + use_sim_norm: bool = False, + ) -> None: + """ + Overview: + Representation network used in MuZero and derived algorithms. Encode the 2D image obs into latent state. + Currently, the network only supports obs images with both a width and height of 96. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[4, 96, 96] + for video games like atari, 1 gray channel times stack 4 frames. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - num_channels (:obj:`int`): The channel of output hidden state. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ + Use the inplace operation to speed up. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + - embedding_dim (:obj:`int`): The dimension of the output hidden state. + - group_size (:obj:`int`): The size of group in the SimNorm layer. + - use_sim_norm (:obj:`bool`): Whether to use SimNorm layer, defaults to False. + """ + super().__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + + self.downsample = downsample + if self.downsample: + self.downsample_net = DownSample( + observation_shape, + num_channels, + activation=activation, + norm_type=norm_type, + ) + else: + self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False) + + if norm_type == 'BN': + self.norm = nn.BatchNorm2d(num_channels) + elif norm_type == 'LN': + if downsample: + self.norm = nn.LayerNorm( + [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], + eps=1e-5) + else: + self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5) + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + self.activation = activation + + self.use_sim_norm = use_sim_norm + + if self.use_sim_norm: + self.embedding_dim = embedding_dim + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \ + H is height. + - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \ + output width, H_ is output height. + """ + if self.downsample: + x = self.downsample_net(x) + else: + x = self.conv(x) + x = self.norm(x) + x = self.activation(x) + + for block in self.resblocks: + x = block(x) + + if self.use_sim_norm: + # NOTE: very important. + # for atari 64,8,8 = 4096 -> 768 + x = self.sim_norm(x) + + return x + + +class RepresentationNetworkMLP(nn.Module): + + def __init__( + self, + observation_shape: int, + hidden_channels: int = 64, + layer_num: int = 2, + activation: nn.Module = nn.GELU(approximate='tanh'), + norm_type: Optional[str] = 'BN', + group_size: int = 8, + ) -> torch.Tensor: + """ + Overview: + Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \ + with Multi-Layer Perceptron (MLP). + Arguments: + - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - hidden_channels (:obj:`int`): The channel of output hidden state. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \ + Use the inplace operation to speed up. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + self.fc_representation = MLP( + in_channels=observation_shape, + hidden_channels=hidden_channels, + out_channels=hidden_channels, + layer_num=layer_num, + activation=activation, + norm_type=norm_type, + # don't use activation and norm in the last layer of representation network is important for convergence. + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=True, + ) + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. + - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. + """ + x = self.fc_representation(x) + # TODO + x = self.sim_norm(x) + return x + + +class LatentDecoder(nn.Module): + + def __init__(self, embedding_dim: int, output_shape: SequenceType, num_channels: int = 64, activation: nn.Module = nn.GELU(approximate='tanh')): + """ + Overview: + Decoder network used in UniZero. Decode the latent state into 2D image obs. + Arguments: + - embedding_dim (:obj:`int`): The dimension of the latent state. + - output_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] + for video games like atari, RGB 3 channel times stack 4 frames. + - num_channels (:obj:`int`): The channel of output hidden state. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). + """ + super().__init__() + self.embedding_dim = embedding_dim + self.output_shape = output_shape # (C, H, W) + self.num_channels = num_channels + self.activation = activation + + # Assuming that the output shape is (C, H, W) = (12, 96, 96) and embedding_dim is 256 + # We will reverse the process of the representation network + self.initial_size = ( + num_channels, output_shape[1] // 8, output_shape[2] // 8) # This should match the last layer of the encoder + self.fc = nn.Linear(self.embedding_dim, np.prod(self.initial_size)) + + # Upsampling blocks + self.conv_blocks = nn.ModuleList([ + # Block 1: (num_channels, H/8, W/8) -> (num_channels//2, H/4, W/4) + nn.ConvTranspose2d(num_channels, num_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1), + self.activation, + nn.BatchNorm2d(num_channels // 2), + # Block 2: (num_channels//2, H/4, W/4) -> (num_channels//4, H/2, W/2) + nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1, + output_padding=1), + self.activation, + nn.BatchNorm2d(num_channels // 4), + # Block 3: (num_channels//4, H/2, W/2) -> (output_shape[0], H, W) + nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1, + output_padding=1), + ]) + # TODO: last layer use sigmoid? + + def forward(self, embeddings: torch.Tensor) -> torch.Tensor: + # Map embeddings back to the image space + x = self.fc(embeddings) # (B, embedding_dim) -> (B, C*H/8*W/8) + x = x.view(-1, *self.initial_size) # (B, C*H/8*W/8) -> (B, C, H/8, W/8) + + # Apply conv blocks + for block in self.conv_blocks: + x = block(x) # Upsample progressively + + # The output x should have the shape of (B, output_shape[0], output_shape[1], output_shape[2]) + return x + + +class LatentEncoderForMemoryEnv(nn.Module): + + def __init__( + self, + image_shape=(3, 5, 5), + embedding_size=100, + channels=[16, 32, 64], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation: nn.Module = nn.GELU(approximate='tanh'), + normalize_pixel=False, + group_size: int = 8, + **kwargs, + ): + """ + Overview: + Encoder network used in UniZero in MemoryEnv. Encode the 2D image obs into latent state. + Arguments: + - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] + for video games like atari, RGB 3 channel times stack 4 frames. + - embedding_size (:obj:`int`): The dimension of the latent state. + - channels (:obj:`List[int]`): The channel of output hidden state. + - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers. + - strides (:obj:`List[int]`): The stride of convolution layers. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). \ + Use the inplace operation to speed up. + - normalize_pixel (:obj:`bool`): Whether to normalize the pixel values to [0, 1], defaults to False. + - group_size (:obj:`int`): The dimension for simplicial normalization + """ + super(LatentEncoderForMemoryEnv, self).__init__() + self.shape = image_shape + self.channels = [image_shape[0]] + list(channels) + + layers = [] + for i in range(len(self.channels) - 1): + layers.append( + nn.Conv2d( + self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i], + padding=kernel_sizes[i] // 2 # keep the same size of feature map + ) + ) + layers.append(nn.BatchNorm2d(self.channels[i + 1])) + layers.append(activation) + + layers.append(nn.AdaptiveAvgPool2d(1)) + + self.cnn = nn.Sequential(*layers) + self.linear = nn.Sequential( + nn.Linear(self.channels[-1], embedding_size, bias=False), + ) + init.kaiming_normal_(self.linear[0].weight, mode='fan_out', nonlinearity='relu') + + self.normalize_pixel = normalize_pixel + self.sim_norm = SimNorm(simnorm_dim=group_size) + + def forward(self, image): + if self.normalize_pixel: + image = image / 255.0 + x = self.cnn(image.float()) # (B, C, 1, 1) + x = torch.flatten(x, start_dim=1) # (B, C) + x = self.linear(x) # (B, embedding_size) + x = self.sim_norm(x) + return x + + +class LatentDecoderForMemoryEnv(nn.Module): + + def __init__( + self, + image_shape=(3, 5, 5), + embedding_size=256, + channels=[64, 32, 16], + kernel_sizes=[3, 3, 3], + strides=[1, 1, 1], + activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), + **kwargs, + ): + """ + Overview: + Decoder network used in UniZero in MemoryEnv. Decode the latent state into 2D image obs. + Arguments: + - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64] + for video games like atari, RGB 3 channel times stack 4 frames. + - embedding_size (:obj:`int`): The dimension of the latent state. + - channels (:obj:`List[int]`): The channel of output hidden state. + - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers. + - strides (:obj:`List[int]`): The stride of convolution layers. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.LeakyReLU(). \ + Use the inplace operation to speed up. + """ + super(LatentDecoderForMemoryEnv, self).__init__() + self.shape = image_shape + self.channels = list(channels) + [image_shape[0]] + + self.linear = nn.Linear(embedding_size, channels[0] * image_shape[1] * image_shape[2]) + + layers = [] + for i in range(len(self.channels) - 1): + layers.append( + nn.ConvTranspose2d( + self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i], + padding=kernel_sizes[i] // 2, output_padding=strides[i] - 1 + ) + ) + if i < len(self.channels) - 2: + layers.append(nn.BatchNorm2d(self.channels[i + 1])) + layers.append(activation) + else: + layers.append(nn.Sigmoid()) + + self.deconv = nn.Sequential(*layers) + + def forward(self, embedding): + x = self.linear(embedding) + x = x.view(-1, self.channels[0], self.shape[1], self.shape[2]) + x = self.deconv(x) # (B, C, H, W) + return x + + +class VectorDecoderForMemoryEnv(nn.Module): + + def __init__( + self, + embedding_dim: int, + output_shape: SequenceType, + hidden_channels: int = 64, + layer_num: int = 2, + activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), # TODO + norm_type: Optional[str] = 'BN', + ) -> torch.Tensor: + """ + Overview: + Decoder network used in UniZero in MemoryEnv. Decode the latent state into vector obs. + Arguments: + - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10. + - num_res_blocks (:obj:`int`): The number of residual blocks. + - hidden_channels (:obj:`int`): The channel of output hidden state. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \ + defaults to True. This option is often used in video games like Atari. In board games like go, \ + we don't need this module. + - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \ + Use the inplace operation to speed up. + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + self.fc_representation = MLP( + in_channels=embedding_dim, + hidden_channels=hidden_channels, + out_channels=output_shape, + layer_num=layer_num, + activation=activation, + norm_type=norm_type, + # don't use activation and norm in the last layer of representation network is important for convergence. + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=True, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Shapes: + - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation. + - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size. + """ + x = self.fc_representation(x) + return x + + +class PredictionNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType, + action_space_size: int, + num_res_blocks: int, + num_channels: int, + value_head_channels: int, + policy_head_channels: int, + fc_value_layers: int, + fc_policy_layers: int, + output_support_size: int, + flatten_output_size_for_value_head: int, + flatten_output_size_for_policy_head: int, + downsample: bool = False, + last_linear_layer_init_zero: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + ) -> None: + """ + Overview: + The definition of policy and value prediction network, which is used to predict value and policy by the + given latent state. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image. + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. + - num_channels (:obj:`int`): The channels of hidden states. + - value_head_channels (:obj:`int`): The channels of value head. + - policy_head_channels (:obj:`int`): The channels of policy head. + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - output_support_size (:obj:`int`): The size of categorical value output. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ + - flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ + of the value head. + - flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ + of the policy head. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ + dynamics/prediction mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super(PredictionNetwork, self).__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + + self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) + self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) + + if observation_shape[1] == 96: + latent_shape = (observation_shape[1] / 16, observation_shape[2] / 16) + elif observation_shape[1] == 64: + latent_shape = (observation_shape[1] / 8, observation_shape[2] / 8) + + if norm_type == 'BN': + self.norm_value = nn.BatchNorm2d(value_head_channels) + self.norm_policy = nn.BatchNorm2d(policy_head_channels) + elif norm_type == 'LN': + if downsample: + self.norm_value = nn.LayerNorm( + [value_head_channels, *latent_shape], + eps=1e-5) + self.norm_policy = nn.LayerNorm([policy_head_channels, *latent_shape], eps=1e-5) + else: + self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]], + eps=1e-5) + self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]], + eps=1e-5) + + self.flatten_output_size_for_value_head = flatten_output_size_for_value_head + self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head + + self.activation = activation + + self.fc_value = MLP( + in_channels=self.flatten_output_size_for_value_head, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=len(fc_value_layers) + 1, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.fc_policy = MLP( + in_channels=self.flatten_output_size_for_policy_head, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=len(fc_policy_layers) + 1, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + for res_block in self.resblocks: + latent_state = res_block(latent_state) + + value = self.conv1x1_value(latent_state) + value = self.norm_value(value) + value = self.activation(value) + + policy = self.conv1x1_policy(latent_state) + policy = self.norm_policy(policy) + policy = self.activation(policy) + + value = value.reshape(-1, self.flatten_output_size_for_value_head) + policy = policy.reshape(-1, self.flatten_output_size_for_policy_head) + + value = self.fc_value(value) + policy = self.fc_policy(policy) + return policy, value + + +class PredictionNetworkMLP(nn.Module): + + def __init__( + self, + action_space_size, + num_channels, + common_layer_num: int = 2, + fc_value_layers: SequenceType = [32], + fc_policy_layers: SequenceType = [32], + output_support_size: int = 601, + last_linear_layer_init_zero: bool = True, + activation: Optional[nn.Module] = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + ): + """ + Overview: + The definition of policy and value prediction network with Multi-Layer Perceptron (MLP), + which is used to predict value and policy by the given latent state. + Arguments: + - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \ + space, it is the number of discrete actions. + - num_channels (:obj:`int`): The channels of latent states. + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - output_support_size (:obj:`int`): The size of categorical value output. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ + dynamics/prediction mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super().__init__() + self.num_channels = num_channels + + # ******* common backbone ****** + self.fc_prediction_common = MLP( + in_channels=self.num_channels, + hidden_channels=self.num_channels, + out_channels=self.num_channels, + layer_num=common_layer_num, + activation=activation, + norm_type=norm_type, + output_activation=True, + output_norm=True, + # last_linear_layer_init_zero=False is important for convergence + last_linear_layer_init_zero=False, + ) + + # ******* value and policy head ****** + self.fc_value_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=len(fc_value_layers) + 1, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.fc_policy_head = MLP( + in_channels=self.num_channels, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=len(fc_policy_layers) + 1, + activation=activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, latent_state: torch.Tensor): + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + x_prediction_common = self.fc_prediction_common(latent_state) + + value = self.fc_value_head(x_prediction_common) + policy = self.fc_policy_head(x_prediction_common) + return policy, value + + +class PredictionHiddenNetwork(nn.Module): + + def __init__( + self, + observation_shape: SequenceType, + action_space_size: int, + num_res_blocks: int, + num_channels: int, + value_head_channels: int, + policy_head_channels: int, + fc_value_layers: int, + fc_policy_layers: int, + output_support_size: int, + flatten_output_size_for_value_head: int, + flatten_output_size_for_policy_head: int, + downsample: bool = False, + last_linear_layer_init_zero: bool = True, + activation: nn.Module = nn.ReLU(inplace=True), + norm_type: Optional[str] = 'BN', + gru_hidden_size: int = 512, + ) -> None: + """ + Overview: + The definition of policy and value prediction network, which is used to predict value and policy by the + given latent state. + Arguments: + - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image. + - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space. + - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model. + - num_channels (:obj:`int`): The channels of hidden states. + - value_head_channels (:obj:`int`): The channels of value head. + - policy_head_channels (:obj:`int`): The channels of policy head. + - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head). + - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head). + - output_support_size (:obj:`int`): The size of categorical value output. + - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \ + - flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ + of the value head. + - flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \ + of the policy head. + - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``. + - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \ + dynamics/prediction mlp, default sets it to True. + - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \ + operation to speedup, e.g. ReLU(inplace=True). + - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'. + """ + super(PredictionHiddenNetwork, self).__init__() + assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']" + + self.observation_shape = observation_shape + self.gru_hidden_size = gru_hidden_size + self.resblocks = nn.ModuleList( + [ + ResBlock( + in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False + ) for _ in range(num_res_blocks) + ] + ) + + self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1) + self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1) + + if norm_type == 'BN': + self.norm_value = nn.BatchNorm2d(value_head_channels) + self.norm_policy = nn.BatchNorm2d(policy_head_channels) + elif norm_type == 'LN': + if downsample: + self.norm_value = nn.LayerNorm( + [value_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)], + eps=1e-5) + self.norm_policy = nn.LayerNorm([policy_head_channels, math.ceil(observation_shape[-2] / 16), + math.ceil(observation_shape[-1] / 16)], eps=1e-5) + else: + self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]], + eps=1e-5) + self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]], + eps=1e-5) + + self.flatten_output_size_for_value_head = flatten_output_size_for_value_head + self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head + + self.activation = activation + + self.fc_value = MLP( + in_channels=self.flatten_output_size_for_value_head + self.gru_hidden_size, + hidden_channels=fc_value_layers[0], + out_channels=output_support_size, + layer_num=len(fc_value_layers) + 1, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + self.fc_policy = MLP( + in_channels=self.flatten_output_size_for_policy_head + self.gru_hidden_size, + hidden_channels=fc_policy_layers[0], + out_channels=action_space_size, + layer_num=len(fc_policy_layers) + 1, + activation=self.activation, + norm_type=norm_type, + output_activation=False, + output_norm=False, + # last_linear_layer_init_zero=True is beneficial for convergence speed. + last_linear_layer_init_zero=last_linear_layer_init_zero + ) + + def forward(self, latent_state: torch.Tensor, world_model_latent_history: torch.Tensor) -> Tuple[ + torch.Tensor, torch.Tensor]: + """ + Overview: + Forward computation of the prediction network. + Arguments: + - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim). + Returns: + - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size). + - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size). + """ + for res_block in self.resblocks: + latent_state = res_block(latent_state) + + value = self.conv1x1_value(latent_state) + value = self.norm_value(value) + value = self.activation(value) + + policy = self.conv1x1_policy(latent_state) + policy = self.norm_policy(policy) + policy = self.activation(policy) + + latent_state_value = value.reshape(-1, self.flatten_output_size_for_value_head) + latent_state_policy = policy.reshape(-1, self.flatten_output_size_for_policy_head) + + # TODO: world_model_latent_history.squeeze(0) shape: (num_layers * num_directions, batch_size, hidden_size) -> ( batch_size, hidden_size) + latent_history_value = torch.cat([latent_state_value, world_model_latent_history.squeeze(0)], dim=1) + latent_history_policy = torch.cat([latent_state_policy, world_model_latent_history.squeeze(0)], dim=1) + + value = self.fc_value(latent_history_value) + policy = self.fc_policy(latent_history_policy) + return policy, value + + diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index e28322215..239e7efb3 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -6,7 +6,8 @@ from easydict import EasyDict from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ - VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook + VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \ + HFLanguageRepresentationNetwork from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model import WorldModel @@ -87,6 +88,15 @@ def __init__( print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') print('==' * 20) + elif world_model_cfg.obs_type == 'text': + self.representation_network = HFLanguageRepresentationNetwork(url=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False,) + self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) elif world_model_cfg.obs_type == 'image': self.representation_network = RepresentationNetworkUniZero( observation_shape, diff --git a/lzero/model/unizero_world_models/hf_transformer.py b/lzero/model/unizero_world_models/hf_transformer.py new file mode 100644 index 000000000..d0ef762be --- /dev/null +++ b/lzero/model/unizero_world_models/hf_transformer.py @@ -0,0 +1,110 @@ +from typing import Optional + +import torch +from transformers import LlamaForCausalLM +from transformers.cache_utils import DynamicCache + +from .kv_caching import KeysValues + + +def kv2dc(cache: KeysValues): + res = DynamicCache() + for kv_cache in cache: + k_tensor = kv_cache._k_cache.get() + v_tensor = kv_cache._v_cache.get() + res.key_cache.append(k_tensor) + res.value_cache.append(v_tensor) + return res + + +def update_kv(cache: KeysValues, new_cache: DynamicCache): + for i in range(len(new_cache.key_cache)): + cache[i].update(new_cache.key_cache[-1], new_cache.value_cache[-1]) + + +class HuggingfaceLlamaTransformer(LlamaForCausalLM): + @classmethod + def from_pretrained(cls, lzero_config, *args, **kwargs): + # Add custom logic here + model = super(HuggingfaceLlamaTransformer, cls).from_pretrained(*args, **kwargs) + model.lzero_config = lzero_config + return model + + def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: + """ + Generate a placeholder for keys and values. + + Arguments: + - n (:obj:`int`): Batch size. + - max_tokens (:obj:`int`): Maximum number of tokens in the sequence. + + Returns: + - KeysValues: An object containing empty keys and values. + """ + device = self.lzero_config.device # Assumption: All submodules are on the same device + return KeysValues(n, self.lzero_config.num_heads, max_tokens, + self.lzero_config.embed_dim, self.lzero_config.num_layers, + device, self.lzero_config.hidden_size) + + def _get_positional_embedding(self, layer, attn_type, pos_emb) -> torch.Tensor: + """ + Helper function to get positional embedding for a given layer and attention type. + + Arguments: + - layer (:obj:`int`): Layer index. + - attn_type (:obj:`str`): Attention type, either 'key' or 'value'. + + Returns: + - torch.Tensor: The positional embedding tensor. + """ + if attn_type == 'key': + module_name = 'k_proj' + elif attn_type == 'value': + module_name = 'v_proj' + elif attn_type == 'query': + module_name = 'q_proj' + else: + assert False + attn_func = getattr(self.model.layers[layer].self_attn, module_name) + return attn_func(pos_emb.weight) + + def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None, + valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + Forward pass of the Transformer model. + + Arguments: + - sequences (:obj:`torch.Tensor`): Input tensor of shape (batch_size, seq_length, embed_dim). + - past_keys_values (:obj:`Optional[KeysValues]`): Precomputed keys and values for faster generation (default: None). + - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid lengths of context for masking (default: None). + + Returns: + - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim). + """ + assert past_keys_values is None or len(past_keys_values) == len(self.model.layers) + if past_keys_values is not None: + kv_cache = kv2dc(past_keys_values) + use_cache = True + else: + kv_cache = None + use_cache = False + + B, T, _ = sequences.shape + if valid_context_lengths is not None: + attention_mask = torch.arange(T).expand(B, T) >= (T - valid_context_lengths.unsqueeze(1)) + else: + attention_mask = torch.ones_like(sequences) + # print(valid_context_lengths.shape) + # print(attention_mask.shape) + # print(sequences.shape) + # assert False + + output = self.model.forward( + attention_mask=attention_mask, + past_key_values=kv_cache, + inputs_embeds=sequences, + use_cache=use_cache + ) + + update_kv(past_keys_values, kv_cache) + return output.logits[:, -1, :] diff --git a/lzero/model/unizero_world_models/kv_caching.py b/lzero/model/unizero_world_models/kv_caching.py index 28b7b0ba2..7ede62197 100644 --- a/lzero/model/unizero_world_models/kv_caching.py +++ b/lzero/model/unizero_world_models/kv_caching.py @@ -1,13 +1,14 @@ # Modified from https://github.com/eloialonso/iris/blob/main/src/models/kv_caching.py -from typing import Tuple +from typing import Tuple, Optional import numpy as np import torch class Cache: - def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: + def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device, + hidden_size: Optional[int]) -> None: """ Overview: Cache for storing intermediate results in a transformer model. @@ -20,7 +21,9 @@ def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: """ assert embed_dim % num_heads == 0 self._num_samples, self._cache, self._size = num_samples, None, None - self._reset = lambda n: torch.empty(n, num_heads, max_tokens, embed_dim // num_heads, device=device) # (B, nh, T, hs) + if hidden_size == None: + hidden_size = embed_dim // num_heads + self._reset = lambda n: torch.empty(n, num_heads, max_tokens, hidden_size, device=device) # (B, nh, T, hs) self.reset() @property @@ -77,7 +80,8 @@ def update(self, x: torch.Tensor, tokens: int) -> None: class KVCache: - def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None: + def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device, + hidden_size: Optional[int]) -> None: """ Overview: Cache for storing key and value tensors in a transformer model. @@ -88,8 +92,8 @@ def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, devi - embed_dim (:obj:`int`): The dimension of the embeddings. - device (:obj:`torch.device`): The device on which to store the cache. """ - self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device) - self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device) + self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device, hidden_size) + self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device, hidden_size) @property def shape(self) -> Tuple[int, int, int, int]: @@ -142,7 +146,8 @@ def update(self, k: torch.Tensor, v: torch.Tensor): class KeysValues: - def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device) -> None: + def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device, + hidden_size: Optional[int] = None) -> None: """ Overview: Class for managing multiple layers of key and value caches in a transformer model. @@ -154,7 +159,8 @@ def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, num_ - num_layers (:obj:`int`): The number of layers in the transformer model. - device (:obj:`torch.device`): The device on which to store the caches. """ - self._keys_values = tuple([KVCache(n, num_heads, max_tokens, embed_dim, device) for _ in range(num_layers)]) + self._keys_values = tuple([KVCache(n, num_heads, max_tokens, embed_dim, device, hidden_size) + for _ in range(num_layers)]) def __getitem__(self, index: int) -> KVCache: """ diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py index bd066ccec..1a237610e 100644 --- a/lzero/model/unizero_world_models/tokenizer.py +++ b/lzero/model/unizero_world_models/tokenizer.py @@ -64,7 +64,11 @@ def encode_to_obs_embeddings(self, x: torch.Tensor) -> torch.Tensor: Returns: torch.Tensor: Encoded embeddings of shape (B, 1, E). """ + # NOTE: only for Jerico env + # x = x.long() # 确保输入为长整型 shape = x.shape + # print(f"Max index in x: {x.max()}") + # print(f"Min index in x: {x.min()}") # Process input tensor based on its dimensionality if len(shape) == 2: # Case when input is 2D (B, E) @@ -72,7 +76,14 @@ def encode_to_obs_embeddings(self, x: torch.Tensor) -> torch.Tensor: obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 3: # Case when input is 3D (B, T, E) - x = x.contiguous().view(-1, shape[-1]) # Flatten the last two dimensions (B * T, E) + # print(f"x:{x}") + try: + x = x.contiguous().view(-1, shape[-1]) # Flatten the last two dimensions (B * T, E) + except Exception as e: + print(f"x:{x}") + print(' =='*20) + print(e) + obs_embeddings = self.encoder(x) obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e') elif len(shape) == 4: diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py index 62536c892..3fa0a3f65 100644 --- a/lzero/model/unizero_world_models/transformer.py +++ b/lzero/model/unizero_world_models/transformer.py @@ -69,6 +69,20 @@ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues: device = self.ln_f.weight.device # Assumption: All submodules are on the same device return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device) + def _get_positional_embedding(self, layer, attn_type, pos_emb) -> torch.Tensor: + """ + Helper function to get positional embedding for a given layer and attention type. + + Arguments: + - layer (:obj:`int`): Layer index. + - attn_type (:obj:`str`): Attention type, either 'key' or 'value'. + + Returns: + - torch.Tensor: The positional embedding tensor. + """ + attn_func = getattr(self.blocks[layer].attn, attn_type) + return attn_func(pos_emb.weight) + def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None, valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor: """ diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index 37d4cd3ec..0a0163b44 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -19,6 +19,7 @@ from .transformer import Transformer, TransformerConfig from .utils import LossWithIntermediateLosses, init_weights from .utils import WorldModelOutput, hash_state +from .hf_transformer import HuggingfaceLlamaTransformer logging.getLogger().setLevel(logging.DEBUG) @@ -45,7 +46,10 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: super().__init__() self.tokenizer = tokenizer self.config = config - self.transformer = Transformer(self.config) + if self.config.use_hf: + self.transformer = HuggingfaceLlamaTransformer.from_pretrained(self.config, self.config.pretrained_path) + else: + self.transformer = Transformer(self.config) if self.config.device == 'cpu': self.device = torch.device('cpu') @@ -61,7 +65,8 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None: # Initialize patterns for block masks self._initialize_patterns() - self.hidden_size = config.embed_dim // config.num_heads + self.hidden_size = config['hidden_size'] if "hidden_size" in config else config.embed_dim // config.num_heads + config['hidden_size'] = self.hidden_size # Position embedding self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device) @@ -403,15 +408,13 @@ def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor: Returns: - torch.Tensor: The positional embedding tensor. """ - attn_func = getattr(self.transformer.blocks[layer].attn, attn_type) + tmp = self.transformer._get_positional_embedding(layer, attn_type, self.pos_emb).view( + 1, self.config.max_tokens, -1, self.embed_dim // self.num_heads + ) if torch.cuda.is_available(): - return attn_func(self.pos_emb.weight).view( - 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads - ).transpose(1, 2).to(self.device).detach() + return tmp.transpose(1, 2).to(self.device).detach() else: - return attn_func(self.pos_emb.weight).view( - 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads - ).transpose(1, 2).detach() + return tmp.transpose(1, 2).detach() def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]], past_keys_values: Optional[torch.Tensor] = None, @@ -1168,6 +1171,18 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar # reconstructed_images) latent_recon_loss = self.latent_recon_loss + elif self.obs_type == 'text': + perceptual_loss = torch.tensor(0., device=batch['observations'].device, + dtype=torch.float32) + + # Reconstruct observations from latent state representations + # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim)) + + # # Calculate reconstruction loss + # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25), + # reconstructed_images) + latent_recon_loss = self.latent_recon_loss + elif self.obs_type == 'image_memory': # Reconstruct observations from latent state representations # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings) diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py index cf95c46d9..90af62ec0 100644 --- a/lzero/policy/unizero.py +++ b/lzero/policy/unizero.py @@ -383,7 +383,19 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in # Convert to categorical distributions target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward) - target_value_categorical = phi_transform(self.value_support, transformed_target_value) + # print(f'transformed_target_value:{transformed_target_value}') + # print("self.value_support:", self.value_support) + + try: + target_value_categorical = phi_transform(self.value_support, transformed_target_value) + except Exception as e: + print('='*20) + print(e) + # print(f'transformed_target_value:{transformed_target_value}') + # print("self.value_support:", self.value_support) + print('='*20) + # target_value_categorical = phi_transform(self.value_support, transformed_target_value) + # Prepare batch for GPT model batch_for_gpt = {} @@ -405,9 +417,15 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in batch_for_gpt['target_policy'] = target_policy[:, :-1] # Extract valid target policy data and compute entropy - valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] - target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) - average_target_policy_entropy = target_policy_entropy.mean() + try: + valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']] + target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1) + average_target_policy_entropy = target_policy_entropy.mean() + except Exception as e: + print('='*20) + print(e) + average_target_policy_entropy = 0. + # Update world model losses = self._learn_model.world_model.compute_loss( @@ -433,8 +451,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model'] latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms'] - assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" - assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" + # assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values" + # assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values" # Core learn model update step self._optimizer_world_model.zero_grad() diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py index eff413df6..c94220d69 100644 --- a/lzero/worker/muzero_collector.py +++ b/lzero/worker/muzero_collector.py @@ -11,6 +11,7 @@ allreduce_data from ding.worker.collector.base_serial_collector import ISerialCollector from torch.nn import L1Loss +import torch.distributed as dist from lzero.mcts.buffer.game_segment import GameSegment from lzero.mcts.utils import prepare_observation @@ -327,6 +328,9 @@ def collect(self, Returns: - return_data (:obj:`List[Any]`): Collected data in the form of a list. """ + # Before starting collection + # self._logger.info(f"Rank {self._rank} starting collection for {n_episode} episodes.") + # TODO: collect_with_pure_policy as a separate collector if n_episode is None: if self._default_n_episode is None: @@ -716,11 +720,21 @@ def collect(self, break collected_duration = sum([d['time'] for d in self._episode_info]) + + # Before allreduce + self._logger.info(f"Rank {self._rank} before allreduce: collected_step={collected_step}, collected_episode={collected_episode}") + # reduce data when enables DDP if self._world_size > 1: + dist.barrier() + # print(f"Rank {dist.get_rank()} collected_step: {collected_step}, collected_episode: {collected_episode}, collected_duration: {collected_duration}") collected_step = allreduce_data(collected_step, 'sum') collected_episode = allreduce_data(collected_episode, 'sum') collected_duration = allreduce_data(collected_duration, 'sum') + + # After allreduce + self._logger.info(f"Rank {self._rank} after allreduce: collected_step={collected_step}, collected_episode={collected_episode}") + self._total_envstep_count += collected_step self._total_episode_count += collected_episode self._total_duration += collected_duration diff --git a/requirements.txt b/requirements.txt index 831ae67c5..ffb8dd582 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,5 @@ pytest line_profiler xxhash einops +openai==1.57.1 +jericho diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py index 56e89d1a3..ed3121bc3 100644 --- a/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py +++ b/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py @@ -15,22 +15,56 @@ def main(env_id, seed): continuous_action_space = True K = 20 # num_of_sampled_actions + # K = 16 # num_of_sampled_actions + collector_env_num = 8 n_episode = 8 num_segments = 8 - game_segment_length = 100 + game_segment_length = 125 + # game_segment_length = 500 + + collector_env_num = 16 + n_episode = 16 + num_segments = 16 + game_segment_length = 125 + + # collector_env_num = 16 + # n_episode = 16 + # num_segments = 16 + # game_segment_length = 125 + + evaluator_env_num = 3 - num_simulations = 50 - replay_ratio = 0.1 - max_env_step = int(5e5) + num_simulations = 50 # TODO + + # max_env_step = int(5e5) + max_env_step = int(1e6) + # max_env_step = int(3e6) # TODO + + reanalyze_ratio = 0 batch_size = 64 num_layers = 2 + # num_layers = 4 + num_unroll_steps = 5 + # num_unroll_steps = 10 infer_context_length = 2 + + # replay_ratio = 0.25 + # num_unroll_steps = 10 + # infer_context_length = 4 + norm_type = 'LN' # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. - buffer_reanalyze_freq = 1/100000 + buffer_reanalyze_freq = 1/1000000000 # TODO + # replay_ratio = 0.1 + replay_ratio = 0.25 + + + # buffer_reanalyze_freq = 1/10 + # replay_ratio = 0.1 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) reanalyze_batch_size = 160 # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. @@ -54,7 +88,9 @@ def main(env_id, seed): domain_name=domain_name, task_name=task_name, from_pixels=False, # vector/state obs - frame_skip=2, + # from_pixels=True, # vector/state obs + # frame_skip=2, + frame_skip=8, continuous=True, save_replay_gif=False, replay_path_gif='./replay_gif', @@ -75,13 +111,18 @@ def main(env_id, seed): num_of_sampled_actions=K, model_type='mlp', world_model_cfg=dict( - policy_loss_type='kl', + num_simulations=num_simulations, + policy_loss_type='kl', # 'simple' + # policy_loss_type='simple', # 'simple' obs_type='vector', num_unroll_steps=num_unroll_steps, + # policy_entropy_weight=0, + # policy_entropy_weight=5e-3, policy_entropy_weight=5e-2, continuous_action_space=continuous_action_space, num_of_sampled_actions=K, sigma_type='conditioned', + # sigma_type='fixed', fixed_sigma_value=0.5, bound_type=None, model_type='mlp', @@ -93,7 +134,8 @@ def main(env_id, seed): action_space_size=action_space_size, num_layers=num_layers, num_heads=8, - embed_dim=768, + embed_dim=768, # original + # embed_dim=512, env_num=max(collector_env_num, evaluator_env_num), ), ), @@ -108,21 +150,31 @@ def main(env_id, seed): replay_ratio=replay_ratio, batch_size=batch_size, discount_factor=0.99, - td_steps=5, - piecewise_decay_lr_scheduler=False, + # discount_factor=1, + # td_steps=5, + # td_steps=10, + td_steps=game_segment_length, # TODO + + lr_piecewise_constant_decay=False, learning_rate=1e-4, grad_clip_value=5, + # grad_clip_value=0.3, # TODO + # manual_temperature_decay=False, manual_temperature_decay=True, threshold_training_steps_for_final_temperature=int(2.5e4), - cos_lr_scheduler=True, + + # cos_lr_scheduler=True, + cos_lr_scheduler=False, + num_segments=num_segments, train_start_after_envsteps=2000, game_segment_length=game_segment_length, num_simulations=num_simulations, - reanalyze_ratio=0, + reanalyze_ratio=reanalyze_ratio, n_episode=n_episode, eval_freq=int(5e3), replay_buffer_size=int(1e6), + # replay_buffer_size=int(5e4), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, # ============= The key different params for ReZero ============= @@ -151,7 +203,8 @@ def main(env_id, seed): # ============ use muzero_segment_collector instead of muzero_collector ============= from lzero.entry import train_unizero_segment - main_config.exp_name=f'data_sampled_unizero/dmc2gym_{env_id}_brf{buffer_reanalyze_freq}_state_cont_suz_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_K{K}_ns{num_simulations}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_{norm_type}_seed{seed}_learnsigma' + main_config.exp_name=f'data_suz_1216/dmc2gym_{env_id}_state_cont_suz_fs8_act-simnorm_td{game_segment_length}_dc099_learn-sigma_gcv5_rbs1e6_no-corlr_embed768_temp2.5e4_pew5e-2_19prior1flatten_obs10value01_clamp4_brf{buffer_reanalyze_freq}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_K{K}_ns{num_simulations}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_{norm_type}_seed{seed}' + train_unizero_segment([main_config, create_config], model_path=main_config.policy.model_path, seed=seed, max_env_step=max_env_step) @@ -161,6 +214,15 @@ def main(env_id, seed): parser.add_argument('--env', type=str, help='The environment to use', default='cartpole-swingup') parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + # args.env = 'cheetah-run' + # args.env = 'walker-walk' + # args.env = 'finger-spin' + # args.env = 'pendulum-swingup' + + # args.env = 'hopper-hop' + # args.env = 'acrobot-swingup' main(args.env, args.seed) \ No newline at end of file diff --git a/zoo/jericho/__init__.py b/zoo/jericho/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py new file mode 100644 index 000000000..8dd791303 --- /dev/null +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -0,0 +1,162 @@ +import os +from easydict import EasyDict +import os +os.environ["HF_HOME"] = "/mnt/afs/zhangshenghan/.cache/huggingface/hub" + +def main(env_id='detective.z5', seed=0): + action_space_size = 50 + max_steps = 51 + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + collector_env_num = 2 + num_segments = 2 + game_segment_length = 20 + evaluator_env_num = 2 + num_simulations = 50 + max_env_step = int(10e6) + batch_size = 32 + num_unroll_steps = 10 + infer_context_length = 4 + num_layers = 2 + replay_ratio = 0.25 + update_per_collect = 20 # NOTE: very important for ddp + embed_dim = 768 + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + # buffer_reanalyze_freq = 1/10 + buffer_reanalyze_freq = 1/100000 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition = 0.75 + # model_name = 'BAAI/bge-base-en-v1.5' + model_name = 'google-bert/bert-base-uncased' + # =========== TODO: only for debug =========== + # collector_env_num = 2 + # num_segments = 2 + # game_segment_length = 20 + # evaluator_env_num = 2 + # max_env_step = int(5e5) + # batch_size = 10 + # num_simulations = 5 + # num_unroll_steps = 5 + # infer_context_length = 2 + # max_steps = 10 + # num_layers = 1 + # replay_ratio = 0.05 + # embed_dim = 768 + # TODO: MCTS内部的action_space受限于root节点的legal action + + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + jericho_unizero_config = dict( + env=dict( + stop_value=int(1e6), + observation_shape=512, + max_steps=max_steps, + max_action_num=action_space_size, + tokenizer_path=model_name, + # tokenizer_path="/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594", + max_seq_len=512, + # game_path="z-machine-games-master/jericho-game-suite/" + env_id, + game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/"+ env_id, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ) + ), + policy=dict( + multi_gpu=False, # ======== Very important for ddp ============= + # multi_gpu=True, # ======== Very important for ddp ============= + # default is 10000 + use_wandb=False, + learn=dict(learner=dict( + hook=dict(save_ckpt_after_iter=1000000, ), ), ), + model=dict( + observation_shape=512, + action_space_size=action_space_size, + encoder_url=model_name, + # encoder_url='google-bert/bert-base-uncased', + # encoder_url='/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594', + # The input of the model is text, whose shape is identical to the mlp model. + model_type='mlp', + continuous_action_space=False, + world_model_cfg=dict( + policy_entropy_weight=5e-3, + continuous_action_space=False, + max_blocks=num_unroll_steps, + # NOTE: each timestep has 2 tokens: obs and action + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + use_hf=False, + num_layers=num_layers, + num_heads=8, + embed_dim=embed_dim, + obs_type='text', # TODO: Change it. + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + update_per_collect=update_per_collect, + action_type='varied_action_space', + model_path=None, + num_unroll_steps=num_unroll_steps, + reanalyze_ratio=0, + replay_ratio=replay_ratio, + batch_size=batch_size, + learning_rate=0.0001, + num_simulations=num_simulations, + num_segments=num_segments, + train_start_after_envsteps=0, # TODO + game_segment_length=game_segment_length, + replay_buffer_size=int(1e6), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), + ) + jericho_unizero_config = EasyDict(jericho_unizero_config) + + jericho_unizero_create_config = dict( + env=dict( + type='jericho', + import_names=['zoo.jericho.envs.jericho_env'], + ), + # NOTE: use base env manager to avoid the bug of subprocess env manager. + env_manager=dict(type='base'), + # env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + jericho_unizero_create_config = EasyDict(jericho_unizero_create_config) + main_config = jericho_unizero_config + create_config = jericho_unizero_create_config + + main_config.exp_name = f'data_unizero_detective_20241220/{env_id[:8]}_ms{max_steps}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + from lzero.entry import train_unizero + train_unizero([main_config, create_config], seed=seed, + model_path=main_config.policy.model_path, max_env_step=max_env_step) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process some environment.') + parser.add_argument('--env', type=str, + help='The environment to use', default='detective.z5') # 'detective.z5' 'zork1.z5' + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + main(args.env, args.seed) diff --git a/zoo/jericho/configs/jericho_unizero_pretrained_config.py b/zoo/jericho/configs/jericho_unizero_pretrained_config.py new file mode 100644 index 000000000..c02d9072e --- /dev/null +++ b/zoo/jericho/configs/jericho_unizero_pretrained_config.py @@ -0,0 +1,166 @@ +import os +from easydict import EasyDict +import os +os.environ["HF_HOME"] = "/mnt/afs/zhangshenghan/.cache/huggingface/hub" + +def main(env_id='detective.z5', seed=0): + action_space_size = 50 + max_steps = 51 + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + collector_env_num = 2 + num_segments = 2 + game_segment_length = 20 + evaluator_env_num = 2 + num_simulations = 50 + max_env_step = int(10e6) + batch_size = 32 + num_unroll_steps = 10 + infer_context_length = 4 + replay_ratio = 0.25 + update_per_collect = 20 # NOTE: very important for ddp + embed_dim = 2048 + num_layers = 16 + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + # buffer_reanalyze_freq = 1/10 + buffer_reanalyze_freq = 1/100000 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition = 0.75 + # model_name = 'BAAI/bge-base-en-v1.5' + model_name = 'google-bert/bert-base-uncased' + # =========== TODO: only for debug =========== + # collector_env_num = 2 + # num_segments = 2 + # game_segment_length = 20 + # evaluator_env_num = 2 + # max_env_step = int(5e5) + # batch_size = 10 + # num_simulations = 5 + # num_unroll_steps = 5 + # infer_context_length = 2 + # max_steps = 10 + # num_layers = 1 + # replay_ratio = 0.05 + # embed_dim = 768 + # TODO: MCTS内部的action_space受限于root节点的legal action + + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + jericho_unizero_config = dict( + env=dict( + stop_value=int(1e6), + observation_shape=512, + max_steps=max_steps, + max_action_num=action_space_size, + tokenizer_path=model_name, + # tokenizer_path="/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594", + max_seq_len=512, + game_path="z-machine-games-master/jericho-game-suite/" + env_id, + # game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/"+ env_id, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ) + ), + policy=dict( + multi_gpu=False, # ======== Very important for ddp ============= + # multi_gpu=True, # ======== Very important for ddp ============= + # default is 10000 + use_wandb=False, + learn=dict(learner=dict( + hook=dict(save_ckpt_after_iter=1000000, ), ), ), + model=dict( + observation_shape=512, + action_space_size=action_space_size, + encoder_url=model_name, + # encoder_url='google-bert/bert-base-uncased', + # encoder_url='/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594', + # The input of the model is text, whose shape is identical to the mlp model. + model_type='mlp', + continuous_action_space=False, + world_model_cfg=dict( + policy_entropy_weight=5e-3, + continuous_action_space=False, + max_blocks=num_unroll_steps, + # NOTE: each timestep has 2 tokens: obs and action + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + use_hf=True, + pretrained_path='/data/share/Llama-3.2-1B', + # These parameters should be the same as the config file of original model. + num_layers=num_layers, + # Note: llama uses GQA, and the number of heads should equal to the number of key-value heads. + num_heads=8, + embed_dim=embed_dim, + hidden_size=64, # The dim for each head is 64 in this case. + obs_type='text', # TODO: Change it. + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + update_per_collect=update_per_collect, + action_type='varied_action_space', + model_path=None, + num_unroll_steps=num_unroll_steps, + reanalyze_ratio=0, + replay_ratio=replay_ratio, + batch_size=batch_size, + learning_rate=0.0001, + num_simulations=num_simulations, + num_segments=num_segments, + train_start_after_envsteps=0, # TODO + game_segment_length=game_segment_length, + replay_buffer_size=int(1e6), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), + ) + jericho_unizero_config = EasyDict(jericho_unizero_config) + + jericho_unizero_create_config = dict( + env=dict( + type='jericho', + import_names=['zoo.jericho.envs.jericho_env'], + ), + # NOTE: use base env manager to avoid the bug of subprocess env manager. + env_manager=dict(type='base'), + # env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + jericho_unizero_create_config = EasyDict(jericho_unizero_create_config) + main_config = jericho_unizero_config + create_config = jericho_unizero_create_config + + main_config.exp_name = f'data_unizero_detective_20241220/{env_id[:8]}_ms{max_steps}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + from lzero.entry import train_unizero + train_unizero([main_config, create_config], seed=seed, + model_path=main_config.policy.model_path, max_env_step=max_env_step) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process some environment.') + parser.add_argument('--env', type=str, + help='The environment to use', default='detective.z5') # 'detective.z5' 'zork1.z5' + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + main(args.env, args.seed) diff --git a/zoo/jericho/configs/jericho_unizero_segment_config.py b/zoo/jericho/configs/jericho_unizero_segment_config.py new file mode 100644 index 000000000..4b6fb4da5 --- /dev/null +++ b/zoo/jericho/configs/jericho_unizero_segment_config.py @@ -0,0 +1,153 @@ +import os +from easydict import EasyDict + + +def main(env_id='detective.z5', seed=0): + action_space_size = 50 + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + collector_env_num = 8 + game_segment_length = 20 + evaluator_env_num = 5 + num_segments = 8 + num_simulations = 50 + max_env_step = int(5e5) + batch_size = 64 + num_unroll_steps = 10 + infer_context_length = 4 + num_layers = 2 + replay_ratio = 0.25 + embed_dim = 768 + max_steps = 100 + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + # buffer_reanalyze_freq = 1/10 + buffer_reanalyze_freq = 1/100000 + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size = 160 + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition = 0.75 + + # =========== TODO: only for debug =========== + collector_env_num = 2 + num_segments = 2 + max_steps=20 + game_segment_length = 20 + evaluator_env_num = 2 + num_simulations = 5 + max_env_step = int(5e5) + batch_size = 10 + num_unroll_steps = 5 + infer_context_length = 2 + num_layers = 1 + replay_ratio = 0.05 + embed_dim = 32 + + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + jericho_unizero_config = dict( + env=dict( + stop_value=int(1e6), + max_steps=max_steps, + observation_shape=512, + max_action_num=action_space_size, + # tokenizer_path="google-bert/bert-base-uncased", + tokenizer_path="/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594", + max_seq_len=512, + # game_path="z-machine-games-master/jericho-game-suite/" + env_id, + game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/"+ env_id, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ) + ), + policy=dict( + # default is 10000 + learn=dict(learner=dict( + hook=dict(save_ckpt_after_iter=1000000, ), ), ), + model=dict( + observation_shape=512, + action_space_size=action_space_size, + # encoder_url='google-bert/bert-base-uncased', + encoder_url='/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594', + # The input of the model is text, whose shape is identical to the mlp model. + model_type='mlp', + world_model_cfg=dict( + policy_entropy_weight=5e-3, + continuous_action_space=False, + max_blocks=num_unroll_steps, + # NOTE: each timestep has 2 tokens: obs and action + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + use_hf=False, + num_layers=num_layers, + num_heads=8, + embed_dim=embed_dim, + obs_type='text', # TODO: Change it. + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + action_type='varied_action_space', + model_path=None, + num_unroll_steps=num_unroll_steps, + reanalyze_ratio=0, + replay_ratio=replay_ratio, + batch_size=batch_size, + learning_rate=0.0001, + num_simulations=num_simulations, + num_segments=num_segments, + # train_start_after_envsteps=2000, + train_start_after_envsteps=0, # TODO: only for debug + game_segment_length=game_segment_length, + replay_buffer_size=int(1e6), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + # ============= The key different params for reanalyze ============= + # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs. + buffer_reanalyze_freq=buffer_reanalyze_freq, + # Each reanalyze process will reanalyze sequences ( transitions per sequence) + reanalyze_batch_size=reanalyze_batch_size, + # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer. + reanalyze_partition=reanalyze_partition, + ), + ) + jericho_unizero_config = EasyDict(jericho_unizero_config) + + jericho_unizero_create_config = dict( + env=dict( + type='jericho', + import_names=['zoo.jericho.envs.jericho_env'], + ), + # NOTE: use base env manager to avoid the bug of subprocess env manager. + env_manager=dict(type='base'), + # env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + jericho_unizero_create_config = EasyDict(jericho_unizero_create_config) + main_config = jericho_unizero_config + create_config = jericho_unizero_create_config + + main_config.exp_name = f'data_unizero/{env_id[:-14]}/{env_id[:-14]}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + from lzero.entry import train_unizero_segment + train_unizero_segment([main_config, create_config], seed=seed, + model_path=main_config.policy.model_path, max_env_step=max_env_step) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process some environment.') + parser.add_argument('--env', type=str, + help='The environment to use', default='detective.z5') # 'zork1.z5' + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + main(args.env, args.seed) diff --git a/zoo/jericho/envs/__init__.py b/zoo/jericho/envs/__init__.py new file mode 100644 index 000000000..740dab512 --- /dev/null +++ b/zoo/jericho/envs/__init__.py @@ -0,0 +1 @@ +from .jericho_env import JerichoEnv diff --git a/zoo/jericho/envs/jericho_env.py b/zoo/jericho/envs/jericho_env.py new file mode 100644 index 000000000..4fdd7c853 --- /dev/null +++ b/zoo/jericho/envs/jericho_env.py @@ -0,0 +1,168 @@ +""" +env返回的obs是id 不是string +""" +import copy +from typing import List + +import gym +import numpy as np +from transformers import AutoTokenizer +from ding.utils import ENV_REGISTRY +from ding.envs import BaseEnv, BaseEnvTimestep +from jericho import FrotzEnv +from ding.utils import set_pkg_seed, get_rank, get_world_size + + +@ENV_REGISTRY.register('jericho') +class JerichoEnv(BaseEnv): + """ + Overview: + The environment for Jericho games. For more details about the game, please refer to the \ + `Jericho `. + """ + tokenizer = None + + def __init__(self, cfg): + self.cfg = cfg + self.max_steps = cfg.max_steps + self.game_path = cfg.game_path + self.max_action_num = cfg.max_action_num + self.max_seq_len = cfg.max_seq_len + + if JerichoEnv.tokenizer is None: + JerichoEnv.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path) + + self._env = FrotzEnv(self.game_path) + self._action_list = None + self.finished = False + self._init_flag = False + self.episode_return = 0 + self.env_step = 0 + + self.observation_space = gym.spaces.Dict() + self.action_space = gym.spaces.Discrete(self.max_action_num) + self.reward_space = gym.spaces.Box( + low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32) + + def prepare_obs(self, obs, return_str: bool = False): + if self._action_list is None: + self._action_list = self._env.get_valid_actions() + full_obs = obs + "\nValid actions: " + str(self._action_list) + if not return_str: + full_obs = JerichoEnv.tokenizer( + [full_obs], truncation=True, padding="max_length", max_length=self.max_seq_len) + obs_attn_mask = full_obs['attention_mask'] + full_obs = np.array(full_obs['input_ids'][0], dtype=np.int32) # TODO: attn_mask + if len(self._action_list) <= self.max_action_num: + action_mask = [1] * len(self._action_list) + [0] * \ + (self.max_action_num - len(self._action_list)) + else: + action_mask = [1] * len(self._action_list) + + action_mask = np.array(action_mask, dtype=np.int8) + if return_str: + return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1} + else: + return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1} + + def reset(self, return_str: bool = False): + initial_observation, info = self._env.reset() + self.finished = False + self._init_flag = True + self._action_list = None + self.episode_return = 0 + self.env_step = 0 + + # 获取当前的 world_size 和 rank + self.world_size = get_world_size() + self.rank = get_rank() + + return self.prepare_obs(initial_observation, return_str) + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + """ + Set the seed for the environment. + """ + self._seed = seed + self._env.seed(seed) + + def close(self) -> None: + self._init_flag = False + + def __repr__(self) -> str: + return "LightZero Jericho Env" + + def step(self, action: int, return_str: bool = False): + try: + action_str = self._action_list[action] + except Exception as e: + # TODO: why exits illegal action + print('='*20) + print(e, f'rank {self.rank}, action {action} is illegal now we randomly choose a legal action from {self._action_list}!') + action = np.random.choice(len(self._action_list)) + action_str = self._action_list[action] + + observation, reward, done, info = self._env.step(action_str) + self.env_step += 1 + self.episode_return += reward + self._action_list = None + observation = self.prepare_obs(observation, return_str) + + # print(f'rank {self.rank}, step: {self.env_step}') + # print(f'self._action_list:{self._action_list}') + # print(f'rank {self.rank}, step: {self.env_step}, observation:{observation}, action:{action}, reward:{reward}') + + if self.env_step >= self.max_steps: + done = True + + if done: + print('='*20) + print(f'rank {self.rank} one episode done!') + # print(f'self._action_list:{self._action_list}, action:{action}, reward:{reward}') + self.finished = True + info['eval_episode_return'] = self.episode_return + + return BaseEnvTimestep(observation, reward, done, info) + + @staticmethod + def create_collector_env_cfg(cfg: dict) -> List[dict]: + collector_env_num = cfg.pop('collector_env_num') + cfg = copy.deepcopy(cfg) + # when in collect phase, sometimes we need to normalize the reward + # reward_normalize is determined by the config. + cfg.is_collect = True + return [cfg for _ in range(collector_env_num)] + + @staticmethod + def create_evaluator_env_cfg(cfg: dict) -> List[dict]: + evaluator_env_num = cfg.pop('evaluator_env_num') + cfg = copy.deepcopy(cfg) + # when in evaluate phase, we don't need to normalize the reward. + cfg.reward_normalize = False + cfg.is_collect = False + return [cfg for _ in range(evaluator_env_num)] + + +if __name__ == '__main__': + from easydict import EasyDict + env_cfg = EasyDict( + dict( + max_steps=100, + # game_path="z-machine-games-master/jericho-game-suite/zork1.z5", + game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/detective.z5", + # game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/905.z5", + max_action_num=50, + max_env_step=100, + tokenizer_path="google-bert/bert-base-uncased", + max_seq_len=512 + ) + ) + env = JerichoEnv(env_cfg) + obs = env.reset(return_str=True) + print(f'[OBS]:\n{obs["observation"]}') + while True: + action_id = int(input('Please input the action id:')) + obs, reward, done, info = env.step(action_id, return_str=True) + print(f'[OBS]:\n{obs["observation"]}') + if done: + break diff --git a/zoo/jericho/envs/test_jericho_env.py b/zoo/jericho/envs/test_jericho_env.py new file mode 100644 index 000000000..28db93b53 --- /dev/null +++ b/zoo/jericho/envs/test_jericho_env.py @@ -0,0 +1,41 @@ +from easydict import EasyDict +from .jericho_env import JerichoEnv +import numpy as np +import pytest + + +@pytest.mark.unittest +class TestJerichoEnv(): + def setup(self) -> None: + # Configuration for the Jericho environment + cfg = EasyDict( + dict( + game_path="z-machine-games-master/jericho-game-suite/zork1.z5", + max_action_num=50, + tokenizer_path="google-bert/bert-base-uncased", + max_seq_len=512 + ) + ) + # Create a Jericho environment that will be used in the following tests. + self.env = JerichoEnv(cfg) + + # Test the initialization of the Jericho environment. + def test_initialization(self): + assert isinstance(self.env, JerichoEnv) + + # Test the reset method of the Jericho environment. + # Ensure that the shape of the observation is as expected. + def test_reset(self): + obs = self.env.reset() + assert obs['observation'].shape == (512,) + + # Test the step method of the Jericho environment. + # Ensure that the shape of the observation, the type of the reward, + # the type of the done flag and the type of the info are as expected. + def test_step_shape(self): + self.env.reset() + obs, reward, done, info = self.env.step(1) + assert obs['observation'].shape == (512,) + assert isinstance(reward, np.ndarray) + assert isinstance(done, bool) + assert isinstance(info, dict)