Skip to content

Commit

Permalink
fix(pu): fix sampled_unizero multitask ddp pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Dec 24, 2024
1 parent 563548b commit d8705e6
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 13 deletions.
3 changes: 3 additions & 0 deletions lzero/entry/train_unizero_multitask_segment_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,9 @@ def train_unizero_multitask_segment_ddp(
print('=' * 20)
print(f'Rank {rank} 评估任务_id: {cfg.policy.task_id}...')

# =========TODO=========
evaluator._policy.reset(reset_init_data=True, task_id=cfg.policy.task_id)

# 执行安全评估
stop, reward = safe_eval(evaluator, learner, collector, rank, world_size)
# 判断评估是否成功
Expand Down
3 changes: 2 additions & 1 deletion lzero/model/unizero_world_models/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,10 @@ def encode_to_obs_embeddings(self, x: torch.Tensor, task_id = None) -> torch.Ten
elif len(shape) == 4:
# Case when input is 4D (B, C, H, W)
try:
# obs_embeddings = self.encoder[task_id](x)
obs_embeddings = self.encoder(x, task_id=task_id) # TODO: for dmc multitask
# obs_embeddings = self.encoder[task_id](x)
except Exception as e:
print(e)
obs_embeddings = self.encoder(x) # TODO: for memory env

obs_embeddings = rearrange(obs_embeddings, 'b e -> b 1 e')
Expand Down
2 changes: 1 addition & 1 deletion lzero/model/unizero_world_models/world_model_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -889,8 +889,8 @@ def wm_forward_for_initial_inference(self, last_obs_embeddings: torch.LongTensor
# Copy and store keys_values_wm for a single environment
self.update_cache_context(current_obs_embeddings, is_init_infer=True)

# elif n > self.env_num and batch_action is not None and current_obs_embeddings is None:
elif batch_action is not None and current_obs_embeddings is None:
# elif n > self.env_num and batch_action is not None and current_obs_embeddings is None:
# ================ calculate the target value in Train phase ================
# [192, 16, 64] -> [32, 6, 16, 64]
last_obs_embeddings = last_obs_embeddings.contiguous().view(batch_action.shape[0], -1, num_observations_tokens,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def generate_configs(env_id_list: List[str],
num_segments: int,
total_batch_size: int):
configs = []
exp_name_prefix = f'data_suz_mt_ddp_20241224/8gpu_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_seed{seed}/'
exp_name_prefix = f'data_suz_mt_20241224/ddp_8gpu_{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}_seed{seed}/'
action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list]
observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list]

Expand Down
20 changes: 10 additions & 10 deletions zoo/dmc2gym/config/dmc2gym_state_suz_multitask_serial_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def create_config(env_id, action_space_size_list, observation_shape_list, collec
num_layers=2,
num_heads=8,
embed_dim=768,
env_num=len(env_id_list),
env_num=max(collector_env_num, evaluator_env_num),
task_num=len(env_id_list),
use_normal_head=True,
use_softmoe_head=False,
Expand Down Expand Up @@ -111,7 +111,7 @@ def generate_configs(env_id_list, seed, collector_env_num, evaluator_env_num, n_
Generate configurations for all DMC tasks in the environment list.
"""
configs = []
exp_name_prefix = f'data_suz_mt_debug/{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}/'
exp_name_prefix = f'data_suz_mt_20241224/{len(env_id_list)}tasks_brf{buffer_reanalyze_freq}/'
action_space_size_list = [dmc_state_env_action_space_map[env_id] for env_id in env_id_list]
observation_shape_list = [dmc_state_env_obs_space_map[env_id] for env_id in env_id_list]
for task_id, env_id in enumerate(env_id_list):
Expand Down Expand Up @@ -195,7 +195,7 @@ def create_env_manager():
num_segments = 8
n_episode = 8
num_simulations = 50
batch_size = [64, 64] # 可以根据需要调整或者设置为列表
batch_size = [64 for _ in range(len(env_id_list))]
num_unroll_steps = 5
infer_context_length = 2
norm_type = 'LN'
Expand All @@ -206,13 +206,13 @@ def create_env_manager():
update_per_collect = 100

# ========== TODO: debug config ============
collector_env_num = 2
evaluator_env_num = 2
num_segments = 2
n_episode = 2
num_simulations = 2
batch_size = [4,4] # 可以根据需要调整或者设置为列表
update_per_collect = 1
# collector_env_num = 2
# evaluator_env_num = 2
# num_segments = 2
# n_episode = 2
# num_simulations = 2
# batch_size = [4,4] # 可以根据需要调整或者设置为列表
# update_per_collect = 1

# 生成配置
configs = generate_configs(
Expand Down

0 comments on commit d8705e6

Please sign in to comment.