Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TMP feature(whl): add pretrained llm for unizero #310

Open
wants to merge 19 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
155 changes: 105 additions & 50 deletions lzero/entry/train_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.worker import MuZeroCollector as Collector
from .utils import random_collect
import torch.distributed as dist
from ding.utils import set_pkg_seed, get_rank, get_world_size


def train_unizero(
Expand Down Expand Up @@ -52,106 +54,143 @@ def train_unizero(

cfg, create_cfg = input_cfg

# Ensure the specified policy type is supported
assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero entry now only supports the following algo.: 'unizero', 'sampled_unizero'"
logging.info("===== 开始训练 UniZero =====")

# 检查是否支持指定的 policy 类型
assert create_cfg.policy.type in ['unizero', 'sampled_unizero'], "train_unizero 仅支持以下算法: 'unizero', 'sampled_unizero'"
logging.info(f"使用的 policy 类型为: {create_cfg.policy.type}")

# Import the correct GameBuffer class based on the policy type
# 根据 policy 类型导入对应的 GameBuffer
game_buffer_classes = {'unizero': 'UniZeroGameBuffer', 'sampled_unizero': 'SampledUniZeroGameBuffer'}

GameBuffer = getattr(__import__('lzero.mcts', fromlist=[game_buffer_classes[create_cfg.policy.type]]),
game_buffer_classes[create_cfg.policy.type])

# Set device based on CUDA availability
# 检查是否有 GPU 可用,设置设备
cfg.policy.device = cfg.policy.model.world_model_cfg.device if torch.cuda.is_available() else 'cpu'
logging.info(f'cfg.policy.device: {cfg.policy.device}')
logging.info(f"设备已设置为: {cfg.policy.device}")

# Compile the configuration
# 编译配置文件
logging.info("正在编译配置文件...")
cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
logging.info("配置文件编译完成!")

# Create main components: env, policy
# 创建环境管理器
logging.info("正在创建环境管理器...")
env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
logging.info("环境管理器创建完成!")

# 环境和随机种子初始化
logging.info("正在初始化环境和随机种子...")
collector_env.seed(cfg.seed)
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=torch.cuda.is_available())
logging.info("环境和随机种子初始化完成!")

# 如果使用 wandb,初始化 wandb
if cfg.policy.use_wandb:
# Initialize wandb
logging.info("正在初始化 wandb...")
wandb.init(
project="LightZero",
config=cfg,
sync_tensorboard=False,
monitor_gym=False,
save_code=True,
)
logging.info("wandb 初始化完成!")

# 创建 policy
logging.info("正在创建 policy...")
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
logging.info("policy 创建完成!")

# Load pretrained model if specified
# 如果指定了模型路径,加载预训练模型
if model_path is not None:
logging.info(f'Loading model from {model_path} begin...')
logging.info(f"正在从 {model_path} 加载预训练模型...")
policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
logging.info(f'Loading model from {model_path} end!')
logging.info("预训练模型加载完成!")

# Create worker components: learner, collector, evaluator, replay buffer, commander
# 创建训练的核心组件
logging.info("正在创建训练的核心组件...")
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)

# MCTS+RL algorithms related core code
policy_config = cfg.policy
replay_buffer = GameBuffer(policy_config)
replay_buffer = GameBuffer(cfg.policy)
collector = Collector(env=collector_env, policy=policy.collect_mode, tb_logger=tb_logger, exp_name=cfg.exp_name,
policy_config=policy_config)
policy_config=cfg.policy)
evaluator = Evaluator(eval_freq=cfg.policy.eval_freq, n_evaluator_episode=cfg.env.n_evaluator_episode,
stop_value=cfg.env.stop_value, env=evaluator_env, policy=policy.eval_mode,
tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=policy_config)
tb_logger=tb_logger, exp_name=cfg.exp_name, policy_config=cfg.policy)
logging.info("训练核心组件创建完成!")

# Learner's before_run hook
# Learner 的前置 hook
logging.info("正在执行 Learner 的 before_run hook...")
learner.call_hook('before_run')
if policy_config.use_wandb:
logging.info("Learner 的 before_run hook 执行完成!")

if cfg.policy.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

# Collect random data before training
# 随机收集数据
if cfg.policy.random_collect_episode_num > 0:
logging.info("正在进行随机数据收集...")
random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
logging.info("随机数据收集完成!")

batch_size = policy._cfg.batch_size

if cfg.policy.multi_gpu:
# 获取当前的 world_size 和 rank
world_size = get_world_size()
rank = get_rank()
else:
world_size = 1
rank = 0

while True:
# Log buffer memory usage
# torch.cuda.empty_cache()

# 记录 replay buffer 的内存使用情况
# logging.info(f"训练迭代 {learner.train_iter}: 正在记录 replay buffer 的内存使用情况...")
log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
# logging.info(f"训练迭代 {learner.train_iter}: 内存使用记录完成!")

