Skip to content

Commit

Permalink
polish(pu): polish atari multitask related configs
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Dec 18, 2024
1 parent 29298e6 commit ffdf4db
Show file tree
Hide file tree
Showing 24 changed files with 626 additions and 3,049 deletions.
3 changes: 2 additions & 1 deletion lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from .train_muzero_multitask_segment_ddp import train_muzero_multitask_segment_ddp


from .train_unizero_multitask import train_unizero_multitask
from .train_unizero_multitask_serial import train_unizero_multitask_serial
from .train_unizero_multitask_segment import train_unizero_multitask_segment
from .train_unizero_multitask_segment_serial import train_unizero_multitask_segment_serial

from .train_unizero_multitask_segment_eval import train_unizero_multitask_segment_eval
31 changes: 0 additions & 31 deletions lzero/entry/train_unizero_multitask_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,12 @@

# ========== TODO ==========
# 设置超时时间 (秒)
# TIMEOUT = 60000 # 例如1000min

TIMEOUT = 12000 # 例如200min

# TIMEOUT = 6000 # 例如100min

# TIMEOUT = 3600 # 例如60min

# TIMEOUT = 1800 # 例如30min

# TIMEOUT = 600 # 例如10min

# TIMEOUT = 300 # 例如5min
# TIMEOUT = 10 # 例如6秒

def safe_eval(evaluator, learner, collector, rank, world_size):
try:
Expand Down Expand Up @@ -70,29 +62,6 @@ def safe_eval(evaluator, learner, collector, rank, world_size):
print(f"An error occurred during evaluation on Rank {rank}/{world_size}: {e}")
return None, None

# def safe_eval(evaluator, learner, collector, rank, world_size):
# print(f"=========before eval Rank {rank}/{world_size}===========")
# # 重置 stop_event,确保每次评估前都处于未设置状态
# evaluator.stop_event.clear()
# with concurrent.futures.ThreadPoolExecutor() as executor:
# # 提交 evaluator.eval 任务
# future = executor.submit(evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep)

# try:
# stop, reward = future.result(timeout=TIMEOUT)
# except concurrent.futures.TimeoutError:
# # 超时,设置 evaluator 的 stop_event
# evaluator.stop_event.set()
# print(f"Eval operation timed out after {TIMEOUT} seconds on Rank {rank}/{world_size}.")

# # future.cancel() # 对于进程池,这个 cancel 实际上不会有用
# # executor.shutdown(wait=False) # 非阻塞关闭池,但好像不起作用
# # print(f"after executor.shutdown(wait=False) on Rank {rank}/{world_size}.")

# return None, None

# print(f"======after eval Rank {rank}/{world_size}======")
# return stop, reward


def allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=1):
Expand Down
68 changes: 0 additions & 68 deletions lzero/entry/train_unizero_multitask_segment_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,16 +31,6 @@

TIMEOUT = 12000 # 例如200min

# TIMEOUT = 6000 # 例如100min

# TIMEOUT = 3600 # 例如60min

# TIMEOUT = 1800 # 例如30min

# TIMEOUT = 600 # 例如10min

# TIMEOUT = 300 # 例如5min
# TIMEOUT = 10 # 例如6秒

def safe_eval(evaluator, learner, collector, rank, world_size):
try:
Expand Down Expand Up @@ -70,30 +60,6 @@ def safe_eval(evaluator, learner, collector, rank, world_size):
print(f"An error occurred during evaluation on Rank {rank}/{world_size}: {e}")
return None, None

# def safe_eval(evaluator, learner, collector, rank, world_size):
# print(f"=========before eval Rank {rank}/{world_size}===========")
# # 重置 stop_event,确保每次评估前都处于未设置状态
# evaluator.stop_event.clear()
# with concurrent.futures.ThreadPoolExecutor() as executor:
# # 提交 evaluator.eval 任务
# future = executor.submit(evaluator.eval, learner.save_checkpoint, learner.train_iter, collector.envstep)

