diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html
similarity index 100%
rename from docs/source/_templates/layout.html
rename to docs/source/_templates/layout.html
diff --git a/lzero/entry/train_unizero.py b/lzero/entry/train_unizero.py
index cd7ff7605..edd67d6f8 100644
--- a/lzero/entry/train_unizero.py
+++ b/lzero/entry/train_unizero.py
@@ -21,6 +21,8 @@
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroCollector as Collector
from .utils import random_collect
+import torch.distributed as dist
+from ding.utils import set_pkg_seed, get_rank, get_world_size
def train_unizero(
@@ -52,33 +54,43 @@ def train_unizero(
cfg, create_cfg = input_cfg
- # Ensure the specified policy type is supported
- assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'"
+ logging.info("===== 开始训练 UniZero =====")
+
+ # 检查是否支持指定的 policy 类型
+ assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero 仅支持以下算法: 'unizero', 'sampled_unizero'"
+ logging.info(f"使用的 policy 类型为: {create_cfg.policy.type}")
- # Import the correct GameBuffer class based on the policy type
+ # 根据 policy 类型导入对应的 GameBuffer 类
game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'}
-
GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]),
game_buffer_classes[create_cfg.policy.type])
- # Set device based on CUDA availability
+ # 检查是否有 GPU 可用,设置设备
cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
- logging.info(f'cfg.policy.device: {cfg.policy.device}')
+ logging.info(f"设备已设置为: {cfg.policy.device}")
- # Compile the configuration
+ # 编译配置文件
+ logging.info("正在编译配置文件...")
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+ logging.info("配置文件编译完成!")
- # Create main components: env, policy
+ # 创建环境管理器
+ logging.info("正在创建环境管理器...")
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ logging.info("环境管理器创建完成!")
+ # 环境和随机种子初始化
+ logging.info("正在初始化环境和随机种子...")
collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available())
+ logging.info("环境和随机种子初始化完成!")
+ # 如果使用 wandb,初始化 wandb
if cfg.policy.use_wandb:
- # Initialize wandb
+ logging.info("正在初始化 wandb...")
wandb.init(
project="LightZero",
config=cfg,
@@ -86,72 +98,99 @@ def train_unizero(
monitor_gym=False,
save_code=True,
)
+ logging.info("wandb 初始化完成!")
+ # 创建 policy
+ logging.info("正在创建 policy...")
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
+ logging.info("policy 创建完成!")
- # Load pretrained model if specified
+ # 如果指定了模型路径,加载预训练模型
if model_path is not None:
- logging.info(f'Loading model from {model_path} begin...')
+ logging.info(f"正在从 {model_path} 加载预训练模型...")
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
- logging.info(f'Loading model from {model_path} end!')
+ logging.info("预训练模型加载完成!")
- # Create worker components: learner, collector, evaluator, replay buffer, commander
+ # 创建训练的核心组件
+ logging.info("正在创建训练的核心组件...")
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
-
- # MCTS+RL algorithms related core code
- policy_config = cfg.policy
- replay_buffer = GameBuffer(policy_config)
+ replay_buffer = GameBuffer(cfg.policy)
collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name,
- policy_config=policy_config)
+ policy_config=cfg.policy)
evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode,
- tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config)
+ tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=cfg.policy)
+ logging.info("训练核心组件创建完成!")
- # Learner's before_run hook
+ # Learner 的前置 hook
+ logging.info("正在执行 Learner 的 before_run hook...")
learner.call_hook('before_run')
- if policy_config.use_wandb:
+ logging.info("Learner 的 before_run hook 执行完成!")
+
+ if cfg.policy.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)
- # Collect random data before training
+ # 随机收集数据
if cfg.policy.random_collect_episode_num > 0:
+ logging.info("正在进行随机数据收集...")
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
+ logging.info("随机数据收集完成!")
batch_size = policy._cfg.batch_size
+ if cfg.policy.multi_gpu:
+ # 获取当前的 world_size 和 rank
+ world_size = get_world_size()
+ rank = get_rank()
+ else:
+ world_size = 1
+ rank = 0
+
while True:
- # Log buffer memory usage
+ # torch.cuda.empty_cache()
+
+ # 记录 replay buffer 的内存使用情况
+ # logging.info(f"训练迭代 {learner.train_iter}: 正在记录 replay buffer 的内存使用情况...")
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
+ # logging.info(f"训练迭代 {learner.train_iter}: 内存使用记录完成!")
- # Set temperature for visit count distributions
+ # 设置温度参数
collect_kwargs = {
'temperature': visit_count_temperature(
- policy_config.manual_temperature_decay,
- policy_config.fixed_temperature_value,
- policy_config.threshold_training_steps_for_final_temperature,
+ cfg.policy.manual_temperature_decay,
+ cfg.policy.fixed_temperature_value,
+ cfg.policy.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
),
- 'epsilon': 0.0 # Default epsilon value
+ 'epsilon': 0.0 # 默认 epsilon 值
}
+ # logging.info(f"训练迭代 {learner.train_iter}: 温度设置完成,值为 {collect_kwargs['temperature']}")
- # Configure epsilon for epsilon-greedy exploration
- if policy_config.eps.eps_greedy_exploration_in_collect:
+ # 配置 epsilon-greedy 探索
+ if cfg.policy.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(
- start=policy_config.eps.start,
- end=policy_config.eps.end,
- decay=policy_config.eps.decay,
- type_=policy_config.eps.type
+ start=cfg.policy.eps.start,
+ end=cfg.policy.eps.end,
+ decay=cfg.policy.eps.decay,
+ type_=cfg.policy.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
+ # logging.info(f"训练迭代 {learner.train_iter}: epsilon 设置完成,值为 {collect_kwargs['epsilon']}")
- # Evaluate policy performance
+ # 评估 policy 的表现
if evaluator.should_eval(learner.train_iter):
+ logging.info(f"训练迭代 {learner.train_iter}: 开始评估...")
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ logging.info(f"训练迭代 {learner.train_iter}: 评估完成,是否停止: {stop}, 当前奖励: {reward}")
if stop:
+ logging.info("满足停止条件,训练结束!")
break
- # Collect new data
+ # 收集新数据
+ # logging.info(f"Rank {rank}, 训练迭代 {learner.train_iter}: 开始收集新数据...")
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ logging.info(f"Rank {rank}, 训练迭代 {learner.train_iter}: 新数据收集完成!")
# Determine updates per collection
update_per_collect = cfg.policy.update_per_collect
@@ -162,44 +201,60 @@ def train_unizero(
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
- # Update replay buffer
+ # 更新 replay buffer
+ # logging.info(f"训练迭代 {learner.train_iter}: 开始更新 replay buffer...")
replay_buffer.push_game_segments(new_data)
replay_buffer.remove_oldest_data_to_fit()
+ # logging.info(f"训练迭代 {learner.train_iter}: replay buffer 更新完成!")
+
+ if world_size > 1:
+ # 同步训练前所有rank的准备状态
+ try:
+ dist.barrier()
+ # logging.info(f'Rank {rank}: 通过训练前的同步障碍')
+ except Exception as e:
+ logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}')
+ break
- # Train the policy if sufficient data is available
+ # 检查是否有足够数据进行训练
if collector.envstep > cfg.policy.train_start_after_envsteps:
if cfg.policy.sample_type == 'episode':
data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size
else:
data_sufficient = replay_buffer.get_num_of_transitions() > batch_size
+
if not data_sufficient:
- logging.warning(
- f'The data in replay_buffer is not sufficient to sample a mini-batch: '
- f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....'
- )
+ # NOTE: 注意ddp训练时,不同rank可能有的replay buffer 数据不足,导致有的没有进入训练阶段,从而通信超时,需要确保同时进入训练阶段
+ logging.warning(f"Rank {rank}: 训练迭代 {learner.train_iter}: replay buffer 数据不足,继续收集数据...")
continue
+ logging.info(f"Rank {rank}, 训练迭代 {learner.train_iter}: 开始训练!")
+
+ # 执行多轮训练
for i in range(update_per_collect):
train_data = replay_buffer.sample(batch_size, policy)
if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0:
- # Clear caches and precompute positional embedding matrices
- policy.recompute_pos_emb_diff_and_clear_cache() # TODO
-
- if policy_config.use_wandb:
+ policy.recompute_pos_emb_diff_and_clear_cache()
+
+ if cfg.policy.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)
train_data.append({'train_which_component': 'transformer'})
log_vars = learner.train(train_data, collector.envstep)
-
if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
+
+ logging.info(f"Rank {rank}, 训练迭代 {learner.train_iter}: 训练完成!")
policy.recompute_pos_emb_diff_and_clear_cache()
- # Check stopping criteria
+ # 检查停止条件
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ logging.info("满足停止条件,训练结束!")
break
learner.call_hook('after_run')
- wandb.finish()
- return policy
+ if cfg.policy.use_wandb:
+ wandb.finish()
+ logging.info("===== 训练完成 =====")
+ return policy
\ No newline at end of file
diff --git a/lzero/entry/train_unizero_bkp.py b/lzero/entry/train_unizero_bkp.py
new file mode 100644
index 000000000..75ee9bcfb
--- /dev/null
+++ b/lzero/entry/train_unizero_bkp.py
@@ -0,0 +1,206 @@
+import logging
+import os
+from functools import partial
+from typing import Tuple, Optional
+
+import torch
+import wandb
+from ding.config import compile_config
+from ding.envs import create_env_manager
+from ding.envs import get_vec_env_setting
+from ding.policy import create_policy
+from ding.rl_utils import get_epsilon_greedy_fn
+from ding.utils import set_pkg_seed, get_rank
+from ding.worker import BaseLearner
+from tensorboardX import SummaryWriter
+from torch.utils.tensorboard import SummaryWriter
+
+from lzero.entry.utils import log_buffer_memory_usage
+from lzero.policy import visit_count_temperature
+from lzero.policy.random_policy import LightZeroRandomPolicy
+from lzero.worker import MuZeroEvaluator as Evaluator
+from lzero.worker import MuZeroCollector as Collector
+from .utils import random_collect
+
+
+def train_unizero(
+ input_cfg: Tuple[dict, dict],
+ seed: int = 0,
+ model: Optional[torch.nn.Module] = None,
+ model_path: Optional[str] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy':
+ """
+ Overview:
+ The train entry for UniZero, proposed in our paper UniZero: Generalized and Efficient Planning with Scalable Latent World Models.
+ UniZero aims to enhance the planning capabilities of reinforcement learning agents by addressing the limitations found in MuZero-style algorithms,
+ particularly in environments requiring the capture of long-term dependencies. More details can be found in https://arxiv.org/abs/2406.10667.
+ Arguments:
+ - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - model_path (:obj:`Optional[str]`): The pretrained model path, which should
+ point to the ckpt file of the pretrained model, and an absolute path is recommended.
+ In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+
+ cfg, create_cfg = input_cfg
+
+ # Ensure the specified policy type is supported
+ assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'"
+
+ # Import the correct GameBuffer class based on the policy type
+ game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'}
+
+ GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]),
+ game_buffer_classes[create_cfg.policy.type])
+
+ # Set device based on CUDA availability
+ cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
+ logging.info(f'cfg.policy.device: {cfg.policy.device}')
+
+ # Compile the configuration
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available())
+
+ if cfg.policy.use_wandb:
+ # Initialize wandb
+ wandb.init(
+ project="LightZero",
+ config=cfg,
+ sync_tensorboard=False,
+ monitor_gym=False,
+ save_code=True,
+ )
+
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
+
+ # Load pretrained model if specified
+ if model_path is not None:
+ logging.info(f'Loading model from {model_path} begin...')
+ policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
+ logging.info(f'Loading model from {model_path} end!')
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+
+ # MCTS+RL algorithms related core code
+ policy_config = cfg.policy
+ replay_buffer = GameBuffer(policy_config)
+ collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name,
+ policy_config=policy_config)
+ evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode,
+ stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode,
+ tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config)
+
+ # Learner's before_run hook
+ learner.call_hook('before_run')
+ if policy_config.use_wandb:
+ policy.set_train_iter_env_step(learner.train_iter, collector.envstep)
+
+ # Collect random data before training
+ if cfg.policy.random_collect_episode_num > 0:
+ random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
+
+ batch_size = policy._cfg.batch_size
+
+ while True:
+ # Log buffer memory usage
+ log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
+
+ # Set temperature for visit count distributions
+ collect_kwargs = {
+ 'temperature': visit_count_temperature(
+ policy_config.manual_temperature_decay,
+ policy_config.fixed_temperature_value,
+ policy_config.threshold_training_steps_for_final_temperature,
+ trained_steps=learner.train_iter
+ ),
+ 'epsilon': 0.0 # Default epsilon value
+ }
+
+ # Configure epsilon for epsilon-greedy exploration
+ if policy_config.eps.eps_greedy_exploration_in_collect:
+ epsilon_greedy_fn = get_epsilon_greedy_fn(
+ start=policy_config.eps.start,
+ end=policy_config.eps.end,
+ decay=policy_config.eps.decay,
+ type_=policy_config.eps.type
+ )
+ collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
+
+ # Evaluate policy performance
+ # if learner.train_iter == 0 or evaluator.should_eval(learner.train_iter):
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+
+ # Collect new data
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+
+ # Determine updates per collection
+ update_per_collect = cfg.policy.update_per_collect
+ if update_per_collect is None:
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the replay_ratio.
+ # The length of game_segment (i.e., len(game_segment.action_segment)) can be smaller than cfg.policy.game_segment_length if it represents the final segment of the game.
+ # On the other hand, its length will be less than cfg.policy.game_segment_length + padding_length when it is not the last game segment. Typically, padding_length is the sum of unroll_steps and td_steps.
+ collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
+ update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
+
+ # Update replay buffer
+ replay_buffer.push_game_segments(new_data)
+ replay_buffer.remove_oldest_data_to_fit()
+
+ # Train the policy if sufficient data is available
+ if collector.envstep > cfg.policy.train_start_after_envsteps:
+ if cfg.policy.sample_type == 'episode':
+ data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size
+ else:
+ data_sufficient = replay_buffer.get_num_of_transitions() > batch_size
+ if not data_sufficient:
+ logging.warning(
+ f'The data in replay_buffer is not sufficient to sample a mini-batch: '
+ f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....'
+ )
+ continue
+
+ for i in range(update_per_collect):
+ train_data = replay_buffer.sample(batch_size, policy)
+ if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0:
+ # Clear caches and precompute positional embedding matrices
+ policy.recompute_pos_emb_diff_and_clear_cache() # TODO
+
+ if policy_config.use_wandb:
+ policy.set_train_iter_env_step(learner.train_iter, collector.envstep)
+
+ train_data.append({'train_which_component': 'transformer'})
+ log_vars = learner.train(train_data, collector.envstep)
+
+ if cfg.policy.use_priority:
+ replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
+
+ policy.recompute_pos_emb_diff_and_clear_cache()
+
+ # Check stopping criteria
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ learner.call_hook('after_run')
+ wandb.finish()
+ return policy
diff --git a/lzero/mcts/buffer/game_buffer.py b/lzero/mcts/buffer/game_buffer.py
index 2678066e9..db548049a 100644
--- a/lzero/mcts/buffer/game_buffer.py
+++ b/lzero/mcts/buffer/game_buffer.py
@@ -401,8 +401,13 @@ def _preprocess_to_play_and_action_mask(
unroll_steps + 1]
)
if len(action_mask_tmp) < unroll_steps + 1:
+ # action_mask_tmp += [
+ # list(np.ones(self._cfg.model.action_space_size, dtype=np.int8))
+ # for _ in range(unroll_steps + 1 - len(action_mask_tmp))
+ # ]
+ # TODO: padded data
action_mask_tmp += [
- list(np.ones(self._cfg.model.action_space_size, dtype=np.int8))
+ list(np.zeros(self._cfg.model.action_space_size, dtype=np.int8))
for _ in range(unroll_steps + 1 - len(action_mask_tmp))
]
action_mask.append(action_mask_tmp)
diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py
index 2ba8180de..aa923c56b 100644
--- a/lzero/mcts/buffer/game_buffer_muzero.py
+++ b/lzero/mcts/buffer/game_buffer_muzero.py
@@ -711,7 +711,9 @@ def _compute_target_policy_non_reanalyzed(
]
else:
legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
-
+ # print(f'='*20)
+ # print(f'buffer_muzero: action_mask:{action_mask}')
+
with torch.no_grad():
policy_index = 0
# 0 -> Invalid target policy for padding outside of game segments,
@@ -730,11 +732,17 @@ def _compute_target_policy_non_reanalyzed(
# for atari/classic_control/box2d environments that only have one player.
target_policies.append(distributions)
else:
- # for board games that have two players.
+ # for board games that have two players or envs that have varied action space.
policy_tmp = [0 for _ in range(policy_shape)]
for index, legal_action in enumerate(legal_actions[policy_index]):
# only the action in ``legal_action`` the policy logits is nonzero
- policy_tmp[legal_action] = distributions[index]
+ # policy_tmp[legal_action] = distributions[index]
+ try:
+ policy_tmp[legal_action] = distributions[index]
+ except Exception as e:
+ print('='*20)
+ print(f'Exception:{e}, distributions:{distributions}, legal_action:{legal_actions[policy_index]}')
+ # TODO: 出现这个问题的原因在于采样的序列末尾可能是padding的action_mask是以np.zeros(self._cfg.model.action_space_size, dtype=np.int8)进行pad的
target_policies.append(policy_tmp)
else:
# NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0
diff --git a/lzero/model/common.py b/lzero/model/common.py
index 22afa95fe..f2ed32798 100644
--- a/lzero/model/common.py
+++ b/lzero/model/common.py
@@ -274,6 +274,86 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return output
+"""
+使用 google-bert/bert-base-uncased , 模型的输入为id
+"""
+class HFLanguageRepresentationNetwork(nn.Module):
+ def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: int = 768, group_size: int = 8):
+ """
+ 初始化语言表示网络
+
+ 参数:
+ - url (str): 预训练 Hugging Face 模型的地址,默认为 'google-bert/bert-base-uncased'。
+ - embedding_size (int): 输出嵌入的维度大小,默认为 768。
+ """
+ super().__init__()
+ from transformers import AutoModel
+ # 加载 Hugging Face 预训练模型
+ self.model = AutoModel.from_pretrained(url)
+
+ # 设置嵌入维度,如果目标维度不是 768,则添加一个线性变换层用于降维或升维
+ self.embedding_size = embedding_size
+ if self.embedding_size != 768:
+ self.embed_head = nn.Linear(768, self.embedding_size)
+
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
+ """
+ 前向传播,获取输入序列的语言表示。
+
+ 参数:
+ - x (torch.Tensor): 输入的张量,通常是序列的 token 索引,形状为 [batch_size, seq_len]。
+ - no_grad (bool): 是否在无梯度模式下运行,默认为 True。
+
+ 返回:
+ - torch.Tensor: 经过处理的语言嵌入向量,形状为 [batch_size, embedding_size]。
+ """
+ if no_grad:
+ # 在 no_grad 模式下禁用梯度计算以节省显存
+ with torch.no_grad():
+ x = x.long() # 确保输入张量为长整型
+ outputs = self.model(x) # 获取模型的输出
+
+ # 模型输出的 last_hidden_state 形状为 [batch_size, seq_len, hidden_size]
+ # 我们通常取 [CLS] 标记对应的向量,即 outputs.last_hidden_state[:, 0, :]
+ cls_embedding = outputs.last_hidden_state[:, 0, :]
+
+ # 如果目标的 embedding_size 不是 768,则应用线性变换
+ if self.embedding_size == 768:
+ # NOTE: very important for training stability.
+ cls_embedding = self.sim_norm(cls_embedding)
+
+ return cls_embedding
+ else:
+ cls_embedding = self.embed_head(cls_embedding)
+
+ # NOTE: very important for training stability.
+ cls_embedding = self.sim_norm(cls_embedding)
+
+ return cls_embedding
+ else:
+ # 非 no_grad 模式下,启用梯度计算
+ x = x.long() # 确保输入张量为长整型
+ outputs = self.model(x)
+ cls_embedding = outputs.last_hidden_state[:, 0, :]
+
+ # 如果目标的 embedding_size 不是 768,则应用线性变换
+ if self.embedding_size == 768:
+ # NOTE: very important for training stability.
+ cls_embedding = self.sim_norm(cls_embedding)
+
+ return cls_embedding
+ else:
+ cls_embedding = self.embed_head(cls_embedding)
+
+ # NOTE: very important for training stability.
+ cls_embedding = self.sim_norm(cls_embedding)
+
+ return cls_embedding
+
+
+
class RepresentationNetworkUniZero(nn.Module):
def __init__(
diff --git a/lzero/model/common_noserve_input_id.py b/lzero/model/common_noserve_input_id.py
new file mode 100644
index 000000000..f2ed32798
--- /dev/null
+++ b/lzero/model/common_noserve_input_id.py
@@ -0,0 +1,1211 @@
+"""
+Overview:
+ In this Python file, we provide a collection of reusable model templates designed to streamline the development
+ process for various custom algorithms. By utilizing these pre-built model templates, users can quickly adapt and
+ customize their custom algorithms, ensuring efficient and effective development.
+ BTW, users can refer to the unittest of these model templates to learn how to use them.
+"""
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+from ding.torch_utils import MLP, ResBlock
+from ding.utils import SequenceType
+from ditk import logging
+
+
+# use dataclass to make the output of network more convenient to use
+@dataclass
+class MZRNNNetworkOutput:
+ # output format of the MuZeroRNN model
+ value: torch.Tensor
+ value_prefix: torch.Tensor
+ policy_logits: torch.Tensor
+ latent_state: torch.Tensor
+ predict_next_latent_state: torch.Tensor
+ reward_hidden_state: Tuple[torch.Tensor]
+
+
+@dataclass
+class EZNetworkOutput:
+ # output format of the EfficientZero model
+ value: torch.Tensor
+ value_prefix: torch.Tensor
+ policy_logits: torch.Tensor
+ latent_state: torch.Tensor
+ reward_hidden_state: Tuple[torch.Tensor]
+
+
+@dataclass
+class MZNetworkOutput:
+ # output format of the MuZero model
+ value: torch.Tensor
+ reward: torch.Tensor
+ policy_logits: torch.Tensor
+ latent_state: torch.Tensor
+
+
+class SimNorm(nn.Module):
+
+ def __init__(self, simnorm_dim: int) -> None:
+ """
+ Overview:
+ Simplicial normalization. Adapted from https://arxiv.org/abs/2204.00616.
+ Arguments:
+ - simnorm_dim (:obj:`int`): The dimension for simplicial normalization.
+ """
+ super().__init__()
+ self.dim = simnorm_dim
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Forward pass of the SimNorm layer.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor to normalize.
+ Returns:
+ - x (:obj:`torch.Tensor`): The normalized tensor.
+ """
+ shp = x.shape
+ # Ensure that there is at least one simplex to normalize across.
+ if shp[1] != 0:
+ x = x.view(*shp[:-1], -1, self.dim)
+ x = F.softmax(x, dim=-1)
+ return x.view(*shp)
+ else:
+ return x
+
+ def __repr__(self) -> str:
+ """
+ Overview:
+ String representation of the SimNorm layer.
+ Returns:
+ - output (:obj:`str`): The string representation.
+ """
+ return f"SimNorm(dim={self.dim})"
+
+
+def AvgL1Norm(x, eps=1e-8):
+ """
+ Overview:
+ Normalize the input tensor by the L1 norm.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor to normalize.
+ - eps (:obj:`float`): The epsilon value to prevent division by zero.
+ Returns:
+ - :obj:`torch.Tensor`: The normalized tensor.
+ """
+ return x / x.abs().mean(-1, keepdim=True).clamp(min=eps)
+
+
+class FeatureAndGradientHook:
+
+ def __init__(self):
+ """
+ Overview:
+ Class to capture features and gradients at SimNorm.
+ """
+ self.features_before = []
+ self.features_after = []
+ self.grads_before = []
+ self.grads_after = []
+
+ def setup_hooks(self, model):
+ # Hooks to capture features and gradients at SimNorm
+ self.forward_handler = model.sim_norm.register_forward_hook(self.forward_hook)
+ self.backward_handler = model.sim_norm.register_full_backward_hook(self.backward_hook)
+
+ def forward_hook(self, module, input, output):
+ with torch.no_grad():
+ self.features_before.append(input[0])
+ self.features_after.append(output)
+
+ def backward_hook(self, module, grad_input, grad_output):
+ with torch.no_grad():
+ self.grads_before.append(grad_input[0] if grad_input[0] is not None else None)
+ self.grads_after.append(grad_output[0] if grad_output[0] is not None else None)
+
+ def analyze(self):
+ # Calculate L2 norms of features
+ l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_before]))
+ l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_after]))
+
+ # Calculate norms of gradients
+ grad_norm_before = torch.mean(
+ torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_before if g is not None]))
+ grad_norm_after = torch.mean(
+ torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_after if g is not None]))
+
+ # Clear stored data and delete tensors to free memory
+ self.clear_data()
+
+ # Optionally clear CUDA cache
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return l2_norm_before, l2_norm_after, grad_norm_before, grad_norm_after
+
+ def clear_data(self):
+ del self.features_before[:]
+ del self.features_after[:]
+ del self.grads_before[:]
+ del self.grads_after[:]
+
+ def remove_hooks(self):
+ self.forward_handler.remove()
+ self.backward_handler.remove()
+
+
+class DownSample(nn.Module):
+
+ def __init__(self, observation_shape: SequenceType, out_channels: int,
+ activation: nn.Module = nn.ReLU(inplace=True),
+ norm_type: Optional[str] = 'BN',
+ num_resblocks: int = 1,
+ ) -> None:
+ """
+ Overview:
+ Define downSample convolution network. Encode the observation into hidden state.
+ This network is often used in video games like Atari. In board games like go and chess,
+ we don't need this module.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[12, 96, 96]
+ for video games like atari, RGB 3 channel times stack 4 frames.
+ - out_channels (:obj:`int`): The output channels of output hidden state.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`Optional[str]`): The normalization type used in network, defaults to 'BN'.
+ - num_resblocks (:obj:`int`): The number of residual blocks. Defaults to 1.
+ """
+ super().__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+
+ self.observation_shape = observation_shape
+ self.conv1 = nn.Conv2d(
+ observation_shape[0],
+ out_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False, # disable bias for better convergence
+ )
+ if norm_type == 'BN':
+ self.norm1 = nn.BatchNorm2d(out_channels // 2)
+ elif norm_type == 'LN':
+ self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2],
+ eps=1e-5)
+
+ self.resblocks1 = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=out_channels // 2,
+ activation=activation,
+ norm_type=norm_type,
+ res_type='basic',
+ bias=False
+ ) for _ in range(num_resblocks)
+ ]
+ )
+ self.downsample_block = ResBlock(
+ in_channels=out_channels // 2,
+ out_channels=out_channels,
+ activation=activation,
+ norm_type=norm_type,
+ res_type='downsample',
+ bias=False
+ )
+ self.resblocks2 = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_resblocks)
+ ]
+ )
+ self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
+ self.resblocks3 = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(1)
+ ]
+ )
+ self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
+ self.activation = activation
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \
+ H is height.
+ - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \
+ output width, H_ is output height.
+ """
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.activation(x)
+
+ for block in self.resblocks1:
+ x = block(x)
+ x = self.downsample_block(x)
+ for block in self.resblocks2:
+ x = block(x)
+ x = self.pooling1(x)
+ for block in self.resblocks3:
+ x = block(x)
+
+ # 64, 84, 96 are the most common observation shapes in Atari games.
+ if self.observation_shape[1] == 64:
+ output = x
+ elif self.observation_shape[1] == 84:
+ x = self.pooling2(x)
+ output = x
+ elif self.observation_shape[1] == 96:
+ x = self.pooling2(x)
+ output = x
+ else:
+ raise NotImplementedError(f"DownSample for observation shape {self.observation_shape} is not implemented now. "
+ f"You should transform the observation shape to 64 or 96 in the env.")
+
+ return output
+
+
+"""
+使用 google-bert/bert-base-uncased , 模型的输入为id
+"""
+class HFLanguageRepresentationNetwork(nn.Module):
+ def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: int = 768, group_size: int = 8):
+ """
+ 初始化语言表示网络
+
+ 参数:
+ - url (str): 预训练 Hugging Face 模型的地址,默认为 'google-bert/bert-base-uncased'。
+ - embedding_size (int): 输出嵌入的维度大小,默认为 768。
+ """
+ super().__init__()
+ from transformers import AutoModel
+ # 加载 Hugging Face 预训练模型
+ self.model = AutoModel.from_pretrained(url)
+
+ # 设置嵌入维度,如果目标维度不是 768,则添加一个线性变换层用于降维或升维
+ self.embedding_size = embedding_size
+ if self.embedding_size != 768:
+ self.embed_head = nn.Linear(768, self.embedding_size)
+
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ def forward(self, x: torch.Tensor, no_grad: bool = True) -> torch.Tensor:
+ """
+ 前向传播,获取输入序列的语言表示。
+
+ 参数:
+ - x (torch.Tensor): 输入的张量,通常是序列的 token 索引,形状为 [batch_size, seq_len]。
+ - no_grad (bool): 是否在无梯度模式下运行,默认为 True。
+
+ 返回:
+ - torch.Tensor: 经过处理的语言嵌入向量,形状为 [batch_size, embedding_size]。
+ """
+ if no_grad:
+ # 在 no_grad 模式下禁用梯度计算以节省显存
+ with torch.no_grad():
+ x = x.long() # 确保输入张量为长整型
+ outputs = self.model(x) # 获取模型的输出
+
+ # 模型输出的 last_hidden_state 形状为 [batch_size, seq_len, hidden_size]
+ # 我们通常取 [CLS] 标记对应的向量,即 outputs.last_hidden_state[:, 0, :]
+ cls_embedding = outputs.last_hidden_state[:, 0, :]
+
+ # 如果目标的 embedding_size 不是 768,则应用线性变换
+ if self.embedding_size == 768:
+ # NOTE: very important for training stability.
+ cls_embedding = self.sim_norm(cls_embedding)
+
+ return cls_embedding
+ else:
+ cls_embedding = self.embed_head(cls_embedding)
+
+ # NOTE: very important for training stability.
+ cls_embedding = self.sim_norm(cls_embedding)
+
+ return cls_embedding
+ else:
+ # 非 no_grad 模式下,启用梯度计算
+ x = x.long() # 确保输入张量为长整型
+ outputs = self.model(x)
+ cls_embedding = outputs.last_hidden_state[:, 0, :]
+
+ # 如果目标的 embedding_size 不是 768,则应用线性变换
+ if self.embedding_size == 768:
+ # NOTE: very important for training stability.
+ cls_embedding = self.sim_norm(cls_embedding)
+
+ return cls_embedding
+ else:
+ cls_embedding = self.embed_head(cls_embedding)
+
+ # NOTE: very important for training stability.
+ cls_embedding = self.sim_norm(cls_embedding)
+
+ return cls_embedding
+
+
+
+class RepresentationNetworkUniZero(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: SequenceType = (3, 64, 64),
+ num_res_blocks: int = 1,
+ num_channels: int = 64,
+ downsample: bool = True,
+ activation: nn.Module = nn.GELU(approximate='tanh'),
+ norm_type: str = 'BN',
+ embedding_dim: int = 256,
+ group_size: int = 8,
+ ) -> None:
+ """
+ Overview:
+ Representation network used in UniZero. Encode the 2D image obs into latent state.
+ Currently, the network only supports obs images with both a width and height of 64.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64]
+ for video games like atari, RGB 3 channel.
+ - num_res_blocks (:obj:`int`): The number of residual blocks.
+ - num_channels (:obj:`int`): The channel of output hidden state.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
+ defaults to True. This option is often used in video games like Atari. In board games like go, \
+ we don't need this module.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ - embedding_dim (:obj:`int`): The dimension of the latent state.
+ - group_size (:obj:`int`): The dimension for simplicial normalization.
+ """
+ super().__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+ logging.info(f"Using norm type: {norm_type}")
+ logging.info(f"Using activation type: {activation}")
+
+ self.observation_shape = observation_shape
+ self.downsample = downsample
+ if self.downsample:
+ self.downsample_net = DownSample(
+ observation_shape,
+ num_channels,
+ activation=activation,
+ norm_type=norm_type,
+ )
+ else:
+ self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False)
+
+ if norm_type == 'BN':
+ self.norm = nn.BatchNorm2d(num_channels)
+ elif norm_type == 'LN':
+ if downsample:
+ self.norm = nn.LayerNorm(
+ [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)],
+ eps=1e-5)
+ else:
+ self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5)
+
+ self.resblocks = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_res_blocks)
+ ]
+ )
+ self.activation = activation
+ self.embedding_dim = embedding_dim
+
+ if self.observation_shape[1] == 64:
+ self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False)
+
+ elif self.observation_shape[1] in [84, 96]:
+ self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False)
+
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \
+ H is height.
+ - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \
+ output width, H_ is output height.
+ """
+ if self.downsample:
+ x = self.downsample_net(x)
+ else:
+ x = self.conv(x)
+ x = self.norm(x)
+ x = self.activation(x)
+ for block in self.resblocks:
+ x = block(x)
+
+ # Important: Transform the output feature plane to the latent state.
+ # For example, for an Atari feature plane of shape (64, 8, 8),
+ # flattening results in a size of 4096, which is then transformed to 768.
+ x = self.last_linear(x.view(x.size(0), -1))
+
+ x = x.view(-1, self.embedding_dim)
+
+ # NOTE: very important for training stability.
+ x = self.sim_norm(x)
+
+ return x
+
+
+class RepresentationNetwork(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: SequenceType = (4, 96, 96),
+ num_res_blocks: int = 1,
+ num_channels: int = 64,
+ downsample: bool = True,
+ activation: nn.Module = nn.ReLU(inplace=True),
+ norm_type: str = 'BN',
+ embedding_dim: int = 256,
+ group_size: int = 8,
+ use_sim_norm: bool = False,
+ ) -> None:
+ """
+ Overview:
+ Representation network used in MuZero and derived algorithms. Encode the 2D image obs into latent state.
+ Currently, the network only supports obs images with both a width and height of 96.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[4, 96, 96]
+ for video games like atari, 1 gray channel times stack 4 frames.
+ - num_res_blocks (:obj:`int`): The number of residual blocks.
+ - num_channels (:obj:`int`): The channel of output hidden state.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
+ defaults to True. This option is often used in video games like Atari. In board games like go, \
+ we don't need this module.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ - embedding_dim (:obj:`int`): The dimension of the output hidden state.
+ - group_size (:obj:`int`): The size of group in the SimNorm layer.
+ - use_sim_norm (:obj:`bool`): Whether to use SimNorm layer, defaults to False.
+ """
+ super().__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+
+ self.downsample = downsample
+ if self.downsample:
+ self.downsample_net = DownSample(
+ observation_shape,
+ num_channels,
+ activation=activation,
+ norm_type=norm_type,
+ )
+ else:
+ self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False)
+
+ if norm_type == 'BN':
+ self.norm = nn.BatchNorm2d(num_channels)
+ elif norm_type == 'LN':
+ if downsample:
+ self.norm = nn.LayerNorm(
+ [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)],
+ eps=1e-5)
+ else:
+ self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5)
+
+ self.resblocks = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_res_blocks)
+ ]
+ )
+ self.activation = activation
+
+ self.use_sim_norm = use_sim_norm
+
+ if self.use_sim_norm:
+ self.embedding_dim = embedding_dim
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \
+ H is height.
+ - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \
+ output width, H_ is output height.
+ """
+ if self.downsample:
+ x = self.downsample_net(x)
+ else:
+ x = self.conv(x)
+ x = self.norm(x)
+ x = self.activation(x)
+
+ for block in self.resblocks:
+ x = block(x)
+
+ if self.use_sim_norm:
+ # NOTE: very important.
+ # for atari 64,8,8 = 4096 -> 768
+ x = self.sim_norm(x)
+
+ return x
+
+
+class RepresentationNetworkMLP(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: int,
+ hidden_channels: int = 64,
+ layer_num: int = 2,
+ activation: nn.Module = nn.GELU(approximate='tanh'),
+ norm_type: Optional[str] = 'BN',
+ group_size: int = 8,
+ ) -> torch.Tensor:
+ """
+ Overview:
+ Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \
+ with Multi-Layer Perceptron (MLP).
+ Arguments:
+ - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10.
+ - num_res_blocks (:obj:`int`): The number of residual blocks.
+ - hidden_channels (:obj:`int`): The channel of output hidden state.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
+ defaults to True. This option is often used in video games like Atari. In board games like go, \
+ we don't need this module.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super().__init__()
+ self.fc_representation = MLP(
+ in_channels=observation_shape,
+ hidden_channels=hidden_channels,
+ out_channels=hidden_channels,
+ layer_num=layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ # don't use activation and norm in the last layer of representation network is important for convergence.
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=True,
+ )
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation.
+ - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size.
+ """
+ x = self.fc_representation(x)
+ # TODO
+ x = self.sim_norm(x)
+ return x
+
+
+class LatentDecoder(nn.Module):
+
+ def __init__(self, embedding_dim: int, output_shape: SequenceType, num_channels: int = 64, activation: nn.Module = nn.GELU(approximate='tanh')):
+ """
+ Overview:
+ Decoder network used in UniZero. Decode the latent state into 2D image obs.
+ Arguments:
+ - embedding_dim (:obj:`int`): The dimension of the latent state.
+ - output_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64]
+ for video games like atari, RGB 3 channel times stack 4 frames.
+ - num_channels (:obj:`int`): The channel of output hidden state.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh').
+ """
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.output_shape = output_shape # (C, H, W)
+ self.num_channels = num_channels
+ self.activation = activation
+
+ # Assuming that the output shape is (C, H, W) = (12, 96, 96) and embedding_dim is 256
+ # We will reverse the process of the representation network
+ self.initial_size = (
+ num_channels, output_shape[1] // 8, output_shape[2] // 8) # This should match the last layer of the encoder
+ self.fc = nn.Linear(self.embedding_dim, np.prod(self.initial_size))
+
+ # Upsampling blocks
+ self.conv_blocks = nn.ModuleList([
+ # Block 1: (num_channels, H/8, W/8) -> (num_channels//2, H/4, W/4)
+ nn.ConvTranspose2d(num_channels, num_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1),
+ self.activation,
+ nn.BatchNorm2d(num_channels // 2),
+ # Block 2: (num_channels//2, H/4, W/4) -> (num_channels//4, H/2, W/2)
+ nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1,
+ output_padding=1),
+ self.activation,
+ nn.BatchNorm2d(num_channels // 4),
+ # Block 3: (num_channels//4, H/2, W/2) -> (output_shape[0], H, W)
+ nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1,
+ output_padding=1),
+ ])
+ # TODO: last layer use sigmoid?
+
+ def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
+ # Map embeddings back to the image space
+ x = self.fc(embeddings) # (B, embedding_dim) -> (B, C*H/8*W/8)
+ x = x.view(-1, *self.initial_size) # (B, C*H/8*W/8) -> (B, C, H/8, W/8)
+
+ # Apply conv blocks
+ for block in self.conv_blocks:
+ x = block(x) # Upsample progressively
+
+ # The output x should have the shape of (B, output_shape[0], output_shape[1], output_shape[2])
+ return x
+
+
+class LatentEncoderForMemoryEnv(nn.Module):
+
+ def __init__(
+ self,
+ image_shape=(3, 5, 5),
+ embedding_size=100,
+ channels=[16, 32, 64],
+ kernel_sizes=[3, 3, 3],
+ strides=[1, 1, 1],
+ activation: nn.Module = nn.GELU(approximate='tanh'),
+ normalize_pixel=False,
+ group_size: int = 8,
+ **kwargs,
+ ):
+ """
+ Overview:
+ Encoder network used in UniZero in MemoryEnv. Encode the 2D image obs into latent state.
+ Arguments:
+ - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64]
+ for video games like atari, RGB 3 channel times stack 4 frames.
+ - embedding_size (:obj:`int`): The dimension of the latent state.
+ - channels (:obj:`List[int]`): The channel of output hidden state.
+ - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers.
+ - strides (:obj:`List[int]`): The stride of convolution layers.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). \
+ Use the inplace operation to speed up.
+ - normalize_pixel (:obj:`bool`): Whether to normalize the pixel values to [0, 1], defaults to False.
+ - group_size (:obj:`int`): The dimension for simplicial normalization
+ """
+ super(LatentEncoderForMemoryEnv, self).__init__()
+ self.shape = image_shape
+ self.channels = [image_shape[0]] + list(channels)
+
+ layers = []
+ for i in range(len(self.channels) - 1):
+ layers.append(
+ nn.Conv2d(
+ self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i],
+ padding=kernel_sizes[i] // 2 # keep the same size of feature map
+ )
+ )
+ layers.append(nn.BatchNorm2d(self.channels[i + 1]))
+ layers.append(activation)
+
+ layers.append(nn.AdaptiveAvgPool2d(1))
+
+ self.cnn = nn.Sequential(*layers)
+ self.linear = nn.Sequential(
+ nn.Linear(self.channels[-1], embedding_size, bias=False),
+ )
+ init.kaiming_normal_(self.linear[0].weight, mode='fan_out', nonlinearity='relu')
+
+ self.normalize_pixel = normalize_pixel
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ def forward(self, image):
+ if self.normalize_pixel:
+ image = image / 255.0
+ x = self.cnn(image.float()) # (B, C, 1, 1)
+ x = torch.flatten(x, start_dim=1) # (B, C)
+ x = self.linear(x) # (B, embedding_size)
+ x = self.sim_norm(x)
+ return x
+
+
+class LatentDecoderForMemoryEnv(nn.Module):
+
+ def __init__(
+ self,
+ image_shape=(3, 5, 5),
+ embedding_size=256,
+ channels=[64, 32, 16],
+ kernel_sizes=[3, 3, 3],
+ strides=[1, 1, 1],
+ activation: nn.Module = nn.LeakyReLU(negative_slope=0.01),
+ **kwargs,
+ ):
+ """
+ Overview:
+ Decoder network used in UniZero in MemoryEnv. Decode the latent state into 2D image obs.
+ Arguments:
+ - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64]
+ for video games like atari, RGB 3 channel times stack 4 frames.
+ - embedding_size (:obj:`int`): The dimension of the latent state.
+ - channels (:obj:`List[int]`): The channel of output hidden state.
+ - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers.
+ - strides (:obj:`List[int]`): The stride of convolution layers.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.LeakyReLU(). \
+ Use the inplace operation to speed up.
+ """
+ super(LatentDecoderForMemoryEnv, self).__init__()
+ self.shape = image_shape
+ self.channels = list(channels) + [image_shape[0]]
+
+ self.linear = nn.Linear(embedding_size, channels[0] * image_shape[1] * image_shape[2])
+
+ layers = []
+ for i in range(len(self.channels) - 1):
+ layers.append(
+ nn.ConvTranspose2d(
+ self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i],
+ padding=kernel_sizes[i] // 2, output_padding=strides[i] - 1
+ )
+ )
+ if i < len(self.channels) - 2:
+ layers.append(nn.BatchNorm2d(self.channels[i + 1]))
+ layers.append(activation)
+ else:
+ layers.append(nn.Sigmoid())
+
+ self.deconv = nn.Sequential(*layers)
+
+ def forward(self, embedding):
+ x = self.linear(embedding)
+ x = x.view(-1, self.channels[0], self.shape[1], self.shape[2])
+ x = self.deconv(x) # (B, C, H, W)
+ return x
+
+
+class VectorDecoderForMemoryEnv(nn.Module):
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ output_shape: SequenceType,
+ hidden_channels: int = 64,
+ layer_num: int = 2,
+ activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), # TODO
+ norm_type: Optional[str] = 'BN',
+ ) -> torch.Tensor:
+ """
+ Overview:
+ Decoder network used in UniZero in MemoryEnv. Decode the latent state into vector obs.
+ Arguments:
+ - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10.
+ - num_res_blocks (:obj:`int`): The number of residual blocks.
+ - hidden_channels (:obj:`int`): The channel of output hidden state.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
+ defaults to True. This option is often used in video games like Atari. In board games like go, \
+ we don't need this module.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super().__init__()
+ self.fc_representation = MLP(
+ in_channels=embedding_dim,
+ hidden_channels=hidden_channels,
+ out_channels=output_shape,
+ layer_num=layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ # don't use activation and norm in the last layer of representation network is important for convergence.
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=True,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation.
+ - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size.
+ """
+ x = self.fc_representation(x)
+ return x
+
+
+class PredictionNetwork(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: SequenceType,
+ action_space_size: int,
+ num_res_blocks: int,
+ num_channels: int,
+ value_head_channels: int,
+ policy_head_channels: int,
+ fc_value_layers: int,
+ fc_policy_layers: int,
+ output_support_size: int,
+ flatten_output_size_for_value_head: int,
+ flatten_output_size_for_policy_head: int,
+ downsample: bool = False,
+ last_linear_layer_init_zero: bool = True,
+ activation: nn.Module = nn.ReLU(inplace=True),
+ norm_type: Optional[str] = 'BN',
+ ) -> None:
+ """
+ Overview:
+ The definition of policy and value prediction network, which is used to predict value and policy by the
+ given latent state.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image.
+ - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space.
+ - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model.
+ - num_channels (:obj:`int`): The channels of hidden states.
+ - value_head_channels (:obj:`int`): The channels of value head.
+ - policy_head_channels (:obj:`int`): The channels of policy head.
+ - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
+ - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
+ - output_support_size (:obj:`int`): The size of categorical value output.
+ - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \
+ - flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
+ of the value head.
+ - flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
+ of the policy head.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``.
+ - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \
+ dynamics/prediction mlp, default sets it to True.
+ - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \
+ operation to speedup, e.g. ReLU(inplace=True).
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super(PredictionNetwork, self).__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+
+ self.resblocks = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_res_blocks)
+ ]
+ )
+
+ self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1)
+ self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1)
+
+ if observation_shape[1] == 96:
+ latent_shape = (observation_shape[1] / 16, observation_shape[2] / 16)
+ elif observation_shape[1] == 64:
+ latent_shape = (observation_shape[1] / 8, observation_shape[2] / 8)
+
+ if norm_type == 'BN':
+ self.norm_value = nn.BatchNorm2d(value_head_channels)
+ self.norm_policy = nn.BatchNorm2d(policy_head_channels)
+ elif norm_type == 'LN':
+ if downsample:
+ self.norm_value = nn.LayerNorm(
+ [value_head_channels, *latent_shape],
+ eps=1e-5)
+ self.norm_policy = nn.LayerNorm([policy_head_channels, *latent_shape], eps=1e-5)
+ else:
+ self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]],
+ eps=1e-5)
+ self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]],
+ eps=1e-5)
+
+ self.flatten_output_size_for_value_head = flatten_output_size_for_value_head
+ self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head
+
+ self.activation = activation
+
+ self.fc_value = MLP(
+ in_channels=self.flatten_output_size_for_value_head,
+ hidden_channels=fc_value_layers[0],
+ out_channels=output_support_size,
+ layer_num=len(fc_value_layers) + 1,
+ activation=self.activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+ self.fc_policy = MLP(
+ in_channels=self.flatten_output_size_for_policy_head,
+ hidden_channels=fc_policy_layers[0],
+ out_channels=action_space_size,
+ layer_num=len(fc_policy_layers) + 1,
+ activation=self.activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+
+ def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ Forward computation of the prediction network.
+ Arguments:
+ - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim).
+ Returns:
+ - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size).
+ - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size).
+ """
+ for res_block in self.resblocks:
+ latent_state = res_block(latent_state)
+
+ value = self.conv1x1_value(latent_state)
+ value = self.norm_value(value)
+ value = self.activation(value)
+
+ policy = self.conv1x1_policy(latent_state)
+ policy = self.norm_policy(policy)
+ policy = self.activation(policy)
+
+ value = value.reshape(-1, self.flatten_output_size_for_value_head)
+ policy = policy.reshape(-1, self.flatten_output_size_for_policy_head)
+
+ value = self.fc_value(value)
+ policy = self.fc_policy(policy)
+ return policy, value
+
+
+class PredictionNetworkMLP(nn.Module):
+
+ def __init__(
+ self,
+ action_space_size,
+ num_channels,
+ common_layer_num: int = 2,
+ fc_value_layers: SequenceType = [32],
+ fc_policy_layers: SequenceType = [32],
+ output_support_size: int = 601,
+ last_linear_layer_init_zero: bool = True,
+ activation: Optional[nn.Module] = nn.ReLU(inplace=True),
+ norm_type: Optional[str] = 'BN',
+ ):
+ """
+ Overview:
+ The definition of policy and value prediction network with Multi-Layer Perceptron (MLP),
+ which is used to predict value and policy by the given latent state.
+ Arguments:
+ - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \
+ space, it is the number of discrete actions.
+ - num_channels (:obj:`int`): The channels of latent states.
+ - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
+ - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
+ - output_support_size (:obj:`int`): The size of categorical value output.
+ - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \
+ dynamics/prediction mlp, default sets it to True.
+ - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \
+ operation to speedup, e.g. ReLU(inplace=True).
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super().__init__()
+ self.num_channels = num_channels
+
+ # ******* common backbone ******
+ self.fc_prediction_common = MLP(
+ in_channels=self.num_channels,
+ hidden_channels=self.num_channels,
+ out_channels=self.num_channels,
+ layer_num=common_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ output_activation=True,
+ output_norm=True,
+ # last_linear_layer_init_zero=False is important for convergence
+ last_linear_layer_init_zero=False,
+ )
+
+ # ******* value and policy head ******
+ self.fc_value_head = MLP(
+ in_channels=self.num_channels,
+ hidden_channels=fc_value_layers[0],
+ out_channels=output_support_size,
+ layer_num=len(fc_value_layers) + 1,
+ activation=activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+ self.fc_policy_head = MLP(
+ in_channels=self.num_channels,
+ hidden_channels=fc_policy_layers[0],
+ out_channels=action_space_size,
+ layer_num=len(fc_policy_layers) + 1,
+ activation=activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+
+ def forward(self, latent_state: torch.Tensor):
+ """
+ Overview:
+ Forward computation of the prediction network.
+ Arguments:
+ - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim).
+ Returns:
+ - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size).
+ - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size).
+ """
+ x_prediction_common = self.fc_prediction_common(latent_state)
+
+ value = self.fc_value_head(x_prediction_common)
+ policy = self.fc_policy_head(x_prediction_common)
+ return policy, value
+
+
+class PredictionHiddenNetwork(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: SequenceType,
+ action_space_size: int,
+ num_res_blocks: int,
+ num_channels: int,
+ value_head_channels: int,
+ policy_head_channels: int,
+ fc_value_layers: int,
+ fc_policy_layers: int,
+ output_support_size: int,
+ flatten_output_size_for_value_head: int,
+ flatten_output_size_for_policy_head: int,
+ downsample: bool = False,
+ last_linear_layer_init_zero: bool = True,
+ activation: nn.Module = nn.ReLU(inplace=True),
+ norm_type: Optional[str] = 'BN',
+ gru_hidden_size: int = 512,
+ ) -> None:
+ """
+ Overview:
+ The definition of policy and value prediction network, which is used to predict value and policy by the
+ given latent state.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image.
+ - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space.
+ - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model.
+ - num_channels (:obj:`int`): The channels of hidden states.
+ - value_head_channels (:obj:`int`): The channels of value head.
+ - policy_head_channels (:obj:`int`): The channels of policy head.
+ - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
+ - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
+ - output_support_size (:obj:`int`): The size of categorical value output.
+ - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \
+ - flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
+ of the value head.
+ - flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
+ of the policy head.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``.
+ - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \
+ dynamics/prediction mlp, default sets it to True.
+ - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \
+ operation to speedup, e.g. ReLU(inplace=True).
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super(PredictionHiddenNetwork, self).__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+
+ self.observation_shape = observation_shape
+ self.gru_hidden_size = gru_hidden_size
+ self.resblocks = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_res_blocks)
+ ]
+ )
+
+ self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1)
+ self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1)
+
+ if norm_type == 'BN':
+ self.norm_value = nn.BatchNorm2d(value_head_channels)
+ self.norm_policy = nn.BatchNorm2d(policy_head_channels)
+ elif norm_type == 'LN':
+ if downsample:
+ self.norm_value = nn.LayerNorm(
+ [value_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)],
+ eps=1e-5)
+ self.norm_policy = nn.LayerNorm([policy_head_channels, math.ceil(observation_shape[-2] / 16),
+ math.ceil(observation_shape[-1] / 16)], eps=1e-5)
+ else:
+ self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]],
+ eps=1e-5)
+ self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]],
+ eps=1e-5)
+
+ self.flatten_output_size_for_value_head = flatten_output_size_for_value_head
+ self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head
+
+ self.activation = activation
+
+ self.fc_value = MLP(
+ in_channels=self.flatten_output_size_for_value_head + self.gru_hidden_size,
+ hidden_channels=fc_value_layers[0],
+ out_channels=output_support_size,
+ layer_num=len(fc_value_layers) + 1,
+ activation=self.activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+ self.fc_policy = MLP(
+ in_channels=self.flatten_output_size_for_policy_head + self.gru_hidden_size,
+ hidden_channels=fc_policy_layers[0],
+ out_channels=action_space_size,
+ layer_num=len(fc_policy_layers) + 1,
+ activation=self.activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+
+ def forward(self, latent_state: torch.Tensor, world_model_latent_history: torch.Tensor) -> Tuple[
+ torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ Forward computation of the prediction network.
+ Arguments:
+ - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim).
+ Returns:
+ - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size).
+ - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size).
+ """
+ for res_block in self.resblocks:
+ latent_state = res_block(latent_state)
+
+ value = self.conv1x1_value(latent_state)
+ value = self.norm_value(value)
+ value = self.activation(value)
+
+ policy = self.conv1x1_policy(latent_state)
+ policy = self.norm_policy(policy)
+ policy = self.activation(policy)
+
+ latent_state_value = value.reshape(-1, self.flatten_output_size_for_value_head)
+ latent_state_policy = policy.reshape(-1, self.flatten_output_size_for_policy_head)
+
+ # TODO: world_model_latent_history.squeeze(0) shape: (num_layers * num_directions, batch_size, hidden_size) -> ( batch_size, hidden_size)
+ latent_history_value = torch.cat([latent_state_value, world_model_latent_history.squeeze(0)], dim=1)
+ latent_history_policy = torch.cat([latent_state_policy, world_model_latent_history.squeeze(0)], dim=1)
+
+ value = self.fc_value(latent_history_value)
+ policy = self.fc_policy(latent_history_policy)
+ return policy, value
+
+
diff --git a/lzero/model/common_serve_input_id.py b/lzero/model/common_serve_input_id.py
new file mode 100644
index 000000000..c77d9cb3b
--- /dev/null
+++ b/lzero/model/common_serve_input_id.py
@@ -0,0 +1,1247 @@
+"""
+Overview:
+ In this Python file, we provide a collection of reusable model templates designed to streamline the development
+ process for various custom algorithms. By utilizing these pre-built model templates, users can quickly adapt and
+ customize their custom algorithms, ensuring efficient and effective development.
+ BTW, users can refer to the unittest of these model templates to learn how to use them.
+"""
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+from ding.torch_utils import MLP, ResBlock
+from ding.utils import SequenceType
+from ditk import logging
+from openai import OpenAI
+from transformers import AutoTokenizer
+
+# use dataclass to make the output of network more convenient to use
+@dataclass
+class MZRNNNetworkOutput:
+ # output format of the MuZeroRNN model
+ value: torch.Tensor
+ value_prefix: torch.Tensor
+ policy_logits: torch.Tensor
+ latent_state: torch.Tensor
+ predict_next_latent_state: torch.Tensor
+ reward_hidden_state: Tuple[torch.Tensor]
+
+
+@dataclass
+class EZNetworkOutput:
+ # output format of the EfficientZero model
+ value: torch.Tensor
+ value_prefix: torch.Tensor
+ policy_logits: torch.Tensor
+ latent_state: torch.Tensor
+ reward_hidden_state: Tuple[torch.Tensor]
+
+
+@dataclass
+class MZNetworkOutput:
+ # output format of the MuZero model
+ value: torch.Tensor
+ reward: torch.Tensor
+ policy_logits: torch.Tensor
+ latent_state: torch.Tensor
+
+
+class SimNorm(nn.Module):
+
+ def __init__(self, simnorm_dim: int) -> None:
+ """
+ Overview:
+ Simplicial normalization. Adapted from https://arxiv.org/abs/2204.00616.
+ Arguments:
+ - simnorm_dim (:obj:`int`): The dimension for simplicial normalization.
+ """
+ super().__init__()
+ self.dim = simnorm_dim
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Forward pass of the SimNorm layer.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor to normalize.
+ Returns:
+ - x (:obj:`torch.Tensor`): The normalized tensor.
+ """
+ shp = x.shape
+ # Ensure that there is at least one simplex to normalize across.
+ if shp[1] != 0:
+ x = x.view(*shp[:-1], -1, self.dim)
+ x = F.softmax(x, dim=-1)
+ return x.view(*shp)
+ else:
+ return x
+
+ def __repr__(self) -> str:
+ """
+ Overview:
+ String representation of the SimNorm layer.
+ Returns:
+ - output (:obj:`str`): The string representation.
+ """
+ return f"SimNorm(dim={self.dim})"
+
+
+def AvgL1Norm(x, eps=1e-8):
+ """
+ Overview:
+ Normalize the input tensor by the L1 norm.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor to normalize.
+ - eps (:obj:`float`): The epsilon value to prevent division by zero.
+ Returns:
+ - :obj:`torch.Tensor`: The normalized tensor.
+ """
+ return x / x.abs().mean(-1, keepdim=True).clamp(min=eps)
+
+
+class FeatureAndGradientHook:
+
+ def __init__(self):
+ """
+ Overview:
+ Class to capture features and gradients at SimNorm.
+ """
+ self.features_before = []
+ self.features_after = []
+ self.grads_before = []
+ self.grads_after = []
+
+ def setup_hooks(self, model):
+ # Hooks to capture features and gradients at SimNorm
+ self.forward_handler = model.sim_norm.register_forward_hook(self.forward_hook)
+ self.backward_handler = model.sim_norm.register_full_backward_hook(self.backward_hook)
+
+ def forward_hook(self, module, input, output):
+ with torch.no_grad():
+ self.features_before.append(input[0])
+ self.features_after.append(output)
+
+ def backward_hook(self, module, grad_input, grad_output):
+ with torch.no_grad():
+ self.grads_before.append(grad_input[0] if grad_input[0] is not None else None)
+ self.grads_after.append(grad_output[0] if grad_output[0] is not None else None)
+
+ def analyze(self):
+ # Calculate L2 norms of features
+ l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_before]))
+ l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_after]))
+
+ # Calculate norms of gradients
+ grad_norm_before = torch.mean(
+ torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_before if g is not None]))
+ grad_norm_after = torch.mean(
+ torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_after if g is not None]))
+
+ # Clear stored data and delete tensors to free memory
+ self.clear_data()
+
+ # Optionally clear CUDA cache
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return l2_norm_before, l2_norm_after, grad_norm_before, grad_norm_after
+
+ def clear_data(self):
+ del self.features_before[:]
+ del self.features_after[:]
+ del self.grads_before[:]
+ del self.grads_after[:]
+
+ def remove_hooks(self):
+ self.forward_handler.remove()
+ self.backward_handler.remove()
+
+
+class DownSample(nn.Module):
+
+ def __init__(self, observation_shape: SequenceType, out_channels: int,
+ activation: nn.Module = nn.ReLU(inplace=True),
+ norm_type: Optional[str] = 'BN',
+ num_resblocks: int = 1,
+ ) -> None:
+ """
+ Overview:
+ Define downSample convolution network. Encode the observation into hidden state.
+ This network is often used in video games like Atari. In board games like go and chess,
+ we don't need this module.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[12, 96, 96]
+ for video games like atari, RGB 3 channel times stack 4 frames.
+ - out_channels (:obj:`int`): The output channels of output hidden state.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`Optional[str]`): The normalization type used in network, defaults to 'BN'.
+ - num_resblocks (:obj:`int`): The number of residual blocks. Defaults to 1.
+ """
+ super().__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+
+ self.observation_shape = observation_shape
+ self.conv1 = nn.Conv2d(
+ observation_shape[0],
+ out_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False, # disable bias for better convergence
+ )
+ if norm_type == 'BN':
+ self.norm1 = nn.BatchNorm2d(out_channels // 2)
+ elif norm_type == 'LN':
+ self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2],
+ eps=1e-5)
+
+ self.resblocks1 = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=out_channels // 2,
+ activation=activation,
+ norm_type=norm_type,
+ res_type='basic',
+ bias=False
+ ) for _ in range(num_resblocks)
+ ]
+ )
+ self.downsample_block = ResBlock(
+ in_channels=out_channels // 2,
+ out_channels=out_channels,
+ activation=activation,
+ norm_type=norm_type,
+ res_type='downsample',
+ bias=False
+ )
+ self.resblocks2 = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_resblocks)
+ ]
+ )
+ self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
+ self.resblocks3 = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(1)
+ ]
+ )
+ self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
+ self.activation = activation
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \
+ H is height.
+ - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \
+ output width, H_ is output height.
+ """
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.activation(x)
+
+ for block in self.resblocks1:
+ x = block(x)
+ x = self.downsample_block(x)
+ for block in self.resblocks2:
+ x = block(x)
+ x = self.pooling1(x)
+ for block in self.resblocks3:
+ x = block(x)
+
+ # 64, 84, 96 are the most common observation shapes in Atari games.
+ if self.observation_shape[1] == 64:
+ output = x
+ elif self.observation_shape[1] == 84:
+ x = self.pooling2(x)
+ output = x
+ elif self.observation_shape[1] == 96:
+ x = self.pooling2(x)
+ output = x
+ else:
+ raise NotImplementedError(f"DownSample for observation shape {self.observation_shape} is not implemented now. "
+ f"You should transform the observation shape to 64 or 96 in the env.")
+
+ return output
+"""
+使用vllm的BAAI/bge-base-en-v1.5 server, 模型的输入为id 需要先decode回string
+"""
+class HFLanguageRepresentationNetwork(nn.Module):
+ def __init__(
+ self,
+ url: str = 'BAAI/bge-base-en-v1.5',
+ embedding_size: int = 768,
+ group_size: int = 8,
+ api_base: str = "http://10.119.30.189:8081/v1",
+ api_key: str = "EMPTY"
+ ):
+ """
+ 初始化语言表示网络,使用 vLLM 的 API 服务获取嵌入。
+
+ 参数:
+ - url (str): vLLM 服务的模型名称,默认为 'BAAI/bge-base-en-v1.5'。
+ - embedding_size (int): 输出嵌入的维度大小,默认为 768。
+ - group_size (int): SimNorm 的组大小,默认为 8。
+ - api_base (str): vLLM API 服务器的基本 URL,默认为 "http://10.119.30.189:8081/v1"。
+ - api_key (str): API 密钥,默认值为 "EMPTY"。
+ """
+ super().__init__()
+ self.url = url
+ self.embedding_size = embedding_size
+ self.api_base = api_base
+ self.api_key = api_key
+
+ # 初始化 OpenAI 客户端以连接 vLLM 的 API 服务器
+ self.client = OpenAI(
+ api_key=api_key,
+ base_url=api_base,
+ )
+
+ # 获取模型 ID
+ models = self.client.models.list()
+ self.model_id = models.data[0].id if models.data else url
+
+ # 初始化线性变换层(如果需要)
+ if self.embedding_size != 768:
+ self.embed_head = nn.Linear(768, self.embedding_size)
+ else:
+ self.embed_head = None
+
+ # 初始化 SimNorm
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ # 初始化分词器,用于将 token 索引解码为字符串
+ self.tokenizer = AutoTokenizer.from_pretrained(url)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ 前向传播,获取输入序列的语言表示。
+
+ 参数:
+ - x (torch.Tensor): 输入的张量,形状为 [batch_size, seq_len],类型为 torch.long。
+
+ 返回:
+ - torch.Tensor: 经过处理的语言嵌入向量,形状为 [batch_size, embedding_size]。
+ """
+ with torch.no_grad():
+ x = x.long() # 确保输入张量为长整型
+
+ # # 检查索引范围
+ # min_idx = x.min().item()
+ # max_idx = x.max().item()
+ # # print(f"min_idx: {min_idx}, max_idx: {max_idx}")
+ # assert min_idx >= 0, "Negative token indices found."
+ # assert max_idx < self.tokenizer.vocab_size, f"Token index {max_idx} exceeds vocab size {self.tokenizer.vocab_size}."
+
+ # 将 token 索引解码为字符串
+ # 假设每个样本的 [CLS] token 在位置 0
+ # 可以根据实际情况调整
+ batch_size = x.size(0)
+ sentences: List[str] = []
+ for i in range(batch_size):
+ # 解码为字符串
+ tokens = x[i].tolist()
+ sentence = self.tokenizer.decode(tokens, skip_special_tokens=True)
+ sentences.append(sentence)
+
+ # 调用 vLLM 的嵌入 API
+ response = self.client.embeddings.create(
+ input=sentences,
+ model=self.model_id,
+ )
+
+ # 提取嵌入并转换为张量
+ embeddings = [data.embedding for data in response.data] # List[List[float]]
+ embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32, device=x.device) # [batch_size, 768]
+
+ # 如果需要降维或升维
+ if self.embed_head is not None:
+ embeddings_tensor = self.embed_head(embeddings_tensor) # [batch_size, embedding_size]
+
+ # 应用 SimNorm
+ embeddings_tensor = self.sim_norm(embeddings_tensor) # [batch_size, embedding_size]
+
+ return embeddings_tensor
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ # 移除无法序列化的对象
+ del state['client']
+ del state['tokenizer']
+ return state
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ # 重新初始化无法序列化的对象
+ self.client = OpenAI(
+ api_key=self.api_key,
+ base_url=self.api_base,
+ )
+ self.tokenizer = AutoTokenizer.from_pretrained(self.url)
+
+
+class RepresentationNetworkUniZero(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: SequenceType = (3, 64, 64),
+ num_res_blocks: int = 1,
+ num_channels: int = 64,
+ downsample: bool = True,
+ activation: nn.Module = nn.GELU(approximate='tanh'),
+ norm_type: str = 'BN',
+ embedding_dim: int = 256,
+ group_size: int = 8,
+ ) -> None:
+ """
+ Overview:
+ Representation network used in UniZero. Encode the 2D image obs into latent state.
+ Currently, the network only supports obs images with both a width and height of 64.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64]
+ for video games like atari, RGB 3 channel.
+ - num_res_blocks (:obj:`int`): The number of residual blocks.
+ - num_channels (:obj:`int`): The channel of output hidden state.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
+ defaults to True. This option is often used in video games like Atari. In board games like go, \
+ we don't need this module.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ - embedding_dim (:obj:`int`): The dimension of the latent state.
+ - group_size (:obj:`int`): The dimension for simplicial normalization.
+ """
+ super().__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+ logging.info(f"Using norm type: {norm_type}")
+ logging.info(f"Using activation type: {activation}")
+
+ self.observation_shape = observation_shape
+ self.downsample = downsample
+ if self.downsample:
+ self.downsample_net = DownSample(
+ observation_shape,
+ num_channels,
+ activation=activation,
+ norm_type=norm_type,
+ )
+ else:
+ self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False)
+
+ if norm_type == 'BN':
+ self.norm = nn.BatchNorm2d(num_channels)
+ elif norm_type == 'LN':
+ if downsample:
+ self.norm = nn.LayerNorm(
+ [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)],
+ eps=1e-5)
+ else:
+ self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5)
+
+ self.resblocks = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_res_blocks)
+ ]
+ )
+ self.activation = activation
+ self.embedding_dim = embedding_dim
+
+ if self.observation_shape[1] == 64:
+ self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False)
+
+ elif self.observation_shape[1] in [84, 96]:
+ self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False)
+
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \
+ H is height.
+ - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \
+ output width, H_ is output height.
+ """
+ if self.downsample:
+ x = self.downsample_net(x)
+ else:
+ x = self.conv(x)
+ x = self.norm(x)
+ x = self.activation(x)
+ for block in self.resblocks:
+ x = block(x)
+
+ # Important: Transform the output feature plane to the latent state.
+ # For example, for an Atari feature plane of shape (64, 8, 8),
+ # flattening results in a size of 4096, which is then transformed to 768.
+ x = self.last_linear(x.view(x.size(0), -1))
+
+ x = x.view(-1, self.embedding_dim)
+
+ # NOTE: very important for training stability.
+ x = self.sim_norm(x)
+
+ return x
+
+
+class RepresentationNetwork(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: SequenceType = (4, 96, 96),
+ num_res_blocks: int = 1,
+ num_channels: int = 64,
+ downsample: bool = True,
+ activation: nn.Module = nn.ReLU(inplace=True),
+ norm_type: str = 'BN',
+ embedding_dim: int = 256,
+ group_size: int = 8,
+ use_sim_norm: bool = False,
+ ) -> None:
+ """
+ Overview:
+ Representation network used in MuZero and derived algorithms. Encode the 2D image obs into latent state.
+ Currently, the network only supports obs images with both a width and height of 96.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[4, 96, 96]
+ for video games like atari, 1 gray channel times stack 4 frames.
+ - num_res_blocks (:obj:`int`): The number of residual blocks.
+ - num_channels (:obj:`int`): The channel of output hidden state.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
+ defaults to True. This option is often used in video games like Atari. In board games like go, \
+ we don't need this module.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ - embedding_dim (:obj:`int`): The dimension of the output hidden state.
+ - group_size (:obj:`int`): The size of group in the SimNorm layer.
+ - use_sim_norm (:obj:`bool`): Whether to use SimNorm layer, defaults to False.
+ """
+ super().__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+
+ self.downsample = downsample
+ if self.downsample:
+ self.downsample_net = DownSample(
+ observation_shape,
+ num_channels,
+ activation=activation,
+ norm_type=norm_type,
+ )
+ else:
+ self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False)
+
+ if norm_type == 'BN':
+ self.norm = nn.BatchNorm2d(num_channels)
+ elif norm_type == 'LN':
+ if downsample:
+ self.norm = nn.LayerNorm(
+ [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)],
+ eps=1e-5)
+ else:
+ self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5)
+
+ self.resblocks = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_res_blocks)
+ ]
+ )
+ self.activation = activation
+
+ self.use_sim_norm = use_sim_norm
+
+ if self.use_sim_norm:
+ self.embedding_dim = embedding_dim
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \
+ H is height.
+ - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \
+ output width, H_ is output height.
+ """
+ if self.downsample:
+ x = self.downsample_net(x)
+ else:
+ x = self.conv(x)
+ x = self.norm(x)
+ x = self.activation(x)
+
+ for block in self.resblocks:
+ x = block(x)
+
+ if self.use_sim_norm:
+ # NOTE: very important.
+ # for atari 64,8,8 = 4096 -> 768
+ x = self.sim_norm(x)
+
+ return x
+
+
+class RepresentationNetworkMLP(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: int,
+ hidden_channels: int = 64,
+ layer_num: int = 2,
+ activation: nn.Module = nn.GELU(approximate='tanh'),
+ norm_type: Optional[str] = 'BN',
+ group_size: int = 8,
+ ) -> torch.Tensor:
+ """
+ Overview:
+ Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \
+ with Multi-Layer Perceptron (MLP).
+ Arguments:
+ - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10.
+ - num_res_blocks (:obj:`int`): The number of residual blocks.
+ - hidden_channels (:obj:`int`): The channel of output hidden state.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
+ defaults to True. This option is often used in video games like Atari. In board games like go, \
+ we don't need this module.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super().__init__()
+ self.fc_representation = MLP(
+ in_channels=observation_shape,
+ hidden_channels=hidden_channels,
+ out_channels=hidden_channels,
+ layer_num=layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ # don't use activation and norm in the last layer of representation network is important for convergence.
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=True,
+ )
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation.
+ - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size.
+ """
+ x = self.fc_representation(x)
+ # TODO
+ x = self.sim_norm(x)
+ return x
+
+
+class LatentDecoder(nn.Module):
+
+ def __init__(self, embedding_dim: int, output_shape: SequenceType, num_channels: int = 64, activation: nn.Module = nn.GELU(approximate='tanh')):
+ """
+ Overview:
+ Decoder network used in UniZero. Decode the latent state into 2D image obs.
+ Arguments:
+ - embedding_dim (:obj:`int`): The dimension of the latent state.
+ - output_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64]
+ for video games like atari, RGB 3 channel times stack 4 frames.
+ - num_channels (:obj:`int`): The channel of output hidden state.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh').
+ """
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.output_shape = output_shape # (C, H, W)
+ self.num_channels = num_channels
+ self.activation = activation
+
+ # Assuming that the output shape is (C, H, W) = (12, 96, 96) and embedding_dim is 256
+ # We will reverse the process of the representation network
+ self.initial_size = (
+ num_channels, output_shape[1] // 8, output_shape[2] // 8) # This should match the last layer of the encoder
+ self.fc = nn.Linear(self.embedding_dim, np.prod(self.initial_size))
+
+ # Upsampling blocks
+ self.conv_blocks = nn.ModuleList([
+ # Block 1: (num_channels, H/8, W/8) -> (num_channels//2, H/4, W/4)
+ nn.ConvTranspose2d(num_channels, num_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1),
+ self.activation,
+ nn.BatchNorm2d(num_channels // 2),
+ # Block 2: (num_channels//2, H/4, W/4) -> (num_channels//4, H/2, W/2)
+ nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1,
+ output_padding=1),
+ self.activation,
+ nn.BatchNorm2d(num_channels // 4),
+ # Block 3: (num_channels//4, H/2, W/2) -> (output_shape[0], H, W)
+ nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1,
+ output_padding=1),
+ ])
+ # TODO: last layer use sigmoid?
+
+ def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
+ # Map embeddings back to the image space
+ x = self.fc(embeddings) # (B, embedding_dim) -> (B, C*H/8*W/8)
+ x = x.view(-1, *self.initial_size) # (B, C*H/8*W/8) -> (B, C, H/8, W/8)
+
+ # Apply conv blocks
+ for block in self.conv_blocks:
+ x = block(x) # Upsample progressively
+
+ # The output x should have the shape of (B, output_shape[0], output_shape[1], output_shape[2])
+ return x
+
+
+class LatentEncoderForMemoryEnv(nn.Module):
+
+ def __init__(
+ self,
+ image_shape=(3, 5, 5),
+ embedding_size=100,
+ channels=[16, 32, 64],
+ kernel_sizes=[3, 3, 3],
+ strides=[1, 1, 1],
+ activation: nn.Module = nn.GELU(approximate='tanh'),
+ normalize_pixel=False,
+ group_size: int = 8,
+ **kwargs,
+ ):
+ """
+ Overview:
+ Encoder network used in UniZero in MemoryEnv. Encode the 2D image obs into latent state.
+ Arguments:
+ - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64]
+ for video games like atari, RGB 3 channel times stack 4 frames.
+ - embedding_size (:obj:`int`): The dimension of the latent state.
+ - channels (:obj:`List[int]`): The channel of output hidden state.
+ - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers.
+ - strides (:obj:`List[int]`): The stride of convolution layers.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). \
+ Use the inplace operation to speed up.
+ - normalize_pixel (:obj:`bool`): Whether to normalize the pixel values to [0, 1], defaults to False.
+ - group_size (:obj:`int`): The dimension for simplicial normalization
+ """
+ super(LatentEncoderForMemoryEnv, self).__init__()
+ self.shape = image_shape
+ self.channels = [image_shape[0]] + list(channels)
+
+ layers = []
+ for i in range(len(self.channels) - 1):
+ layers.append(
+ nn.Conv2d(
+ self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i],
+ padding=kernel_sizes[i] // 2 # keep the same size of feature map
+ )
+ )
+ layers.append(nn.BatchNorm2d(self.channels[i + 1]))
+ layers.append(activation)
+
+ layers.append(nn.AdaptiveAvgPool2d(1))
+
+ self.cnn = nn.Sequential(*layers)
+ self.linear = nn.Sequential(
+ nn.Linear(self.channels[-1], embedding_size, bias=False),
+ )
+ init.kaiming_normal_(self.linear[0].weight, mode='fan_out', nonlinearity='relu')
+
+ self.normalize_pixel = normalize_pixel
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ def forward(self, image):
+ if self.normalize_pixel:
+ image = image / 255.0
+ x = self.cnn(image.float()) # (B, C, 1, 1)
+ x = torch.flatten(x, start_dim=1) # (B, C)
+ x = self.linear(x) # (B, embedding_size)
+ x = self.sim_norm(x)
+ return x
+
+
+class LatentDecoderForMemoryEnv(nn.Module):
+
+ def __init__(
+ self,
+ image_shape=(3, 5, 5),
+ embedding_size=256,
+ channels=[64, 32, 16],
+ kernel_sizes=[3, 3, 3],
+ strides=[1, 1, 1],
+ activation: nn.Module = nn.LeakyReLU(negative_slope=0.01),
+ **kwargs,
+ ):
+ """
+ Overview:
+ Decoder network used in UniZero in MemoryEnv. Decode the latent state into 2D image obs.
+ Arguments:
+ - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64]
+ for video games like atari, RGB 3 channel times stack 4 frames.
+ - embedding_size (:obj:`int`): The dimension of the latent state.
+ - channels (:obj:`List[int]`): The channel of output hidden state.
+ - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers.
+ - strides (:obj:`List[int]`): The stride of convolution layers.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.LeakyReLU(). \
+ Use the inplace operation to speed up.
+ """
+ super(LatentDecoderForMemoryEnv, self).__init__()
+ self.shape = image_shape
+ self.channels = list(channels) + [image_shape[0]]
+
+ self.linear = nn.Linear(embedding_size, channels[0] * image_shape[1] * image_shape[2])
+
+ layers = []
+ for i in range(len(self.channels) - 1):
+ layers.append(
+ nn.ConvTranspose2d(
+ self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i],
+ padding=kernel_sizes[i] // 2, output_padding=strides[i] - 1
+ )
+ )
+ if i < len(self.channels) - 2:
+ layers.append(nn.BatchNorm2d(self.channels[i + 1]))
+ layers.append(activation)
+ else:
+ layers.append(nn.Sigmoid())
+
+ self.deconv = nn.Sequential(*layers)
+
+ def forward(self, embedding):
+ x = self.linear(embedding)
+ x = x.view(-1, self.channels[0], self.shape[1], self.shape[2])
+ x = self.deconv(x) # (B, C, H, W)
+ return x
+
+
+class VectorDecoderForMemoryEnv(nn.Module):
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ output_shape: SequenceType,
+ hidden_channels: int = 64,
+ layer_num: int = 2,
+ activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), # TODO
+ norm_type: Optional[str] = 'BN',
+ ) -> torch.Tensor:
+ """
+ Overview:
+ Decoder network used in UniZero in MemoryEnv. Decode the latent state into vector obs.
+ Arguments:
+ - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10.
+ - num_res_blocks (:obj:`int`): The number of residual blocks.
+ - hidden_channels (:obj:`int`): The channel of output hidden state.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
+ defaults to True. This option is often used in video games like Atari. In board games like go, \
+ we don't need this module.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super().__init__()
+ self.fc_representation = MLP(
+ in_channels=embedding_dim,
+ hidden_channels=hidden_channels,
+ out_channels=output_shape,
+ layer_num=layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ # don't use activation and norm in the last layer of representation network is important for convergence.
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=True,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation.
+ - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size.
+ """
+ x = self.fc_representation(x)
+ return x
+
+
+class PredictionNetwork(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: SequenceType,
+ action_space_size: int,
+ num_res_blocks: int,
+ num_channels: int,
+ value_head_channels: int,
+ policy_head_channels: int,
+ fc_value_layers: int,
+ fc_policy_layers: int,
+ output_support_size: int,
+ flatten_output_size_for_value_head: int,
+ flatten_output_size_for_policy_head: int,
+ downsample: bool = False,
+ last_linear_layer_init_zero: bool = True,
+ activation: nn.Module = nn.ReLU(inplace=True),
+ norm_type: Optional[str] = 'BN',
+ ) -> None:
+ """
+ Overview:
+ The definition of policy and value prediction network, which is used to predict value and policy by the
+ given latent state.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image.
+ - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space.
+ - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model.
+ - num_channels (:obj:`int`): The channels of hidden states.
+ - value_head_channels (:obj:`int`): The channels of value head.
+ - policy_head_channels (:obj:`int`): The channels of policy head.
+ - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
+ - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
+ - output_support_size (:obj:`int`): The size of categorical value output.
+ - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \
+ - flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
+ of the value head.
+ - flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
+ of the policy head.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``.
+ - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \
+ dynamics/prediction mlp, default sets it to True.
+ - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \
+ operation to speedup, e.g. ReLU(inplace=True).
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super(PredictionNetwork, self).__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+
+ self.resblocks = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_res_blocks)
+ ]
+ )
+
+ self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1)
+ self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1)
+
+ if observation_shape[1] == 96:
+ latent_shape = (observation_shape[1] / 16, observation_shape[2] / 16)
+ elif observation_shape[1] == 64:
+ latent_shape = (observation_shape[1] / 8, observation_shape[2] / 8)
+
+ if norm_type == 'BN':
+ self.norm_value = nn.BatchNorm2d(value_head_channels)
+ self.norm_policy = nn.BatchNorm2d(policy_head_channels)
+ elif norm_type == 'LN':
+ if downsample:
+ self.norm_value = nn.LayerNorm(
+ [value_head_channels, *latent_shape],
+ eps=1e-5)
+ self.norm_policy = nn.LayerNorm([policy_head_channels, *latent_shape], eps=1e-5)
+ else:
+ self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]],
+ eps=1e-5)
+ self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]],
+ eps=1e-5)
+
+ self.flatten_output_size_for_value_head = flatten_output_size_for_value_head
+ self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head
+
+ self.activation = activation
+
+ self.fc_value = MLP(
+ in_channels=self.flatten_output_size_for_value_head,
+ hidden_channels=fc_value_layers[0],
+ out_channels=output_support_size,
+ layer_num=len(fc_value_layers) + 1,
+ activation=self.activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+ self.fc_policy = MLP(
+ in_channels=self.flatten_output_size_for_policy_head,
+ hidden_channels=fc_policy_layers[0],
+ out_channels=action_space_size,
+ layer_num=len(fc_policy_layers) + 1,
+ activation=self.activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+
+ def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ Forward computation of the prediction network.
+ Arguments:
+ - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim).
+ Returns:
+ - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size).
+ - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size).
+ """
+ for res_block in self.resblocks:
+ latent_state = res_block(latent_state)
+
+ value = self.conv1x1_value(latent_state)
+ value = self.norm_value(value)
+ value = self.activation(value)
+
+ policy = self.conv1x1_policy(latent_state)
+ policy = self.norm_policy(policy)
+ policy = self.activation(policy)
+
+ value = value.reshape(-1, self.flatten_output_size_for_value_head)
+ policy = policy.reshape(-1, self.flatten_output_size_for_policy_head)
+
+ value = self.fc_value(value)
+ policy = self.fc_policy(policy)
+ return policy, value
+
+
+class PredictionNetworkMLP(nn.Module):
+
+ def __init__(
+ self,
+ action_space_size,
+ num_channels,
+ common_layer_num: int = 2,
+ fc_value_layers: SequenceType = [32],
+ fc_policy_layers: SequenceType = [32],
+ output_support_size: int = 601,
+ last_linear_layer_init_zero: bool = True,
+ activation: Optional[nn.Module] = nn.ReLU(inplace=True),
+ norm_type: Optional[str] = 'BN',
+ ):
+ """
+ Overview:
+ The definition of policy and value prediction network with Multi-Layer Perceptron (MLP),
+ which is used to predict value and policy by the given latent state.
+ Arguments:
+ - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \
+ space, it is the number of discrete actions.
+ - num_channels (:obj:`int`): The channels of latent states.
+ - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
+ - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
+ - output_support_size (:obj:`int`): The size of categorical value output.
+ - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \
+ dynamics/prediction mlp, default sets it to True.
+ - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \
+ operation to speedup, e.g. ReLU(inplace=True).
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super().__init__()
+ self.num_channels = num_channels
+
+ # ******* common backbone ******
+ self.fc_prediction_common = MLP(
+ in_channels=self.num_channels,
+ hidden_channels=self.num_channels,
+ out_channels=self.num_channels,
+ layer_num=common_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ output_activation=True,
+ output_norm=True,
+ # last_linear_layer_init_zero=False is important for convergence
+ last_linear_layer_init_zero=False,
+ )
+
+ # ******* value and policy head ******
+ self.fc_value_head = MLP(
+ in_channels=self.num_channels,
+ hidden_channels=fc_value_layers[0],
+ out_channels=output_support_size,
+ layer_num=len(fc_value_layers) + 1,
+ activation=activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+ self.fc_policy_head = MLP(
+ in_channels=self.num_channels,
+ hidden_channels=fc_policy_layers[0],
+ out_channels=action_space_size,
+ layer_num=len(fc_policy_layers) + 1,
+ activation=activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+
+ def forward(self, latent_state: torch.Tensor):
+ """
+ Overview:
+ Forward computation of the prediction network.
+ Arguments:
+ - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim).
+ Returns:
+ - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size).
+ - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size).
+ """
+ x_prediction_common = self.fc_prediction_common(latent_state)
+
+ value = self.fc_value_head(x_prediction_common)
+ policy = self.fc_policy_head(x_prediction_common)
+ return policy, value
+
+
+class PredictionHiddenNetwork(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: SequenceType,
+ action_space_size: int,
+ num_res_blocks: int,
+ num_channels: int,
+ value_head_channels: int,
+ policy_head_channels: int,
+ fc_value_layers: int,
+ fc_policy_layers: int,
+ output_support_size: int,
+ flatten_output_size_for_value_head: int,
+ flatten_output_size_for_policy_head: int,
+ downsample: bool = False,
+ last_linear_layer_init_zero: bool = True,
+ activation: nn.Module = nn.ReLU(inplace=True),
+ norm_type: Optional[str] = 'BN',
+ gru_hidden_size: int = 512,
+ ) -> None:
+ """
+ Overview:
+ The definition of policy and value prediction network, which is used to predict value and policy by the
+ given latent state.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image.
+ - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space.
+ - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model.
+ - num_channels (:obj:`int`): The channels of hidden states.
+ - value_head_channels (:obj:`int`): The channels of value head.
+ - policy_head_channels (:obj:`int`): The channels of policy head.
+ - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
+ - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
+ - output_support_size (:obj:`int`): The size of categorical value output.
+ - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \
+ - flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
+ of the value head.
+ - flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
+ of the policy head.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``.
+ - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \
+ dynamics/prediction mlp, default sets it to True.
+ - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \
+ operation to speedup, e.g. ReLU(inplace=True).
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super(PredictionHiddenNetwork, self).__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+
+ self.observation_shape = observation_shape
+ self.gru_hidden_size = gru_hidden_size
+ self.resblocks = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_res_blocks)
+ ]
+ )
+
+ self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1)
+ self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1)
+
+ if norm_type == 'BN':
+ self.norm_value = nn.BatchNorm2d(value_head_channels)
+ self.norm_policy = nn.BatchNorm2d(policy_head_channels)
+ elif norm_type == 'LN':
+ if downsample:
+ self.norm_value = nn.LayerNorm(
+ [value_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)],
+ eps=1e-5)
+ self.norm_policy = nn.LayerNorm([policy_head_channels, math.ceil(observation_shape[-2] / 16),
+ math.ceil(observation_shape[-1] / 16)], eps=1e-5)
+ else:
+ self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]],
+ eps=1e-5)
+ self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]],
+ eps=1e-5)
+
+ self.flatten_output_size_for_value_head = flatten_output_size_for_value_head
+ self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head
+
+ self.activation = activation
+
+ self.fc_value = MLP(
+ in_channels=self.flatten_output_size_for_value_head + self.gru_hidden_size,
+ hidden_channels=fc_value_layers[0],
+ out_channels=output_support_size,
+ layer_num=len(fc_value_layers) + 1,
+ activation=self.activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+ self.fc_policy = MLP(
+ in_channels=self.flatten_output_size_for_policy_head + self.gru_hidden_size,
+ hidden_channels=fc_policy_layers[0],
+ out_channels=action_space_size,
+ layer_num=len(fc_policy_layers) + 1,
+ activation=self.activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+
+ def forward(self, latent_state: torch.Tensor, world_model_latent_history: torch.Tensor) -> Tuple[
+ torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ Forward computation of the prediction network.
+ Arguments:
+ - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim).
+ Returns:
+ - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size).
+ - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size).
+ """
+ for res_block in self.resblocks:
+ latent_state = res_block(latent_state)
+
+ value = self.conv1x1_value(latent_state)
+ value = self.norm_value(value)
+ value = self.activation(value)
+
+ policy = self.conv1x1_policy(latent_state)
+ policy = self.norm_policy(policy)
+ policy = self.activation(policy)
+
+ latent_state_value = value.reshape(-1, self.flatten_output_size_for_value_head)
+ latent_state_policy = policy.reshape(-1, self.flatten_output_size_for_policy_head)
+
+ # TODO: world_model_latent_history.squeeze(0) shape: (num_layers * num_directions, batch_size, hidden_size) -> ( batch_size, hidden_size)
+ latent_history_value = torch.cat([latent_state_value, world_model_latent_history.squeeze(0)], dim=1)
+ latent_history_policy = torch.cat([latent_state_policy, world_model_latent_history.squeeze(0)], dim=1)
+
+ value = self.fc_value(latent_history_value)
+ policy = self.fc_policy(latent_history_policy)
+ return policy, value
+
+
diff --git a/lzero/model/common_serve_input_str.py b/lzero/model/common_serve_input_str.py
new file mode 100644
index 000000000..c55a8d8e5
--- /dev/null
+++ b/lzero/model/common_serve_input_str.py
@@ -0,0 +1,1235 @@
+"""
+Overview:
+ In this Python file, we provide a collection of reusable model templates designed to streamline the development
+ process for various custom algorithms. By utilizing these pre-built model templates, users can quickly adapt and
+ customize their custom algorithms, ensuring efficient and effective development.
+ BTW, users can refer to the unittest of these model templates to learn how to use them.
+"""
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+from ding.torch_utils import MLP, ResBlock
+from ding.utils import SequenceType
+from ditk import logging
+from openai import OpenAI
+from transformers import AutoTokenizer
+
+
+# use dataclass to make the output of network more convenient to use
+@dataclass
+class MZRNNNetworkOutput:
+ # output format of the MuZeroRNN model
+ value: torch.Tensor
+ value_prefix: torch.Tensor
+ policy_logits: torch.Tensor
+ latent_state: torch.Tensor
+ predict_next_latent_state: torch.Tensor
+ reward_hidden_state: Tuple[torch.Tensor]
+
+
+@dataclass
+class EZNetworkOutput:
+ # output format of the EfficientZero model
+ value: torch.Tensor
+ value_prefix: torch.Tensor
+ policy_logits: torch.Tensor
+ latent_state: torch.Tensor
+ reward_hidden_state: Tuple[torch.Tensor]
+
+
+@dataclass
+class MZNetworkOutput:
+ # output format of the MuZero model
+ value: torch.Tensor
+ reward: torch.Tensor
+ policy_logits: torch.Tensor
+ latent_state: torch.Tensor
+
+
+class SimNorm(nn.Module):
+
+ def __init__(self, simnorm_dim: int) -> None:
+ """
+ Overview:
+ Simplicial normalization. Adapted from https://arxiv.org/abs/2204.00616.
+ Arguments:
+ - simnorm_dim (:obj:`int`): The dimension for simplicial normalization.
+ """
+ super().__init__()
+ self.dim = simnorm_dim
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Forward pass of the SimNorm layer.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor to normalize.
+ Returns:
+ - x (:obj:`torch.Tensor`): The normalized tensor.
+ """
+ shp = x.shape
+ # Ensure that there is at least one simplex to normalize across.
+ if shp[1] != 0:
+ x = x.view(*shp[:-1], -1, self.dim)
+ x = F.softmax(x, dim=-1)
+ return x.view(*shp)
+ else:
+ return x
+
+ def __repr__(self) -> str:
+ """
+ Overview:
+ String representation of the SimNorm layer.
+ Returns:
+ - output (:obj:`str`): The string representation.
+ """
+ return f"SimNorm(dim={self.dim})"
+
+
+def AvgL1Norm(x, eps=1e-8):
+ """
+ Overview:
+ Normalize the input tensor by the L1 norm.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor to normalize.
+ - eps (:obj:`float`): The epsilon value to prevent division by zero.
+ Returns:
+ - :obj:`torch.Tensor`: The normalized tensor.
+ """
+ return x / x.abs().mean(-1, keepdim=True).clamp(min=eps)
+
+
+class FeatureAndGradientHook:
+
+ def __init__(self):
+ """
+ Overview:
+ Class to capture features and gradients at SimNorm.
+ """
+ self.features_before = []
+ self.features_after = []
+ self.grads_before = []
+ self.grads_after = []
+
+ def setup_hooks(self, model):
+ # Hooks to capture features and gradients at SimNorm
+ self.forward_handler = model.sim_norm.register_forward_hook(self.forward_hook)
+ self.backward_handler = model.sim_norm.register_full_backward_hook(self.backward_hook)
+
+ def forward_hook(self, module, input, output):
+ with torch.no_grad():
+ self.features_before.append(input[0])
+ self.features_after.append(output)
+
+ def backward_hook(self, module, grad_input, grad_output):
+ with torch.no_grad():
+ self.grads_before.append(grad_input[0] if grad_input[0] is not None else None)
+ self.grads_after.append(grad_output[0] if grad_output[0] is not None else None)
+
+ def analyze(self):
+ # Calculate L2 norms of features
+ l2_norm_before = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_before]))
+ l2_norm_after = torch.mean(torch.stack([torch.norm(f, p=2, dim=1).mean() for f in self.features_after]))
+
+ # Calculate norms of gradients
+ grad_norm_before = torch.mean(
+ torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_before if g is not None]))
+ grad_norm_after = torch.mean(
+ torch.stack([torch.norm(g, p=2, dim=1).mean() for g in self.grads_after if g is not None]))
+
+ # Clear stored data and delete tensors to free memory
+ self.clear_data()
+
+ # Optionally clear CUDA cache
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ return l2_norm_before, l2_norm_after, grad_norm_before, grad_norm_after
+
+ def clear_data(self):
+ del self.features_before[:]
+ del self.features_after[:]
+ del self.grads_before[:]
+ del self.grads_after[:]
+
+ def remove_hooks(self):
+ self.forward_handler.remove()
+ self.backward_handler.remove()
+
+
+class DownSample(nn.Module):
+
+ def __init__(self, observation_shape: SequenceType, out_channels: int,
+ activation: nn.Module = nn.ReLU(inplace=True),
+ norm_type: Optional[str] = 'BN',
+ num_resblocks: int = 1,
+ ) -> None:
+ """
+ Overview:
+ Define downSample convolution network. Encode the observation into hidden state.
+ This network is often used in video games like Atari. In board games like go and chess,
+ we don't need this module.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[12, 96, 96]
+ for video games like atari, RGB 3 channel times stack 4 frames.
+ - out_channels (:obj:`int`): The output channels of output hidden state.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`Optional[str]`): The normalization type used in network, defaults to 'BN'.
+ - num_resblocks (:obj:`int`): The number of residual blocks. Defaults to 1.
+ """
+ super().__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+
+ self.observation_shape = observation_shape
+ self.conv1 = nn.Conv2d(
+ observation_shape[0],
+ out_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False, # disable bias for better convergence
+ )
+ if norm_type == 'BN':
+ self.norm1 = nn.BatchNorm2d(out_channels // 2)
+ elif norm_type == 'LN':
+ self.norm1 = nn.LayerNorm([out_channels // 2, observation_shape[-2] // 2, observation_shape[-1] // 2],
+ eps=1e-5)
+
+ self.resblocks1 = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=out_channels // 2,
+ activation=activation,
+ norm_type=norm_type,
+ res_type='basic',
+ bias=False
+ ) for _ in range(num_resblocks)
+ ]
+ )
+ self.downsample_block = ResBlock(
+ in_channels=out_channels // 2,
+ out_channels=out_channels,
+ activation=activation,
+ norm_type=norm_type,
+ res_type='downsample',
+ bias=False
+ )
+ self.resblocks2 = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_resblocks)
+ ]
+ )
+ self.pooling1 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
+ self.resblocks3 = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=out_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(1)
+ ]
+ )
+ self.pooling2 = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
+ self.activation = activation
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \
+ H is height.
+ - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \
+ output width, H_ is output height.
+ """
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.activation(x)
+
+ for block in self.resblocks1:
+ x = block(x)
+ x = self.downsample_block(x)
+ for block in self.resblocks2:
+ x = block(x)
+ x = self.pooling1(x)
+ for block in self.resblocks3:
+ x = block(x)
+
+ # 64, 84, 96 are the most common observation shapes in Atari games.
+ if self.observation_shape[1] == 64:
+ output = x
+ elif self.observation_shape[1] == 84:
+ x = self.pooling2(x)
+ output = x
+ elif self.observation_shape[1] == 96:
+ x = self.pooling2(x)
+ output = x
+ else:
+ raise NotImplementedError(f"DownSample for observation shape {self.observation_shape} is not implemented now. "
+ f"You should transform the observation shape to 64 or 96 in the env.")
+
+ return output
+
+"""
+使用vllm的BAAI/bge-base-en-v1.5 server, 模型的输入为string
+"""
+class HFLanguageRepresentationNetwork(nn.Module):
+ def __init__(
+ self,
+ url: str = 'BAAI/bge-base-en-v1.5',
+ embedding_size: int = 768,
+ group_size: int = 8,
+ api_base: str = "http://10.119.30.189:8081/v1",
+ api_key: str = "EMPTY"
+ ):
+ """
+ 初始化语言表示网络,使用 vLLM 的 API 服务获取嵌入。
+
+ 参数:
+ - url (str): vLLM 服务的模型名称,默认为 'BAAI/bge-base-en-v1.5'。
+ - embedding_size (int): 输出嵌入的维度大小,默认为 768。
+ - group_size (int): SimNorm 的组大小,默认为 8。
+ - api_base (str): vLLM API 服务器的基本 URL,默认为 "http://10.119.30.189:8081/v1"。
+ - api_key (str): API 密钥,默认值为 "EMPTY"。
+ """
+ super().__init__()
+ self.url = url
+ self.embedding_size = embedding_size
+ self.api_base = api_base
+ self.api_key = api_key
+
+ # 初始化 OpenAI 客户端以连接 vLLM 的 API 服务器
+ self.client = OpenAI(
+ api_key=api_key,
+ base_url=api_base,
+ )
+
+ # 获取模型 ID
+ models = self.client.models.list()
+ self.model_id = models.data[0].id if models.data else url
+
+ # 初始化线性变换层(如果需要)
+ if self.embedding_size != 768:
+ self.embed_head = nn.Linear(768, self.embedding_size)
+ else:
+ self.embed_head = None
+
+ # 初始化 SimNorm
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ # 初始化分词器,用于将字符串转为 token
+ self.tokenizer = AutoTokenizer.from_pretrained(url)
+
+ def forward(self, x: List[str]) -> torch.Tensor:
+ """
+ 前向传播,获取输入序列的语言表示。
+
+ 参数:
+ - x (List[str]): 输入的字符串列表,长度为 batch_size。
+
+ 返回:
+ - torch.Tensor: 经过处理的语言嵌入向量,形状为 [batch_size, embedding_size]。
+ """
+ with torch.no_grad():
+ # 分词
+ encoded_inputs = self.tokenizer(
+ x, truncation=True, padding="max_length", max_length=self.max_seq_len, return_tensors='pt'
+ )
+ input_texts = encoded_inputs['input_ids'].tolist()
+
+ # 调用 vLLM 的嵌入 API
+ response = self.client.embeddings.create(
+ input=x,
+ model=self.model_id,
+ )
+
+ # 提取嵌入并转换为张量
+ embeddings = [data.embedding for data in response.data] # List[List[float]]
+ embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32, device=self.device) # [batch_size, 768]
+
+ # 如果需要降维或升维
+ if self.embed_head is not None:
+ embeddings_tensor = self.embed_head(embeddings_tensor) # [batch_size, embedding_size]
+
+ # 应用 SimNorm
+ embeddings_tensor = self.sim_norm(embeddings_tensor) # [batch_size, embedding_size]
+
+ return embeddings_tensor
+
+ def __getstate__(self):
+ state = self.__dict__.copy()
+ # 移除无法序列化的对象
+ del state['client']
+ del state['tokenizer']
+ return state
+
+ def __setstate__(self, state):
+ self.__dict__.update(state)
+ # 重新初始化无法序列化的对象
+ self.client = OpenAI(
+ api_key=self.api_key,
+ base_url=self.api_base,
+ )
+ self.tokenizer = AutoTokenizer.from_pretrained(self.url)
+
+
+class RepresentationNetworkUniZero(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: SequenceType = (3, 64, 64),
+ num_res_blocks: int = 1,
+ num_channels: int = 64,
+ downsample: bool = True,
+ activation: nn.Module = nn.GELU(approximate='tanh'),
+ norm_type: str = 'BN',
+ embedding_dim: int = 256,
+ group_size: int = 8,
+ ) -> None:
+ """
+ Overview:
+ Representation network used in UniZero. Encode the 2D image obs into latent state.
+ Currently, the network only supports obs images with both a width and height of 64.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64]
+ for video games like atari, RGB 3 channel.
+ - num_res_blocks (:obj:`int`): The number of residual blocks.
+ - num_channels (:obj:`int`): The channel of output hidden state.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
+ defaults to True. This option is often used in video games like Atari. In board games like go, \
+ we don't need this module.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ - embedding_dim (:obj:`int`): The dimension of the latent state.
+ - group_size (:obj:`int`): The dimension for simplicial normalization.
+ """
+ super().__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+ logging.info(f"Using norm type: {norm_type}")
+ logging.info(f"Using activation type: {activation}")
+
+ self.observation_shape = observation_shape
+ self.downsample = downsample
+ if self.downsample:
+ self.downsample_net = DownSample(
+ observation_shape,
+ num_channels,
+ activation=activation,
+ norm_type=norm_type,
+ )
+ else:
+ self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False)
+
+ if norm_type == 'BN':
+ self.norm = nn.BatchNorm2d(num_channels)
+ elif norm_type == 'LN':
+ if downsample:
+ self.norm = nn.LayerNorm(
+ [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)],
+ eps=1e-5)
+ else:
+ self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5)
+
+ self.resblocks = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_res_blocks)
+ ]
+ )
+ self.activation = activation
+ self.embedding_dim = embedding_dim
+
+ if self.observation_shape[1] == 64:
+ self.last_linear = nn.Linear(64 * 8 * 8, self.embedding_dim, bias=False)
+
+ elif self.observation_shape[1] in [84, 96]:
+ self.last_linear = nn.Linear(64 * 6 * 6, self.embedding_dim, bias=False)
+
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \
+ H is height.
+ - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \
+ output width, H_ is output height.
+ """
+ if self.downsample:
+ x = self.downsample_net(x)
+ else:
+ x = self.conv(x)
+ x = self.norm(x)
+ x = self.activation(x)
+ for block in self.resblocks:
+ x = block(x)
+
+ # Important: Transform the output feature plane to the latent state.
+ # For example, for an Atari feature plane of shape (64, 8, 8),
+ # flattening results in a size of 4096, which is then transformed to 768.
+ x = self.last_linear(x.view(x.size(0), -1))
+
+ x = x.view(-1, self.embedding_dim)
+
+ # NOTE: very important for training stability.
+ x = self.sim_norm(x)
+
+ return x
+
+
+class RepresentationNetwork(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: SequenceType = (4, 96, 96),
+ num_res_blocks: int = 1,
+ num_channels: int = 64,
+ downsample: bool = True,
+ activation: nn.Module = nn.ReLU(inplace=True),
+ norm_type: str = 'BN',
+ embedding_dim: int = 256,
+ group_size: int = 8,
+ use_sim_norm: bool = False,
+ ) -> None:
+ """
+ Overview:
+ Representation network used in MuZero and derived algorithms. Encode the 2D image obs into latent state.
+ Currently, the network only supports obs images with both a width and height of 96.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[4, 96, 96]
+ for video games like atari, 1 gray channel times stack 4 frames.
+ - num_res_blocks (:obj:`int`): The number of residual blocks.
+ - num_channels (:obj:`int`): The channel of output hidden state.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
+ defaults to True. This option is often used in video games like Atari. In board games like go, \
+ we don't need this module.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ - embedding_dim (:obj:`int`): The dimension of the output hidden state.
+ - group_size (:obj:`int`): The size of group in the SimNorm layer.
+ - use_sim_norm (:obj:`bool`): Whether to use SimNorm layer, defaults to False.
+ """
+ super().__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+
+ self.downsample = downsample
+ if self.downsample:
+ self.downsample_net = DownSample(
+ observation_shape,
+ num_channels,
+ activation=activation,
+ norm_type=norm_type,
+ )
+ else:
+ self.conv = nn.Conv2d(observation_shape[0], num_channels, kernel_size=3, stride=1, padding=1, bias=False)
+
+ if norm_type == 'BN':
+ self.norm = nn.BatchNorm2d(num_channels)
+ elif norm_type == 'LN':
+ if downsample:
+ self.norm = nn.LayerNorm(
+ [num_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)],
+ eps=1e-5)
+ else:
+ self.norm = nn.LayerNorm([num_channels, observation_shape[-2], observation_shape[-1]], eps=1e-5)
+
+ self.resblocks = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_res_blocks)
+ ]
+ )
+ self.activation = activation
+
+ self.use_sim_norm = use_sim_norm
+
+ if self.use_sim_norm:
+ self.embedding_dim = embedding_dim
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, C_in, W, H)`, where B is batch size, C_in is channel, W is width, \
+ H is height.
+ - output (:obj:`torch.Tensor`): :math:`(B, C_out, W_, H_)`, where B is batch size, C_out is channel, W_ is \
+ output width, H_ is output height.
+ """
+ if self.downsample:
+ x = self.downsample_net(x)
+ else:
+ x = self.conv(x)
+ x = self.norm(x)
+ x = self.activation(x)
+
+ for block in self.resblocks:
+ x = block(x)
+
+ if self.use_sim_norm:
+ # NOTE: very important.
+ # for atari 64,8,8 = 4096 -> 768
+ x = self.sim_norm(x)
+
+ return x
+
+
+class RepresentationNetworkMLP(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: int,
+ hidden_channels: int = 64,
+ layer_num: int = 2,
+ activation: nn.Module = nn.GELU(approximate='tanh'),
+ norm_type: Optional[str] = 'BN',
+ group_size: int = 8,
+ ) -> torch.Tensor:
+ """
+ Overview:
+ Representation network used in MuZero and derived algorithms. Encode the vector obs into latent state \
+ with Multi-Layer Perceptron (MLP).
+ Arguments:
+ - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10.
+ - num_res_blocks (:obj:`int`): The number of residual blocks.
+ - hidden_channels (:obj:`int`): The channel of output hidden state.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
+ defaults to True. This option is often used in video games like Atari. In board games like go, \
+ we don't need this module.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(inplace=True). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super().__init__()
+ self.fc_representation = MLP(
+ in_channels=observation_shape,
+ hidden_channels=hidden_channels,
+ out_channels=hidden_channels,
+ layer_num=layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ # don't use activation and norm in the last layer of representation network is important for convergence.
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=True,
+ )
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation.
+ - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size.
+ """
+ x = self.fc_representation(x)
+ # TODO
+ x = self.sim_norm(x)
+ return x
+
+
+class LatentDecoder(nn.Module):
+
+ def __init__(self, embedding_dim: int, output_shape: SequenceType, num_channels: int = 64, activation: nn.Module = nn.GELU(approximate='tanh')):
+ """
+ Overview:
+ Decoder network used in UniZero. Decode the latent state into 2D image obs.
+ Arguments:
+ - embedding_dim (:obj:`int`): The dimension of the latent state.
+ - output_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64]
+ for video games like atari, RGB 3 channel times stack 4 frames.
+ - num_channels (:obj:`int`): The channel of output hidden state.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh').
+ """
+ super().__init__()
+ self.embedding_dim = embedding_dim
+ self.output_shape = output_shape # (C, H, W)
+ self.num_channels = num_channels
+ self.activation = activation
+
+ # Assuming that the output shape is (C, H, W) = (12, 96, 96) and embedding_dim is 256
+ # We will reverse the process of the representation network
+ self.initial_size = (
+ num_channels, output_shape[1] // 8, output_shape[2] // 8) # This should match the last layer of the encoder
+ self.fc = nn.Linear(self.embedding_dim, np.prod(self.initial_size))
+
+ # Upsampling blocks
+ self.conv_blocks = nn.ModuleList([
+ # Block 1: (num_channels, H/8, W/8) -> (num_channels//2, H/4, W/4)
+ nn.ConvTranspose2d(num_channels, num_channels // 2, kernel_size=3, stride=2, padding=1, output_padding=1),
+ self.activation,
+ nn.BatchNorm2d(num_channels // 2),
+ # Block 2: (num_channels//2, H/4, W/4) -> (num_channels//4, H/2, W/2)
+ nn.ConvTranspose2d(num_channels // 2, num_channels // 4, kernel_size=3, stride=2, padding=1,
+ output_padding=1),
+ self.activation,
+ nn.BatchNorm2d(num_channels // 4),
+ # Block 3: (num_channels//4, H/2, W/2) -> (output_shape[0], H, W)
+ nn.ConvTranspose2d(num_channels // 4, output_shape[0], kernel_size=3, stride=2, padding=1,
+ output_padding=1),
+ ])
+ # TODO: last layer use sigmoid?
+
+ def forward(self, embeddings: torch.Tensor) -> torch.Tensor:
+ # Map embeddings back to the image space
+ x = self.fc(embeddings) # (B, embedding_dim) -> (B, C*H/8*W/8)
+ x = x.view(-1, *self.initial_size) # (B, C*H/8*W/8) -> (B, C, H/8, W/8)
+
+ # Apply conv blocks
+ for block in self.conv_blocks:
+ x = block(x) # Upsample progressively
+
+ # The output x should have the shape of (B, output_shape[0], output_shape[1], output_shape[2])
+ return x
+
+
+class LatentEncoderForMemoryEnv(nn.Module):
+
+ def __init__(
+ self,
+ image_shape=(3, 5, 5),
+ embedding_size=100,
+ channels=[16, 32, 64],
+ kernel_sizes=[3, 3, 3],
+ strides=[1, 1, 1],
+ activation: nn.Module = nn.GELU(approximate='tanh'),
+ normalize_pixel=False,
+ group_size: int = 8,
+ **kwargs,
+ ):
+ """
+ Overview:
+ Encoder network used in UniZero in MemoryEnv. Encode the 2D image obs into latent state.
+ Arguments:
+ - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64]
+ for video games like atari, RGB 3 channel times stack 4 frames.
+ - embedding_size (:obj:`int`): The dimension of the latent state.
+ - channels (:obj:`List[int]`): The channel of output hidden state.
+ - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers.
+ - strides (:obj:`List[int]`): The stride of convolution layers.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.GELU(approximate='tanh'). \
+ Use the inplace operation to speed up.
+ - normalize_pixel (:obj:`bool`): Whether to normalize the pixel values to [0, 1], defaults to False.
+ - group_size (:obj:`int`): The dimension for simplicial normalization
+ """
+ super(LatentEncoderForMemoryEnv, self).__init__()
+ self.shape = image_shape
+ self.channels = [image_shape[0]] + list(channels)
+
+ layers = []
+ for i in range(len(self.channels) - 1):
+ layers.append(
+ nn.Conv2d(
+ self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i],
+ padding=kernel_sizes[i] // 2 # keep the same size of feature map
+ )
+ )
+ layers.append(nn.BatchNorm2d(self.channels[i + 1]))
+ layers.append(activation)
+
+ layers.append(nn.AdaptiveAvgPool2d(1))
+
+ self.cnn = nn.Sequential(*layers)
+ self.linear = nn.Sequential(
+ nn.Linear(self.channels[-1], embedding_size, bias=False),
+ )
+ init.kaiming_normal_(self.linear[0].weight, mode='fan_out', nonlinearity='relu')
+
+ self.normalize_pixel = normalize_pixel
+ self.sim_norm = SimNorm(simnorm_dim=group_size)
+
+ def forward(self, image):
+ if self.normalize_pixel:
+ image = image / 255.0
+ x = self.cnn(image.float()) # (B, C, 1, 1)
+ x = torch.flatten(x, start_dim=1) # (B, C)
+ x = self.linear(x) # (B, embedding_size)
+ x = self.sim_norm(x)
+ return x
+
+
+class LatentDecoderForMemoryEnv(nn.Module):
+
+ def __init__(
+ self,
+ image_shape=(3, 5, 5),
+ embedding_size=256,
+ channels=[64, 32, 16],
+ kernel_sizes=[3, 3, 3],
+ strides=[1, 1, 1],
+ activation: nn.Module = nn.LeakyReLU(negative_slope=0.01),
+ **kwargs,
+ ):
+ """
+ Overview:
+ Decoder network used in UniZero in MemoryEnv. Decode the latent state into 2D image obs.
+ Arguments:
+ - image_shape (:obj:`SequenceType`): The shape of observation space, e.g. [C, W, H]=[3, 64, 64]
+ for video games like atari, RGB 3 channel times stack 4 frames.
+ - embedding_size (:obj:`int`): The dimension of the latent state.
+ - channels (:obj:`List[int]`): The channel of output hidden state.
+ - kernel_sizes (:obj:`List[int]`): The kernel size of convolution layers.
+ - strides (:obj:`List[int]`): The stride of convolution layers.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.LeakyReLU(). \
+ Use the inplace operation to speed up.
+ """
+ super(LatentDecoderForMemoryEnv, self).__init__()
+ self.shape = image_shape
+ self.channels = list(channels) + [image_shape[0]]
+
+ self.linear = nn.Linear(embedding_size, channels[0] * image_shape[1] * image_shape[2])
+
+ layers = []
+ for i in range(len(self.channels) - 1):
+ layers.append(
+ nn.ConvTranspose2d(
+ self.channels[i], self.channels[i + 1], kernel_sizes[i], strides[i],
+ padding=kernel_sizes[i] // 2, output_padding=strides[i] - 1
+ )
+ )
+ if i < len(self.channels) - 2:
+ layers.append(nn.BatchNorm2d(self.channels[i + 1]))
+ layers.append(activation)
+ else:
+ layers.append(nn.Sigmoid())
+
+ self.deconv = nn.Sequential(*layers)
+
+ def forward(self, embedding):
+ x = self.linear(embedding)
+ x = x.view(-1, self.channels[0], self.shape[1], self.shape[2])
+ x = self.deconv(x) # (B, C, H, W)
+ return x
+
+
+class VectorDecoderForMemoryEnv(nn.Module):
+
+ def __init__(
+ self,
+ embedding_dim: int,
+ output_shape: SequenceType,
+ hidden_channels: int = 64,
+ layer_num: int = 2,
+ activation: nn.Module = nn.LeakyReLU(negative_slope=0.01), # TODO
+ norm_type: Optional[str] = 'BN',
+ ) -> torch.Tensor:
+ """
+ Overview:
+ Decoder network used in UniZero in MemoryEnv. Decode the latent state into vector obs.
+ Arguments:
+ - observation_shape (:obj:`int`): The shape of vector observation space, e.g. N = 10.
+ - num_res_blocks (:obj:`int`): The number of residual blocks.
+ - hidden_channels (:obj:`int`): The channel of output hidden state.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``, \
+ defaults to True. This option is often used in video games like Atari. In board games like go, \
+ we don't need this module.
+ - activation (:obj:`nn.Module`): The activation function used in network, defaults to nn.ReLU(). \
+ Use the inplace operation to speed up.
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super().__init__()
+ self.fc_representation = MLP(
+ in_channels=embedding_dim,
+ hidden_channels=hidden_channels,
+ out_channels=output_shape,
+ layer_num=layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ # don't use activation and norm in the last layer of representation network is important for convergence.
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=True,
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size, N is the length of vector observation.
+ - output (:obj:`torch.Tensor`): :math:`(B, hidden_channels)`, where B is batch size.
+ """
+ x = self.fc_representation(x)
+ return x
+
+
+class PredictionNetwork(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: SequenceType,
+ action_space_size: int,
+ num_res_blocks: int,
+ num_channels: int,
+ value_head_channels: int,
+ policy_head_channels: int,
+ fc_value_layers: int,
+ fc_policy_layers: int,
+ output_support_size: int,
+ flatten_output_size_for_value_head: int,
+ flatten_output_size_for_policy_head: int,
+ downsample: bool = False,
+ last_linear_layer_init_zero: bool = True,
+ activation: nn.Module = nn.ReLU(inplace=True),
+ norm_type: Optional[str] = 'BN',
+ ) -> None:
+ """
+ Overview:
+ The definition of policy and value prediction network, which is used to predict value and policy by the
+ given latent state.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image.
+ - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space.
+ - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model.
+ - num_channels (:obj:`int`): The channels of hidden states.
+ - value_head_channels (:obj:`int`): The channels of value head.
+ - policy_head_channels (:obj:`int`): The channels of policy head.
+ - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
+ - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
+ - output_support_size (:obj:`int`): The size of categorical value output.
+ - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \
+ - flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
+ of the value head.
+ - flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
+ of the policy head.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``.
+ - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \
+ dynamics/prediction mlp, default sets it to True.
+ - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \
+ operation to speedup, e.g. ReLU(inplace=True).
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super(PredictionNetwork, self).__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+
+ self.resblocks = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_res_blocks)
+ ]
+ )
+
+ self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1)
+ self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1)
+
+ if observation_shape[1] == 96:
+ latent_shape = (observation_shape[1] / 16, observation_shape[2] / 16)
+ elif observation_shape[1] == 64:
+ latent_shape = (observation_shape[1] / 8, observation_shape[2] / 8)
+
+ if norm_type == 'BN':
+ self.norm_value = nn.BatchNorm2d(value_head_channels)
+ self.norm_policy = nn.BatchNorm2d(policy_head_channels)
+ elif norm_type == 'LN':
+ if downsample:
+ self.norm_value = nn.LayerNorm(
+ [value_head_channels, *latent_shape],
+ eps=1e-5)
+ self.norm_policy = nn.LayerNorm([policy_head_channels, *latent_shape], eps=1e-5)
+ else:
+ self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]],
+ eps=1e-5)
+ self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]],
+ eps=1e-5)
+
+ self.flatten_output_size_for_value_head = flatten_output_size_for_value_head
+ self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head
+
+ self.activation = activation
+
+ self.fc_value = MLP(
+ in_channels=self.flatten_output_size_for_value_head,
+ hidden_channels=fc_value_layers[0],
+ out_channels=output_support_size,
+ layer_num=len(fc_value_layers) + 1,
+ activation=self.activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+ self.fc_policy = MLP(
+ in_channels=self.flatten_output_size_for_policy_head,
+ hidden_channels=fc_policy_layers[0],
+ out_channels=action_space_size,
+ layer_num=len(fc_policy_layers) + 1,
+ activation=self.activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+
+ def forward(self, latent_state: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ Forward computation of the prediction network.
+ Arguments:
+ - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim).
+ Returns:
+ - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size).
+ - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size).
+ """
+ for res_block in self.resblocks:
+ latent_state = res_block(latent_state)
+
+ value = self.conv1x1_value(latent_state)
+ value = self.norm_value(value)
+ value = self.activation(value)
+
+ policy = self.conv1x1_policy(latent_state)
+ policy = self.norm_policy(policy)
+ policy = self.activation(policy)
+
+ value = value.reshape(-1, self.flatten_output_size_for_value_head)
+ policy = policy.reshape(-1, self.flatten_output_size_for_policy_head)
+
+ value = self.fc_value(value)
+ policy = self.fc_policy(policy)
+ return policy, value
+
+
+class PredictionNetworkMLP(nn.Module):
+
+ def __init__(
+ self,
+ action_space_size,
+ num_channels,
+ common_layer_num: int = 2,
+ fc_value_layers: SequenceType = [32],
+ fc_policy_layers: SequenceType = [32],
+ output_support_size: int = 601,
+ last_linear_layer_init_zero: bool = True,
+ activation: Optional[nn.Module] = nn.ReLU(inplace=True),
+ norm_type: Optional[str] = 'BN',
+ ):
+ """
+ Overview:
+ The definition of policy and value prediction network with Multi-Layer Perceptron (MLP),
+ which is used to predict value and policy by the given latent state.
+ Arguments:
+ - action_space_size: (:obj:`int`): Action space size, usually an integer number. For discrete action \
+ space, it is the number of discrete actions.
+ - num_channels (:obj:`int`): The channels of latent states.
+ - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
+ - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
+ - output_support_size (:obj:`int`): The size of categorical value output.
+ - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \
+ dynamics/prediction mlp, default sets it to True.
+ - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \
+ operation to speedup, e.g. ReLU(inplace=True).
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super().__init__()
+ self.num_channels = num_channels
+
+ # ******* common backbone ******
+ self.fc_prediction_common = MLP(
+ in_channels=self.num_channels,
+ hidden_channels=self.num_channels,
+ out_channels=self.num_channels,
+ layer_num=common_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ output_activation=True,
+ output_norm=True,
+ # last_linear_layer_init_zero=False is important for convergence
+ last_linear_layer_init_zero=False,
+ )
+
+ # ******* value and policy head ******
+ self.fc_value_head = MLP(
+ in_channels=self.num_channels,
+ hidden_channels=fc_value_layers[0],
+ out_channels=output_support_size,
+ layer_num=len(fc_value_layers) + 1,
+ activation=activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+ self.fc_policy_head = MLP(
+ in_channels=self.num_channels,
+ hidden_channels=fc_policy_layers[0],
+ out_channels=action_space_size,
+ layer_num=len(fc_policy_layers) + 1,
+ activation=activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+
+ def forward(self, latent_state: torch.Tensor):
+ """
+ Overview:
+ Forward computation of the prediction network.
+ Arguments:
+ - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim).
+ Returns:
+ - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size).
+ - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size).
+ """
+ x_prediction_common = self.fc_prediction_common(latent_state)
+
+ value = self.fc_value_head(x_prediction_common)
+ policy = self.fc_policy_head(x_prediction_common)
+ return policy, value
+
+
+class PredictionHiddenNetwork(nn.Module):
+
+ def __init__(
+ self,
+ observation_shape: SequenceType,
+ action_space_size: int,
+ num_res_blocks: int,
+ num_channels: int,
+ value_head_channels: int,
+ policy_head_channels: int,
+ fc_value_layers: int,
+ fc_policy_layers: int,
+ output_support_size: int,
+ flatten_output_size_for_value_head: int,
+ flatten_output_size_for_policy_head: int,
+ downsample: bool = False,
+ last_linear_layer_init_zero: bool = True,
+ activation: nn.Module = nn.ReLU(inplace=True),
+ norm_type: Optional[str] = 'BN',
+ gru_hidden_size: int = 512,
+ ) -> None:
+ """
+ Overview:
+ The definition of policy and value prediction network, which is used to predict value and policy by the
+ given latent state.
+ Arguments:
+ - observation_shape (:obj:`SequenceType`): The shape of observation space, e.g. (C, H, W) for image.
+ - action_space_size: (:obj:`int`): Action space size, usually an integer number for discrete action space.
+ - num_res_blocks (:obj:`int`): The number of res blocks in AlphaZero model.
+ - num_channels (:obj:`int`): The channels of hidden states.
+ - value_head_channels (:obj:`int`): The channels of value head.
+ - policy_head_channels (:obj:`int`): The channels of policy head.
+ - fc_value_layers (:obj:`SequenceType`): The number of hidden layers used in value head (MLP head).
+ - fc_policy_layers (:obj:`SequenceType`): The number of hidden layers used in policy head (MLP head).
+ - output_support_size (:obj:`int`): The size of categorical value output.
+ - self_supervised_learning_loss (:obj:`bool`): Whether to use self_supervised_learning related networks \
+ - flatten_output_size_for_value_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
+ of the value head.
+ - flatten_output_size_for_policy_head (:obj:`int`): The size of flatten hidden states, i.e. the input size \
+ of the policy head.
+ - downsample (:obj:`bool`): Whether to do downsampling for observations in ``representation_network``.
+ - last_linear_layer_init_zero (:obj:`bool`): Whether to use zero initializations for the last layer of \
+ dynamics/prediction mlp, default sets it to True.
+ - activation (:obj:`Optional[nn.Module]`): Activation function used in network, which often use in-place \
+ operation to speedup, e.g. ReLU(inplace=True).
+ - norm_type (:obj:`str`): The type of normalization in networks. defaults to 'BN'.
+ """
+ super(PredictionHiddenNetwork, self).__init__()
+ assert norm_type in ['BN', 'LN'], "norm_type must in ['BN', 'LN']"
+
+ self.observation_shape = observation_shape
+ self.gru_hidden_size = gru_hidden_size
+ self.resblocks = nn.ModuleList(
+ [
+ ResBlock(
+ in_channels=num_channels, activation=activation, norm_type=norm_type, res_type='basic', bias=False
+ ) for _ in range(num_res_blocks)
+ ]
+ )
+
+ self.conv1x1_value = nn.Conv2d(num_channels, value_head_channels, 1)
+ self.conv1x1_policy = nn.Conv2d(num_channels, policy_head_channels, 1)
+
+ if norm_type == 'BN':
+ self.norm_value = nn.BatchNorm2d(value_head_channels)
+ self.norm_policy = nn.BatchNorm2d(policy_head_channels)
+ elif norm_type == 'LN':
+ if downsample:
+ self.norm_value = nn.LayerNorm(
+ [value_head_channels, math.ceil(observation_shape[-2] / 16), math.ceil(observation_shape[-1] / 16)],
+ eps=1e-5)
+ self.norm_policy = nn.LayerNorm([policy_head_channels, math.ceil(observation_shape[-2] / 16),
+ math.ceil(observation_shape[-1] / 16)], eps=1e-5)
+ else:
+ self.norm_value = nn.LayerNorm([value_head_channels, observation_shape[-2], observation_shape[-1]],
+ eps=1e-5)
+ self.norm_policy = nn.LayerNorm([policy_head_channels, observation_shape[-2], observation_shape[-1]],
+ eps=1e-5)
+
+ self.flatten_output_size_for_value_head = flatten_output_size_for_value_head
+ self.flatten_output_size_for_policy_head = flatten_output_size_for_policy_head
+
+ self.activation = activation
+
+ self.fc_value = MLP(
+ in_channels=self.flatten_output_size_for_value_head + self.gru_hidden_size,
+ hidden_channels=fc_value_layers[0],
+ out_channels=output_support_size,
+ layer_num=len(fc_value_layers) + 1,
+ activation=self.activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+ self.fc_policy = MLP(
+ in_channels=self.flatten_output_size_for_policy_head + self.gru_hidden_size,
+ hidden_channels=fc_policy_layers[0],
+ out_channels=action_space_size,
+ layer_num=len(fc_policy_layers) + 1,
+ activation=self.activation,
+ norm_type=norm_type,
+ output_activation=False,
+ output_norm=False,
+ # last_linear_layer_init_zero=True is beneficial for convergence speed.
+ last_linear_layer_init_zero=last_linear_layer_init_zero
+ )
+
+ def forward(self, latent_state: torch.Tensor, world_model_latent_history: torch.Tensor) -> Tuple[
+ torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ Forward computation of the prediction network.
+ Arguments:
+ - latent_state (:obj:`torch.Tensor`): input tensor with shape (B, latent_state_dim).
+ Returns:
+ - policy (:obj:`torch.Tensor`): policy tensor with shape (B, action_space_size).
+ - value (:obj:`torch.Tensor`): value tensor with shape (B, output_support_size).
+ """
+ for res_block in self.resblocks:
+ latent_state = res_block(latent_state)
+
+ value = self.conv1x1_value(latent_state)
+ value = self.norm_value(value)
+ value = self.activation(value)
+
+ policy = self.conv1x1_policy(latent_state)
+ policy = self.norm_policy(policy)
+ policy = self.activation(policy)
+
+ latent_state_value = value.reshape(-1, self.flatten_output_size_for_value_head)
+ latent_state_policy = policy.reshape(-1, self.flatten_output_size_for_policy_head)
+
+ # TODO: world_model_latent_history.squeeze(0) shape: (num_layers * num_directions, batch_size, hidden_size) -> ( batch_size, hidden_size)
+ latent_history_value = torch.cat([latent_state_value, world_model_latent_history.squeeze(0)], dim=1)
+ latent_history_policy = torch.cat([latent_state_policy, world_model_latent_history.squeeze(0)], dim=1)
+
+ value = self.fc_value(latent_history_value)
+ policy = self.fc_policy(latent_history_policy)
+ return policy, value
+
+
diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py
index e28322215..239e7efb3 100644
--- a/lzero/model/unizero_model.py
+++ b/lzero/model/unizero_model.py
@@ -6,7 +6,8 @@
from easydict import EasyDict
from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \
- VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook
+ VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \
+ HFLanguageRepresentationNetwork
from .unizero_world_models.tokenizer import Tokenizer
from .unizero_world_models.world_model import WorldModel
@@ -87,6 +88,15 @@ def __init__(
print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer')
print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder')
print('==' * 20)
+ elif world_model_cfg.obs_type == 'text':
+ self.representation_network = HFLanguageRepresentationNetwork(url=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim)
+ self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False,)
+ self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer)
+ print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')
+ print('==' * 20)
+ print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer')
+ print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder')
+ print('==' * 20)
elif world_model_cfg.obs_type == 'image':
self.representation_network = RepresentationNetworkUniZero(
observation_shape,
diff --git a/lzero/model/unizero_world_models/hf_transformer.py b/lzero/model/unizero_world_models/hf_transformer.py
new file mode 100644
index 000000000..d0ef762be
--- /dev/null
+++ b/lzero/model/unizero_world_models/hf_transformer.py
@@ -0,0 +1,110 @@
+from typing import Optional
+
+import torch
+from transformers import LlamaForCausalLM
+from transformers.cache_utils import DynamicCache
+
+from .kv_caching import KeysValues
+
+
+def kv2dc(cache: KeysValues):
+ res = DynamicCache()
+ for kv_cache in cache:
+ k_tensor = kv_cache._k_cache.get()
+ v_tensor = kv_cache._v_cache.get()
+ res.key_cache.append(k_tensor)
+ res.value_cache.append(v_tensor)
+ return res
+
+
+def update_kv(cache: KeysValues, new_cache: DynamicCache):
+ for i in range(len(new_cache.key_cache)):
+ cache[i].update(new_cache.key_cache[-1], new_cache.value_cache[-1])
+
+
+class HuggingfaceLlamaTransformer(LlamaForCausalLM):
+ @classmethod
+ def from_pretrained(cls, lzero_config, *args, **kwargs):
+ # Add custom logic here
+ model = super(HuggingfaceLlamaTransformer, cls).from_pretrained(*args, **kwargs)
+ model.lzero_config = lzero_config
+ return model
+
+ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
+ """
+ Generate a placeholder for keys and values.
+
+ Arguments:
+ - n (:obj:`int`): Batch size.
+ - max_tokens (:obj:`int`): Maximum number of tokens in the sequence.
+
+ Returns:
+ - KeysValues: An object containing empty keys and values.
+ """
+ device = self.lzero_config.device # Assumption: All submodules are on the same device
+ return KeysValues(n, self.lzero_config.num_heads, max_tokens,
+ self.lzero_config.embed_dim, self.lzero_config.num_layers,
+ device, self.lzero_config.hidden_size)
+
+ def _get_positional_embedding(self, layer, attn_type, pos_emb) -> torch.Tensor:
+ """
+ Helper function to get positional embedding for a given layer and attention type.
+
+ Arguments:
+ - layer (:obj:`int`): Layer index.
+ - attn_type (:obj:`str`): Attention type, either 'key' or 'value'.
+
+ Returns:
+ - torch.Tensor: The positional embedding tensor.
+ """
+ if attn_type == 'key':
+ module_name = 'k_proj'
+ elif attn_type == 'value':
+ module_name = 'v_proj'
+ elif attn_type == 'query':
+ module_name = 'q_proj'
+ else:
+ assert False
+ attn_func = getattr(self.model.layers[layer].self_attn, module_name)
+ return attn_func(pos_emb.weight)
+
+ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None,
+ valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Forward pass of the Transformer model.
+
+ Arguments:
+ - sequences (:obj:`torch.Tensor`): Input tensor of shape (batch_size, seq_length, embed_dim).
+ - past_keys_values (:obj:`Optional[KeysValues]`): Precomputed keys and values for faster generation (default: None).
+ - valid_context_lengths (:obj:`Optional[torch.Tensor]`): Valid lengths of context for masking (default: None).
+
+ Returns:
+ - torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim).
+ """
+ assert past_keys_values is None or len(past_keys_values) == len(self.model.layers)
+ if past_keys_values is not None:
+ kv_cache = kv2dc(past_keys_values)
+ use_cache = True
+ else:
+ kv_cache = None
+ use_cache = False
+
+ B, T, _ = sequences.shape
+ if valid_context_lengths is not None:
+ attention_mask = torch.arange(T).expand(B, T) >= (T - valid_context_lengths.unsqueeze(1))
+ else:
+ attention_mask = torch.ones_like(sequences)
+ # print(valid_context_lengths.shape)
+ # print(attention_mask.shape)
+ # print(sequences.shape)
+ # assert False
+
+ output = self.model.forward(
+ attention_mask=attention_mask,
+ past_key_values=kv_cache,
+ inputs_embeds=sequences,
+ use_cache=use_cache
+ )
+
+ update_kv(past_keys_values, kv_cache)
+ return output.logits[:, -1, :]
diff --git a/lzero/model/unizero_world_models/kv_caching.py b/lzero/model/unizero_world_models/kv_caching.py
index 28b7b0ba2..7ede62197 100644
--- a/lzero/model/unizero_world_models/kv_caching.py
+++ b/lzero/model/unizero_world_models/kv_caching.py
@@ -1,13 +1,14 @@
# Modified from https://github.com/eloialonso/iris/blob/main/src/models/kv_caching.py
-from typing import Tuple
+from typing import Tuple, Optional
import numpy as np
import torch
class Cache:
- def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None:
+ def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device,
+ hidden_size: Optional[int]) -> None:
"""
Overview:
Cache for storing intermediate results in a transformer model.
@@ -20,7 +21,9 @@ def __init__(self, num_samples: int, num_heads: int, max_tokens: int, embed_dim:
"""
assert embed_dim % num_heads == 0
self._num_samples, self._cache, self._size = num_samples, None, None
- self._reset = lambda n: torch.empty(n, num_heads, max_tokens, embed_dim // num_heads, device=device) # (B, nh, T, hs)
+ if hidden_size == None:
+ hidden_size = embed_dim // num_heads
+ self._reset = lambda n: torch.empty(n, num_heads, max_tokens, hidden_size, device=device) # (B, nh, T, hs)
self.reset()
@property
@@ -77,7 +80,8 @@ def update(self, x: torch.Tensor, tokens: int) -> None:
class KVCache:
- def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device) -> None:
+ def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, device: torch.device,
+ hidden_size: Optional[int]) -> None:
"""
Overview:
Cache for storing key and value tensors in a transformer model.
@@ -88,8 +92,8 @@ def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, devi
- embed_dim (:obj:`int`): The dimension of the embeddings.
- device (:obj:`torch.device`): The device on which to store the cache.
"""
- self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device)
- self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device)
+ self._k_cache = Cache(n, num_heads, max_tokens, embed_dim, device, hidden_size)
+ self._v_cache = Cache(n, num_heads, max_tokens, embed_dim, device, hidden_size)
@property
def shape(self) -> Tuple[int, int, int, int]:
@@ -142,7 +146,8 @@ def update(self, k: torch.Tensor, v: torch.Tensor):
class KeysValues:
- def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device) -> None:
+ def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, num_layers: int, device: torch.device,
+ hidden_size: Optional[int] = None) -> None:
"""
Overview:
Class for managing multiple layers of key and value caches in a transformer model.
@@ -154,7 +159,8 @@ def __init__(self, n: int, num_heads: int, max_tokens: int, embed_dim: int, num_
- num_layers (:obj:`int`): The number of layers in the transformer model.
- device (:obj:`torch.device`): The device on which to store the caches.
"""
- self._keys_values = tuple([KVCache(n, num_heads, max_tokens, embed_dim, device) for _ in range(num_layers)])
+ self._keys_values = tuple([KVCache(n, num_heads, max_tokens, embed_dim, device, hidden_size)
+ for _ in range(num_layers)])
def __getitem__(self, index: int) -> KVCache:
"""
diff --git a/lzero/model/unizero_world_models/tokenizer.py b/lzero/model/unizero_world_models/tokenizer.py
index bd066ccec..1a237610e 100644
--- a/lzero/model/unizero_world_models/tokenizer.py
+++ b/lzero/model/unizero_world_models/tokenizer.py
@@ -64,7 +64,11 @@ def encode_to_obs_embeddings(self, x: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: Encoded embeddings of shape (B, 1, E).
"""
+ # NOTE: only for Jerico env
+ # x = x.long() # 确保输入为长整型
shape = x.shape
+ # print(f"Max index in x: {x.max()}")
+ # print(f"Min index in x: {x.min()}")
# Process input tensor based on its dimensionality
if len(shape) == 2:
# Case when input is 2D (B, E)
@@ -72,7 +76,14 @@ def encode_to_obs_embeddings(self, x: torch.Tensor) -> torch.Tensor:
obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e')
elif len(shape) == 3:
# Case when input is 3D (B, T, E)
- x = x.contiguous().view(-1, shape[-1]) # Flatten the last two dimensions (B * T, E)
+ # print(f"x:{x}")
+ try:
+ x = x.contiguous().view(-1, shape[-1]) # Flatten the last two dimensions (B * T, E)
+ except Exception as e:
+ print(f"x:{x}")
+ print(' =='*20)
+ print(e)
+
obs_embeddings = self.encoder(x)
obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e')
elif len(shape) == 4:
diff --git a/lzero/model/unizero_world_models/transformer.py b/lzero/model/unizero_world_models/transformer.py
index 62536c892..3fa0a3f65 100644
--- a/lzero/model/unizero_world_models/transformer.py
+++ b/lzero/model/unizero_world_models/transformer.py
@@ -69,6 +69,20 @@ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
device = self.ln_f.weight.device # Assumption: All submodules are on the same device
return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device)
+ def _get_positional_embedding(self, layer, attn_type, pos_emb) -> torch.Tensor:
+ """
+ Helper function to get positional embedding for a given layer and attention type.
+
+ Arguments:
+ - layer (:obj:`int`): Layer index.
+ - attn_type (:obj:`str`): Attention type, either 'key' or 'value'.
+
+ Returns:
+ - torch.Tensor: The positional embedding tensor.
+ """
+ attn_func = getattr(self.blocks[layer].attn, attn_type)
+ return attn_func(pos_emb.weight)
+
def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py
index 37d4cd3ec..0a0163b44 100644
--- a/lzero/model/unizero_world_models/world_model.py
+++ b/lzero/model/unizero_world_models/world_model.py
@@ -19,6 +19,7 @@
from .transformer import Transformer, TransformerConfig
from .utils import LossWithIntermediateLosses, init_weights
from .utils import WorldModelOutput, hash_state
+from .hf_transformer import HuggingfaceLlamaTransformer
logging.getLogger().setLevel(logging.DEBUG)
@@ -45,7 +46,10 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
super().__init__()
self.tokenizer = tokenizer
self.config = config
- self.transformer = Transformer(self.config)
+ if self.config.use_hf:
+ self.transformer = HuggingfaceLlamaTransformer.from_pretrained(self.config, self.config.pretrained_path)
+ else:
+ self.transformer = Transformer(self.config)
if self.config.device == 'cpu':
self.device = torch.device('cpu')
@@ -61,7 +65,8 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
# Initialize patterns for block masks
self._initialize_patterns()
- self.hidden_size = config.embed_dim // config.num_heads
+ self.hidden_size = config['hidden_size'] if "hidden_size" in config else config.embed_dim // config.num_heads
+ config['hidden_size'] = self.hidden_size
# Position embedding
self.pos_emb = nn.Embedding(config.max_tokens, config.embed_dim, device=self.device)
@@ -403,15 +408,13 @@ def _get_positional_embedding(self, layer, attn_type) -> torch.Tensor:
Returns:
- torch.Tensor: The positional embedding tensor.
"""
- attn_func = getattr(self.transformer.blocks[layer].attn, attn_type)
+ tmp = self.transformer._get_positional_embedding(layer, attn_type, self.pos_emb).view(
+ 1, self.config.max_tokens, -1, self.embed_dim // self.num_heads
+ )
if torch.cuda.is_available():
- return attn_func(self.pos_emb.weight).view(
- 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads
- ).transpose(1, 2).to(self.device).detach()
+ return tmp.transpose(1, 2).to(self.device).detach()
else:
- return attn_func(self.pos_emb.weight).view(
- 1, self.config.max_tokens, self.num_heads, self.embed_dim // self.num_heads
- ).transpose(1, 2).detach()
+ return tmp.transpose(1, 2).detach()
def forward(self, obs_embeddings_or_act_tokens: Dict[str, Union[torch.Tensor, tuple]],
past_keys_values: Optional[torch.Tensor] = None,
@@ -1168,6 +1171,18 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
# reconstructed_images)
latent_recon_loss = self.latent_recon_loss
+ elif self.obs_type == 'text':
+ perceptual_loss = torch.tensor(0., device=batch['observations'].device,
+ dtype=torch.float32)
+
+ # Reconstruct observations from latent state representations
+ # reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim))
+
+ # # Calculate reconstruction loss
+ # latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25),
+ # reconstructed_images)
+ latent_recon_loss = self.latent_recon_loss
+
elif self.obs_type == 'image_memory':
# Reconstruct observations from latent state representations
# reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings)
diff --git a/lzero/policy/unizero.py b/lzero/policy/unizero.py
index cf95c46d9..90af62ec0 100644
--- a/lzero/policy/unizero.py
+++ b/lzero/policy/unizero.py
@@ -383,7 +383,19 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
# Convert to categorical distributions
target_reward_categorical = phi_transform(self.reward_support, transformed_target_reward)
- target_value_categorical = phi_transform(self.value_support, transformed_target_value)
+ # print(f'transformed_target_value:{transformed_target_value}')
+ # print("self.value_support:", self.value_support)
+
+ try:
+ target_value_categorical = phi_transform(self.value_support, transformed_target_value)
+ except Exception as e:
+ print('='*20)
+ print(e)
+ # print(f'transformed_target_value:{transformed_target_value}')
+ # print("self.value_support:", self.value_support)
+ print('='*20)
+ # target_value_categorical = phi_transform(self.value_support, transformed_target_value)
+
# Prepare batch for GPT model
batch_for_gpt = {}
@@ -405,9 +417,15 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
batch_for_gpt['target_policy'] = target_policy[:, :-1]
# Extract valid target policy data and compute entropy
- valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']]
- target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1)
- average_target_policy_entropy = target_policy_entropy.mean()
+ try:
+ valid_target_policy = batch_for_gpt['target_policy'][batch_for_gpt['mask_padding']]
+ target_policy_entropy = -torch.sum(valid_target_policy * torch.log(valid_target_policy + 1e-9), dim=-1)
+ average_target_policy_entropy = target_policy_entropy.mean()
+ except Exception as e:
+ print('='*20)
+ print(e)
+ average_target_policy_entropy = 0.
+
# Update world model
losses = self._learn_model.world_model.compute_loss(
@@ -433,8 +451,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model']
latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms']
- assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values"
- assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values"
+ # assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values"
+ # assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values"
# Core learn model update step
self._optimizer_world_model.zero_grad()
diff --git a/lzero/worker/muzero_collector.py b/lzero/worker/muzero_collector.py
index eff413df6..c94220d69 100644
--- a/lzero/worker/muzero_collector.py
+++ b/lzero/worker/muzero_collector.py
@@ -11,6 +11,7 @@
allreduce_data
from ding.worker.collector.base_serial_collector import ISerialCollector
from torch.nn import L1Loss
+import torch.distributed as dist
from lzero.mcts.buffer.game_segment import GameSegment
from lzero.mcts.utils import prepare_observation
@@ -327,6 +328,9 @@ def collect(self,
Returns:
- return_data (:obj:`List[Any]`): Collected data in the form of a list.
"""
+ # Before starting collection
+ # self._logger.info(f"Rank {self._rank} starting collection for {n_episode} episodes.")
+
# TODO: collect_with_pure_policy as a separate collector
if n_episode is None:
if self._default_n_episode is None:
@@ -716,11 +720,21 @@ def collect(self,
break
collected_duration = sum([d['time'] for d in self._episode_info])
+
+ # Before allreduce
+ self._logger.info(f"Rank {self._rank} before allreduce: collected_step={collected_step}, collected_episode={collected_episode}")
+
# reduce data when enables DDP
if self._world_size > 1:
+ dist.barrier()
+ # print(f"Rank {dist.get_rank()} collected_step: {collected_step}, collected_episode: {collected_episode}, collected_duration: {collected_duration}")
collected_step = allreduce_data(collected_step, 'sum')
collected_episode = allreduce_data(collected_episode, 'sum')
collected_duration = allreduce_data(collected_duration, 'sum')
+
+ # After allreduce
+ self._logger.info(f"Rank {self._rank} after allreduce: collected_step={collected_step}, collected_episode={collected_episode}")
+
self._total_envstep_count += collected_step
self._total_episode_count += collected_episode
self._total_duration += collected_duration
diff --git a/requirements.txt b/requirements.txt
index 831ae67c5..ffb8dd582 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -8,3 +8,5 @@ pytest
line_profiler
xxhash
einops
+openai==1.57.1
+jericho
diff --git a/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py b/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py
index 56e89d1a3..ed3121bc3 100644
--- a/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py
+++ b/zoo/dmc2gym/config/dmc2gym_state_suz_segment_config.py
@@ -15,22 +15,56 @@ def main(env_id, seed):
continuous_action_space = True
K = 20 # num_of_sampled_actions
+ # K = 16 # num_of_sampled_actions
+
collector_env_num = 8
n_episode = 8
num_segments = 8
- game_segment_length = 100
+ game_segment_length = 125
+ # game_segment_length = 500
+
+ collector_env_num = 16
+ n_episode = 16
+ num_segments = 16
+ game_segment_length = 125
+
+ # collector_env_num = 16
+ # n_episode = 16
+ # num_segments = 16
+ # game_segment_length = 125
+
+
evaluator_env_num = 3
- num_simulations = 50
- replay_ratio = 0.1
- max_env_step = int(5e5)
+ num_simulations = 50 # TODO
+
+ # max_env_step = int(5e5)
+ max_env_step = int(1e6)
+ # max_env_step = int(3e6) # TODO
+
+ reanalyze_ratio = 0
batch_size = 64
num_layers = 2
+ # num_layers = 4
+
num_unroll_steps = 5
+ # num_unroll_steps = 10
infer_context_length = 2
+
+ # replay_ratio = 0.25
+ # num_unroll_steps = 10
+ # infer_context_length = 4
+
norm_type = 'LN'
# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
- buffer_reanalyze_freq = 1/100000
+ buffer_reanalyze_freq = 1/1000000000 # TODO
+ # replay_ratio = 0.1
+ replay_ratio = 0.25
+
+
+ # buffer_reanalyze_freq = 1/10
+ # replay_ratio = 0.1
+
# Each reanalyze process will reanalyze sequences ( transitions per sequence)
reanalyze_batch_size = 160
# The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
@@ -54,7 +88,9 @@ def main(env_id, seed):
domain_name=domain_name,
task_name=task_name,
from_pixels=False, # vector/state obs
- frame_skip=2,
+ # from_pixels=True, # vector/state obs
+ # frame_skip=2,
+ frame_skip=8,
continuous=True,
save_replay_gif=False,
replay_path_gif='./replay_gif',
@@ -75,13 +111,18 @@ def main(env_id, seed):
num_of_sampled_actions=K,
model_type='mlp',
world_model_cfg=dict(
- policy_loss_type='kl',
+ num_simulations=num_simulations,
+ policy_loss_type='kl', # 'simple'
+ # policy_loss_type='simple', # 'simple'
obs_type='vector',
num_unroll_steps=num_unroll_steps,
+ # policy_entropy_weight=0,
+ # policy_entropy_weight=5e-3,
policy_entropy_weight=5e-2,
continuous_action_space=continuous_action_space,
num_of_sampled_actions=K,
sigma_type='conditioned',
+ # sigma_type='fixed',
fixed_sigma_value=0.5,
bound_type=None,
model_type='mlp',
@@ -93,7 +134,8 @@ def main(env_id, seed):
action_space_size=action_space_size,
num_layers=num_layers,
num_heads=8,
- embed_dim=768,
+ embed_dim=768, # original
+ # embed_dim=512,
env_num=max(collector_env_num, evaluator_env_num),
),
),
@@ -108,21 +150,31 @@ def main(env_id, seed):
replay_ratio=replay_ratio,
batch_size=batch_size,
discount_factor=0.99,
- td_steps=5,
- piecewise_decay_lr_scheduler=False,
+ # discount_factor=1,
+ # td_steps=5,
+ # td_steps=10,
+ td_steps=game_segment_length, # TODO
+
+ lr_piecewise_constant_decay=False,
learning_rate=1e-4,
grad_clip_value=5,
+ # grad_clip_value=0.3, # TODO
+ # manual_temperature_decay=False,
manual_temperature_decay=True,
threshold_training_steps_for_final_temperature=int(2.5e4),
- cos_lr_scheduler=True,
+
+ # cos_lr_scheduler=True,
+ cos_lr_scheduler=False,
+
num_segments=num_segments,
train_start_after_envsteps=2000,
game_segment_length=game_segment_length,
num_simulations=num_simulations,
- reanalyze_ratio=0,
+ reanalyze_ratio=reanalyze_ratio,
n_episode=n_episode,
eval_freq=int(5e3),
replay_buffer_size=int(1e6),
+ # replay_buffer_size=int(5e4),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
# ============= The key different params for ReZero =============
@@ -151,7 +203,8 @@ def main(env_id, seed):
# ============ use muzero_segment_collector instead of muzero_collector =============
from lzero.entry import train_unizero_segment
- main_config.exp_name=f'data_sampled_unizero/dmc2gym_{env_id}_brf{buffer_reanalyze_freq}_state_cont_suz_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_K{K}_ns{num_simulations}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_{norm_type}_seed{seed}_learnsigma'
+ main_config.exp_name=f'data_suz_1216/dmc2gym_{env_id}_state_cont_suz_fs8_act-simnorm_td{game_segment_length}_dc099_learn-sigma_gcv5_rbs1e6_no-corlr_embed768_temp2.5e4_pew5e-2_19prior1flatten_obs10value01_clamp4_brf{buffer_reanalyze_freq}_nlayer{num_layers}_numsegments-{num_segments}_gsl{game_segment_length}_K{K}_ns{num_simulations}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_{norm_type}_seed{seed}'
+
train_unizero_segment([main_config, create_config], model_path=main_config.policy.model_path, seed=seed, max_env_step=max_env_step)
@@ -161,6 +214,15 @@ def main(env_id, seed):
parser.add_argument('--env', type=str, help='The environment to use', default='cartpole-swingup')
parser.add_argument('--seed', type=int, help='The seed to use', default=0)
+
args = parser.parse_args()
+
+ # args.env = 'cheetah-run'
+ # args.env = 'walker-walk'
+ # args.env = 'finger-spin'
+ # args.env = 'pendulum-swingup'
+
+ # args.env = 'hopper-hop'
+ # args.env = 'acrobot-swingup'
main(args.env, args.seed)
\ No newline at end of file
diff --git a/zoo/jericho/__init__.py b/zoo/jericho/__init__.py
new file mode 100644
index 000000000..e69de29bb
diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py
new file mode 100644
index 000000000..8dd791303
--- /dev/null
+++ b/zoo/jericho/configs/jericho_unizero_config.py
@@ -0,0 +1,162 @@
+import os
+from easydict import EasyDict
+import os
+os.environ["HF_HOME"] = "/mnt/afs/zhangshenghan/.cache/huggingface/hub"
+
+def main(env_id='detective.z5', seed=0):
+ action_space_size = 50
+ max_steps = 51
+
+ # ==============================================================
+ # begin of the most frequently changed config specified by the user
+ # ==============================================================
+ collector_env_num = 2
+ num_segments = 2
+ game_segment_length = 20
+ evaluator_env_num = 2
+ num_simulations = 50
+ max_env_step = int(10e6)
+ batch_size = 32
+ num_unroll_steps = 10
+ infer_context_length = 4
+ num_layers = 2
+ replay_ratio = 0.25
+ update_per_collect = 20 # NOTE: very important for ddp
+ embed_dim = 768
+ # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
+ # buffer_reanalyze_freq = 1/10
+ buffer_reanalyze_freq = 1/100000
+ # Each reanalyze process will reanalyze sequences ( transitions per sequence)
+ reanalyze_batch_size = 160
+ # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
+ reanalyze_partition = 0.75
+ # model_name = 'BAAI/bge-base-en-v1.5'
+ model_name = 'google-bert/bert-base-uncased'
+ # =========== TODO: only for debug ===========
+ # collector_env_num = 2
+ # num_segments = 2
+ # game_segment_length = 20
+ # evaluator_env_num = 2
+ # max_env_step = int(5e5)
+ # batch_size = 10
+ # num_simulations = 5
+ # num_unroll_steps = 5
+ # infer_context_length = 2
+ # max_steps = 10
+ # num_layers = 1
+ # replay_ratio = 0.05
+ # embed_dim = 768
+ # TODO: MCTS内部的action_space受限于root节点的legal action
+
+ # ==============================================================
+ # end of the most frequently changed config specified by the user
+ # ==============================================================
+ jericho_unizero_config = dict(
+ env=dict(
+ stop_value=int(1e6),
+ observation_shape=512,
+ max_steps=max_steps,
+ max_action_num=action_space_size,
+ tokenizer_path=model_name,
+ # tokenizer_path="/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594",
+ max_seq_len=512,
+ # game_path="z-machine-games-master/jericho-game-suite/" + env_id,
+ game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/"+ env_id,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ multi_gpu=False, # ======== Very important for ddp =============
+ # multi_gpu=True, # ======== Very important for ddp =============
+ # default is 10000
+ use_wandb=False,
+ learn=dict(learner=dict(
+ hook=dict(save_ckpt_after_iter=1000000, ), ), ),
+ model=dict(
+ observation_shape=512,
+ action_space_size=action_space_size,
+ encoder_url=model_name,
+ # encoder_url='google-bert/bert-base-uncased',
+ # encoder_url='/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594',
+ # The input of the model is text, whose shape is identical to the mlp model.
+ model_type='mlp',
+ continuous_action_space=False,
+ world_model_cfg=dict(
+ policy_entropy_weight=5e-3,
+ continuous_action_space=False,
+ max_blocks=num_unroll_steps,
+ # NOTE: each timestep has 2 tokens: obs and action
+ max_tokens=2 * num_unroll_steps,
+ context_length=2 * infer_context_length,
+ device='cuda',
+ action_space_size=action_space_size,
+ use_hf=False,
+ num_layers=num_layers,
+ num_heads=8,
+ embed_dim=embed_dim,
+ obs_type='text', # TODO: Change it.
+ env_num=max(collector_env_num, evaluator_env_num),
+ ),
+ ),
+ update_per_collect=update_per_collect,
+ action_type='varied_action_space',
+ model_path=None,
+ num_unroll_steps=num_unroll_steps,
+ reanalyze_ratio=0,
+ replay_ratio=replay_ratio,
+ batch_size=batch_size,
+ learning_rate=0.0001,
+ num_simulations=num_simulations,
+ num_segments=num_segments,
+ train_start_after_envsteps=0, # TODO
+ game_segment_length=game_segment_length,
+ replay_buffer_size=int(1e6),
+ eval_freq=int(5e3),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ # ============= The key different params for reanalyze =============
+ # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
+ buffer_reanalyze_freq=buffer_reanalyze_freq,
+ # Each reanalyze process will reanalyze sequences ( transitions per sequence)
+ reanalyze_batch_size=reanalyze_batch_size,
+ # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
+ reanalyze_partition=reanalyze_partition,
+ ),
+ )
+ jericho_unizero_config = EasyDict(jericho_unizero_config)
+
+ jericho_unizero_create_config = dict(
+ env=dict(
+ type='jericho',
+ import_names=['zoo.jericho.envs.jericho_env'],
+ ),
+ # NOTE: use base env manager to avoid the bug of subprocess env manager.
+ env_manager=dict(type='base'),
+ # env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='unizero',
+ import_names=['lzero.policy.unizero'],
+ ),
+ )
+ jericho_unizero_create_config = EasyDict(jericho_unizero_create_config)
+ main_config = jericho_unizero_config
+ create_config = jericho_unizero_create_config
+
+ main_config.exp_name = f'data_unizero_detective_20241220/{env_id[:8]}_ms{max_steps}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
+ from lzero.entry import train_unizero
+ train_unizero([main_config, create_config], seed=seed,
+ model_path=main_config.policy.model_path, max_env_step=max_env_step)
+
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser(description='Process some environment.')
+ parser.add_argument('--env', type=str,
+ help='The environment to use', default='detective.z5') # 'detective.z5' 'zork1.z5'
+ parser.add_argument('--seed', type=int, help='The seed to use', default=0)
+ args = parser.parse_args()
+
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
+ main(args.env, args.seed)
diff --git a/zoo/jericho/configs/jericho_unizero_pretrained_config.py b/zoo/jericho/configs/jericho_unizero_pretrained_config.py
new file mode 100644
index 000000000..c02d9072e
--- /dev/null
+++ b/zoo/jericho/configs/jericho_unizero_pretrained_config.py
@@ -0,0 +1,166 @@
+import os
+from easydict import EasyDict
+import os
+os.environ["HF_HOME"] = "/mnt/afs/zhangshenghan/.cache/huggingface/hub"
+
+def main(env_id='detective.z5', seed=0):
+ action_space_size = 50
+ max_steps = 51
+
+ # ==============================================================
+ # begin of the most frequently changed config specified by the user
+ # ==============================================================
+ collector_env_num = 2
+ num_segments = 2
+ game_segment_length = 20
+ evaluator_env_num = 2
+ num_simulations = 50
+ max_env_step = int(10e6)
+ batch_size = 32
+ num_unroll_steps = 10
+ infer_context_length = 4
+ replay_ratio = 0.25
+ update_per_collect = 20 # NOTE: very important for ddp
+ embed_dim = 2048
+ num_layers = 16
+ # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
+ # buffer_reanalyze_freq = 1/10
+ buffer_reanalyze_freq = 1/100000
+ # Each reanalyze process will reanalyze sequences ( transitions per sequence)
+ reanalyze_batch_size = 160
+ # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
+ reanalyze_partition = 0.75
+ # model_name = 'BAAI/bge-base-en-v1.5'
+ model_name = 'google-bert/bert-base-uncased'
+ # =========== TODO: only for debug ===========
+ # collector_env_num = 2
+ # num_segments = 2
+ # game_segment_length = 20
+ # evaluator_env_num = 2
+ # max_env_step = int(5e5)
+ # batch_size = 10
+ # num_simulations = 5
+ # num_unroll_steps = 5
+ # infer_context_length = 2
+ # max_steps = 10
+ # num_layers = 1
+ # replay_ratio = 0.05
+ # embed_dim = 768
+ # TODO: MCTS内部的action_space受限于root节点的legal action
+
+ # ==============================================================
+ # end of the most frequently changed config specified by the user
+ # ==============================================================
+ jericho_unizero_config = dict(
+ env=dict(
+ stop_value=int(1e6),
+ observation_shape=512,
+ max_steps=max_steps,
+ max_action_num=action_space_size,
+ tokenizer_path=model_name,
+ # tokenizer_path="/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594",
+ max_seq_len=512,
+ game_path="z-machine-games-master/jericho-game-suite/" + env_id,
+ # game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/"+ env_id,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ multi_gpu=False, # ======== Very important for ddp =============
+ # multi_gpu=True, # ======== Very important for ddp =============
+ # default is 10000
+ use_wandb=False,
+ learn=dict(learner=dict(
+ hook=dict(save_ckpt_after_iter=1000000, ), ), ),
+ model=dict(
+ observation_shape=512,
+ action_space_size=action_space_size,
+ encoder_url=model_name,
+ # encoder_url='google-bert/bert-base-uncased',
+ # encoder_url='/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594',
+ # The input of the model is text, whose shape is identical to the mlp model.
+ model_type='mlp',
+ continuous_action_space=False,
+ world_model_cfg=dict(
+ policy_entropy_weight=5e-3,
+ continuous_action_space=False,
+ max_blocks=num_unroll_steps,
+ # NOTE: each timestep has 2 tokens: obs and action
+ max_tokens=2 * num_unroll_steps,
+ context_length=2 * infer_context_length,
+ device='cuda',
+ action_space_size=action_space_size,
+ use_hf=True,
+ pretrained_path='/data/share/Llama-3.2-1B',
+ # These parameters should be the same as the config file of original model.
+ num_layers=num_layers,
+ # Note: llama uses GQA, and the number of heads should equal to the number of key-value heads.
+ num_heads=8,
+ embed_dim=embed_dim,
+ hidden_size=64, # The dim for each head is 64 in this case.
+ obs_type='text', # TODO: Change it.
+ env_num=max(collector_env_num, evaluator_env_num),
+ ),
+ ),
+ update_per_collect=update_per_collect,
+ action_type='varied_action_space',
+ model_path=None,
+ num_unroll_steps=num_unroll_steps,
+ reanalyze_ratio=0,
+ replay_ratio=replay_ratio,
+ batch_size=batch_size,
+ learning_rate=0.0001,
+ num_simulations=num_simulations,
+ num_segments=num_segments,
+ train_start_after_envsteps=0, # TODO
+ game_segment_length=game_segment_length,
+ replay_buffer_size=int(1e6),
+ eval_freq=int(5e3),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ # ============= The key different params for reanalyze =============
+ # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
+ buffer_reanalyze_freq=buffer_reanalyze_freq,
+ # Each reanalyze process will reanalyze sequences ( transitions per sequence)
+ reanalyze_batch_size=reanalyze_batch_size,
+ # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
+ reanalyze_partition=reanalyze_partition,
+ ),
+ )
+ jericho_unizero_config = EasyDict(jericho_unizero_config)
+
+ jericho_unizero_create_config = dict(
+ env=dict(
+ type='jericho',
+ import_names=['zoo.jericho.envs.jericho_env'],
+ ),
+ # NOTE: use base env manager to avoid the bug of subprocess env manager.
+ env_manager=dict(type='base'),
+ # env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='unizero',
+ import_names=['lzero.policy.unizero'],
+ ),
+ )
+ jericho_unizero_create_config = EasyDict(jericho_unizero_create_config)
+ main_config = jericho_unizero_config
+ create_config = jericho_unizero_create_config
+
+ main_config.exp_name = f'data_unizero_detective_20241220/{env_id[:8]}_ms{max_steps}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}-upc{update_per_collect}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
+ from lzero.entry import train_unizero
+ train_unizero([main_config, create_config], seed=seed,
+ model_path=main_config.policy.model_path, max_env_step=max_env_step)
+
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser(description='Process some environment.')
+ parser.add_argument('--env', type=str,
+ help='The environment to use', default='detective.z5') # 'detective.z5' 'zork1.z5'
+ parser.add_argument('--seed', type=int, help='The seed to use', default=0)
+ args = parser.parse_args()
+
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
+ main(args.env, args.seed)
diff --git a/zoo/jericho/configs/jericho_unizero_segment_config.py b/zoo/jericho/configs/jericho_unizero_segment_config.py
new file mode 100644
index 000000000..4b6fb4da5
--- /dev/null
+++ b/zoo/jericho/configs/jericho_unizero_segment_config.py
@@ -0,0 +1,153 @@
+import os
+from easydict import EasyDict
+
+
+def main(env_id='detective.z5', seed=0):
+ action_space_size = 50
+
+ # ==============================================================
+ # begin of the most frequently changed config specified by the user
+ # ==============================================================
+ collector_env_num = 8
+ game_segment_length = 20
+ evaluator_env_num = 5
+ num_segments = 8
+ num_simulations = 50
+ max_env_step = int(5e5)
+ batch_size = 64
+ num_unroll_steps = 10
+ infer_context_length = 4
+ num_layers = 2
+ replay_ratio = 0.25
+ embed_dim = 768
+ max_steps = 100
+ # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
+ # buffer_reanalyze_freq = 1/10
+ buffer_reanalyze_freq = 1/100000
+ # Each reanalyze process will reanalyze sequences ( transitions per sequence)
+ reanalyze_batch_size = 160
+ # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
+ reanalyze_partition = 0.75
+
+ # =========== TODO: only for debug ===========
+ collector_env_num = 2
+ num_segments = 2
+ max_steps=20
+ game_segment_length = 20
+ evaluator_env_num = 2
+ num_simulations = 5
+ max_env_step = int(5e5)
+ batch_size = 10
+ num_unroll_steps = 5
+ infer_context_length = 2
+ num_layers = 1
+ replay_ratio = 0.05
+ embed_dim = 32
+
+ # ==============================================================
+ # end of the most frequently changed config specified by the user
+ # ==============================================================
+ jericho_unizero_config = dict(
+ env=dict(
+ stop_value=int(1e6),
+ max_steps=max_steps,
+ observation_shape=512,
+ max_action_num=action_space_size,
+ # tokenizer_path="google-bert/bert-base-uncased",
+ tokenizer_path="/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594",
+ max_seq_len=512,
+ # game_path="z-machine-games-master/jericho-game-suite/" + env_id,
+ game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/"+ env_id,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ # default is 10000
+ learn=dict(learner=dict(
+ hook=dict(save_ckpt_after_iter=1000000, ), ), ),
+ model=dict(
+ observation_shape=512,
+ action_space_size=action_space_size,
+ # encoder_url='google-bert/bert-base-uncased',
+ encoder_url='/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594',
+ # The input of the model is text, whose shape is identical to the mlp model.
+ model_type='mlp',
+ world_model_cfg=dict(
+ policy_entropy_weight=5e-3,
+ continuous_action_space=False,
+ max_blocks=num_unroll_steps,
+ # NOTE: each timestep has 2 tokens: obs and action
+ max_tokens=2 * num_unroll_steps,
+ context_length=2 * infer_context_length,
+ device='cuda',
+ action_space_size=action_space_size,
+ use_hf=False,
+ num_layers=num_layers,
+ num_heads=8,
+ embed_dim=embed_dim,
+ obs_type='text', # TODO: Change it.
+ env_num=max(collector_env_num, evaluator_env_num),
+ ),
+ ),
+ action_type='varied_action_space',
+ model_path=None,
+ num_unroll_steps=num_unroll_steps,
+ reanalyze_ratio=0,
+ replay_ratio=replay_ratio,
+ batch_size=batch_size,
+ learning_rate=0.0001,
+ num_simulations=num_simulations,
+ num_segments=num_segments,
+ # train_start_after_envsteps=2000,
+ train_start_after_envsteps=0, # TODO: only for debug
+ game_segment_length=game_segment_length,
+ replay_buffer_size=int(1e6),
+ eval_freq=int(5e3),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ # ============= The key different params for reanalyze =============
+ # Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
+ buffer_reanalyze_freq=buffer_reanalyze_freq,
+ # Each reanalyze process will reanalyze sequences ( transitions per sequence)
+ reanalyze_batch_size=reanalyze_batch_size,
+ # The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
+ reanalyze_partition=reanalyze_partition,
+ ),
+ )
+ jericho_unizero_config = EasyDict(jericho_unizero_config)
+
+ jericho_unizero_create_config = dict(
+ env=dict(
+ type='jericho',
+ import_names=['zoo.jericho.envs.jericho_env'],
+ ),
+ # NOTE: use base env manager to avoid the bug of subprocess env manager.
+ env_manager=dict(type='base'),
+ # env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='unizero',
+ import_names=['lzero.policy.unizero'],
+ ),
+ )
+ jericho_unizero_create_config = EasyDict(jericho_unizero_create_config)
+ main_config = jericho_unizero_config
+ create_config = jericho_unizero_create_config
+
+ main_config.exp_name = f'data_unizero/{env_id[:-14]}/{env_id[:-14]}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
+ from lzero.entry import train_unizero_segment
+ train_unizero_segment([main_config, create_config], seed=seed,
+ model_path=main_config.policy.model_path, max_env_step=max_env_step)
+
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser(description='Process some environment.')
+ parser.add_argument('--env', type=str,
+ help='The environment to use', default='detective.z5') # 'zork1.z5'
+ parser.add_argument('--seed', type=int, help='The seed to use', default=0)
+ args = parser.parse_args()
+
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
+ main(args.env, args.seed)
diff --git a/zoo/jericho/envs/__init__.py b/zoo/jericho/envs/__init__.py
new file mode 100644
index 000000000..740dab512
--- /dev/null
+++ b/zoo/jericho/envs/__init__.py
@@ -0,0 +1 @@
+from .jericho_env import JerichoEnv
diff --git a/zoo/jericho/envs/jericho_env.py b/zoo/jericho/envs/jericho_env.py
new file mode 100644
index 000000000..4fdd7c853
--- /dev/null
+++ b/zoo/jericho/envs/jericho_env.py
@@ -0,0 +1,168 @@
+"""
+env返回的obs是id 不是string
+"""
+import copy
+from typing import List
+
+import gym
+import numpy as np
+from transformers import AutoTokenizer
+from ding.utils import ENV_REGISTRY
+from ding.envs import BaseEnv, BaseEnvTimestep
+from jericho import FrotzEnv
+from ding.utils import set_pkg_seed, get_rank, get_world_size
+
+
+@ENV_REGISTRY.register('jericho')
+class JerichoEnv(BaseEnv):
+ """
+ Overview:
+ The environment for Jericho games. For more details about the game, please refer to the \
+ `Jericho `.
+ """
+ tokenizer = None
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+ self.max_steps = cfg.max_steps
+ self.game_path = cfg.game_path
+ self.max_action_num = cfg.max_action_num
+ self.max_seq_len = cfg.max_seq_len
+
+ if JerichoEnv.tokenizer is None:
+ JerichoEnv.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path)
+
+ self._env = FrotzEnv(self.game_path)
+ self._action_list = None
+ self.finished = False
+ self._init_flag = False
+ self.episode_return = 0
+ self.env_step = 0
+
+ self.observation_space = gym.spaces.Dict()
+ self.action_space = gym.spaces.Discrete(self.max_action_num)
+ self.reward_space = gym.spaces.Box(
+ low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32)
+
+ def prepare_obs(self, obs, return_str: bool = False):
+ if self._action_list is None:
+ self._action_list = self._env.get_valid_actions()
+ full_obs = obs + "\nValid actions: " + str(self._action_list)
+ if not return_str:
+ full_obs = JerichoEnv.tokenizer(
+ [full_obs], truncation=True, padding="max_length", max_length=self.max_seq_len)
+ obs_attn_mask = full_obs['attention_mask']
+ full_obs = np.array(full_obs['input_ids'][0], dtype=np.int32) # TODO: attn_mask
+ if len(self._action_list) <= self.max_action_num:
+ action_mask = [1] * len(self._action_list) + [0] * \
+ (self.max_action_num - len(self._action_list))
+ else:
+ action_mask = [1] * len(self._action_list)
+
+ action_mask = np.array(action_mask, dtype=np.int8)
+ if return_str:
+ return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1}
+ else:
+ return {'observation': full_obs, 'obs_attn_mask': obs_attn_mask, 'action_mask': action_mask, 'to_play': -1}
+
+ def reset(self, return_str: bool = False):
+ initial_observation, info = self._env.reset()
+ self.finished = False
+ self._init_flag = True
+ self._action_list = None
+ self.episode_return = 0
+ self.env_step = 0
+
+ # 获取当前的 world_size 和 rank
+ self.world_size = get_world_size()
+ self.rank = get_rank()
+
+ return self.prepare_obs(initial_observation, return_str)
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ """
+ Set the seed for the environment.
+ """
+ self._seed = seed
+ self._env.seed(seed)
+
+ def close(self) -> None:
+ self._init_flag = False
+
+ def __repr__(self) -> str:
+ return "LightZero Jericho Env"
+
+ def step(self, action: int, return_str: bool = False):
+ try:
+ action_str = self._action_list[action]
+ except Exception as e:
+ # TODO: why exits illegal action
+ print('='*20)
+ print(e, f'rank {self.rank}, action {action} is illegal now we randomly choose a legal action from {self._action_list}!')
+ action = np.random.choice(len(self._action_list))
+ action_str = self._action_list[action]
+
+ observation, reward, done, info = self._env.step(action_str)
+ self.env_step += 1
+ self.episode_return += reward
+ self._action_list = None
+ observation = self.prepare_obs(observation, return_str)
+
+ # print(f'rank {self.rank}, step: {self.env_step}')
+ # print(f'self._action_list:{self._action_list}')
+ # print(f'rank {self.rank}, step: {self.env_step}, observation:{observation}, action:{action}, reward:{reward}')
+
+ if self.env_step >= self.max_steps:
+ done = True
+
+ if done:
+ print('='*20)
+ print(f'rank {self.rank} one episode done!')
+ # print(f'self._action_list:{self._action_list}, action:{action}, reward:{reward}')
+ self.finished = True
+ info['eval_episode_return'] = self.episode_return
+
+ return BaseEnvTimestep(observation, reward, done, info)
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_env_num = cfg.pop('collector_env_num')
+ cfg = copy.deepcopy(cfg)
+ # when in collect phase, sometimes we need to normalize the reward
+ # reward_normalize is determined by the config.
+ cfg.is_collect = True
+ return [cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_env_num = cfg.pop('evaluator_env_num')
+ cfg = copy.deepcopy(cfg)
+ # when in evaluate phase, we don't need to normalize the reward.
+ cfg.reward_normalize = False
+ cfg.is_collect = False
+ return [cfg for _ in range(evaluator_env_num)]
+
+
+if __name__ == '__main__':
+ from easydict import EasyDict
+ env_cfg = EasyDict(
+ dict(
+ max_steps=100,
+ # game_path="z-machine-games-master/jericho-game-suite/zork1.z5",
+ game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/detective.z5",
+ # game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/905.z5",
+ max_action_num=50,
+ max_env_step=100,
+ tokenizer_path="google-bert/bert-base-uncased",
+ max_seq_len=512
+ )
+ )
+ env = JerichoEnv(env_cfg)
+ obs = env.reset(return_str=True)
+ print(f'[OBS]:\n{obs["observation"]}')
+ while True:
+ action_id = int(input('Please input the action id:'))
+ obs, reward, done, info = env.step(action_id, return_str=True)
+ print(f'[OBS]:\n{obs["observation"]}')
+ if done:
+ break
diff --git a/zoo/jericho/envs/test_jericho_env.py b/zoo/jericho/envs/test_jericho_env.py
new file mode 100644
index 000000000..28db93b53
--- /dev/null
+++ b/zoo/jericho/envs/test_jericho_env.py
@@ -0,0 +1,41 @@
+from easydict import EasyDict
+from .jericho_env import JerichoEnv
+import numpy as np
+import pytest
+
+
+@pytest.mark.unittest
+class TestJerichoEnv():
+ def setup(self) -> None:
+ # Configuration for the Jericho environment
+ cfg = EasyDict(
+ dict(
+ game_path="z-machine-games-master/jericho-game-suite/zork1.z5",
+ max_action_num=50,
+ tokenizer_path="google-bert/bert-base-uncased",
+ max_seq_len=512
+ )
+ )
+ # Create a Jericho environment that will be used in the following tests.
+ self.env = JerichoEnv(cfg)
+
+ # Test the initialization of the Jericho environment.
+ def test_initialization(self):
+ assert isinstance(self.env, JerichoEnv)
+
+ # Test the reset method of the Jericho environment.
+ # Ensure that the shape of the observation is as expected.
+ def test_reset(self):
+ obs = self.env.reset()
+ assert obs['observation'].shape == (512,)
+
+ # Test the step method of the Jericho environment.
+ # Ensure that the shape of the observation, the type of the reward,
+ # the type of the done flag and the type of the info are as expected.
+ def test_step_shape(self):
+ self.env.reset()
+ obs, reward, done, info = self.env.step(1)
+ assert obs['observation'].shape == (512,)
+ assert isinstance(reward, np.ndarray)
+ assert isinstance(done, bool)
+ assert isinstance(info, dict)