# Set temperature for visit count distributions
# 设置温度参数
collect_kwargs = {
'temperature': visit_count_temperature(
policy_config.manual_temperature_decay,
policy_config.fixed_temperature_value,
policy_config.threshold_training_steps_for_final_temperature,
cfg.policy.manual_temperature_decay,
cfg.policy.fixed_temperature_value,
cfg.policy.threshold_training_steps_for_final_temperature,
trained_steps=learner.train_iter
),
'epsilon': 0.0 # Default epsilon value
'epsilon': 0.0 # 默认 epsilon
}
# logging.info(f"训练迭代 {learner.train_iter}: 温度设置完成,值为 {collect_kwargs['temperature']}")

# Configure epsilon for epsilon-greedy exploration
if policy_config.eps.eps_greedy_exploration_in_collect:
# 配置 epsilon-greedy 探索
if cfg.policy.eps.eps_greedy_exploration_in_collect:
epsilon_greedy_fn = get_epsilon_greedy_fn(
start=policy_config.eps.start,
end=policy_config.eps.end,
decay=policy_config.eps.decay,
type_=policy_config.eps.type
start=cfg.policy.eps.start,
end=cfg.policy.eps.end,
decay=cfg.policy.eps.decay,
type_=cfg.policy.eps.type
)
collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
# logging.info(f"训练迭代 {learner.train_iter}: epsilon 设置完成,值为 {collect_kwargs['epsilon']}")

# Evaluate policy performance
# 评估 policy 的表现
if evaluator.should_eval(learner.train_iter):
logging.info(f"训练迭代 {learner.train_iter}: 开始评估...")
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
logging.info(f"训练迭代 {learner.train_iter}: 评估完成,是否停止: {stop}, 当前奖励: {reward}")
if stop:
logging.info("满足停止条件,训练结束!")
break

# Collect new data
# 收集新数据
# logging.info(f"Rank {rank}, 训练迭代 {learner.train_iter}: 开始收集新数据...")
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
logging.info(f"Rank {rank}, 训练迭代 {learner.train_iter}: 新数据收集完成!")

# Determine updates per collection
update_per_collect = cfg.policy.update_per_collect
Expand All @@ -162,44 +201,60 @@ def train_unizero(
collected_transitions_num = sum(min(len(game_segment), cfg.policy.game_segment_length) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)

# Update replay buffer
# 更新 replay buffer
# logging.info(f"训练迭代 {learner.train_iter}: 开始更新 replay buffer...")
replay_buffer.push_game_segments(new_data)
replay_buffer.remove_oldest_data_to_fit()
# logging.info(f"训练迭代 {learner.train_iter}: replay buffer 更新完成!")

if world_size > 1:
# 同步训练前所有rank的准备状态
try:
dist.barrier()
# logging.info(f'Rank {rank}: 通过训练前的同步障碍')
except Exception as e:
logging.error(f'Rank {rank}: 同步障碍失败,错误: {e}')
break

# Train the policy if sufficient data is available
# 检查是否有足够数据进行训练
if collector.envstep > cfg.policy.train_start_after_envsteps:
if cfg.policy.sample_type == 'episode':
data_sufficient = replay_buffer.get_num_of_game_segments() > batch_size
else:
data_sufficient = replay_buffer.get_num_of_transitions() > batch_size

if not data_sufficient:
logging.warning(
f'The data in replay_buffer is not sufficient to sample a mini-batch: '
f'batch_size: {batch_size}, replay_buffer: {replay_buffer}. Continue to collect now ....'
)
# NOTE: 注意ddp训练时,不同rank可能有的replay buffer 数据不足,导致有的没有进入训练阶段,从而通信超时,需要确保同时进入训练阶段
logging.warning(f"Rank {rank}: 训练迭代 {learner.train_iter}: replay buffer 数据不足,继续收集数据...")
continue

logging.info(f"Rank {rank}, 训练迭代 {learner.train_iter}: 开始训练!")

# 执行多轮训练
for i in range(update_per_collect):
train_data = replay_buffer.sample(batch_size, policy)
if cfg.policy.reanalyze_ratio > 0 and i % 20 == 0:
# Clear caches and precompute positional embedding matrices
policy.recompute_pos_emb_diff_and_clear_cache() # TODO

if policy_config.use_wandb:
policy.recompute_pos_emb_diff_and_clear_cache()

if cfg.policy.use_wandb:
policy.set_train_iter_env_step(learner.train_iter, collector.envstep)

train_data.append({'train_which_component': 'transformer'})
log_vars = learner.train(train_data, collector.envstep)

if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

logging.info(f"Rank {rank}, 训练迭代 {learner.train_iter}: 训练完成!")

policy.recompute_pos_emb_diff_and_clear_cache()

# Check stopping criteria
# 检查停止条件
if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
logging.info("满足停止条件,训练结束!")
break

learner.call_hook('after_run')
wandb.finish()
return policy
if cfg.policy.use_wandb:
wandb.finish()
logging.info("===== 训练完成 =====")
return policy
Loading