# try:
# stop, reward = future.result(timeout=TIMEOUT)
# except concurrent.futures.TimeoutError:
# # 超时,设置 evaluator 的 stop_event
# evaluator.stop_event.set()
# print(f"Eval operation timed out after {TIMEOUT} seconds on Rank {rank}/{world_size}.")

# # future.cancel() # 对于进程池,这个 cancel 实际上不会有用
# # executor.shutdown(wait=False) # 非阻塞关闭池,但好像不起作用
# # print(f"after executor.shutdown(wait=False) on Rank {rank}/{world_size}.")

# return None, None

# print(f"======after eval Rank {rank}/{world_size}======")
# return stop, reward


def allocate_batch_size(cfgs, game_buffers, alpha=1.0, clip_scale=1):
"""
Expand Down Expand Up @@ -493,40 +459,6 @@ def train_unizero_multitask_segment_eval(
sys.exit(0)
# ========== TODO: ==========

if cfg.policy.use_priority:
for idx, (cfg, replay_buffer) in enumerate(zip(cfgs, game_buffers)):
# 更新任务特定的 replay buffer 的优先级
task_id = cfg.policy.task_id
replay_buffer.update_priority(train_data_multi_task[idx], log_vars[0][f'value_priority_task{task_id}'])

current_priorities = log_vars[0][f'value_priority_task{task_id}']

mean_priority = np.mean(current_priorities)
std_priority = np.std(current_priorities)

alpha = 0.1 # 运行均值的平滑因子
if f'running_mean_priority_task{task_id}' not in value_priority_tasks:
# 如果不存在,则初始化运行均值
value_priority_tasks[f'running_mean_priority_task{task_id}'] = mean_priority
else:
# 更新运行均值
value_priority_tasks[f'running_mean_priority_task{task_id}'] = (
alpha * mean_priority + (1 - alpha) * value_priority_tasks[f'running_mean_priority_task{task_id}']
)

# 使用运行均值计算归一化的优先级
running_mean_priority = value_priority_tasks[f'running_mean_priority_task{task_id}']
normalized_priorities = (current_priorities - running_mean_priority) / (std_priority + 1e-6)

# 如果需要,可以将归一化的优先级存储回 replay buffer
# replay_buffer.update_priority(train_data_multi_task[idx], normalized_priorities)

# 如果设置了 print_task_priority_logs 标志,则记录统计信息
if cfg.policy.print_task_priority_logs:
print(f"Task {task_id} - Mean Priority: {mean_priority:.8f}, "
f"Running Mean Priority: {running_mean_priority:.8f}, "
f"Standard Deviation: {std_priority:.8f}")

train_epoch += 1
policy.recompute_pos_emb_diff_and_clear_cache()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,28 @@
from functools import partial
from typing import Tuple, Optional, List

import torch
import numpy as np
import torch
from ding.config import compile_config
from ding.envs import create_env_manager, get_vec_env_setting
from ding.policy import create_policy
from ding.rl_utils import get_epsilon_greedy_fn
from ding.utils import EasyTimer
from ding.utils import set_pkg_seed, get_rank
from ding.worker import BaseLearner
from tensorboardX import SummaryWriter

from lzero.entry.utils import log_buffer_memory_usage
from lzero.mcts import UniZeroGameBuffer as GameBuffer
from lzero.policy import visit_count_temperature
from lzero.worker import MuZeroEvaluator as Evaluator
from lzero.mcts import UniZeroGameBuffer as GameBuffer
from lzero.worker import MuZeroSegmentCollector as Collector
from ding.utils import EasyTimer

timer = EasyTimer()
from line_profiler import line_profiler


#@profile
def train_unizero_multitask_segment(
def train_unizero_multitask_segment_serial(
input_cfg_list: List[Tuple[int, Tuple[dict, dict]]],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from line_profiler import line_profiler

#@profile
def train_unizero_multitask(
def train_unizero_multitask_serial(
input_cfg_list: List[Tuple[int, Tuple[dict, dict]]],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
Expand Down
Loading

0 comments on commit ffdf4db

Please sign in to comment.