Skip to content

Commit

Permalink
polish(pu): adapt muzero_multitask to segment_collector
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Oct 16, 2024
1 parent 6abef12 commit cc7cc66
Show file tree
Hide file tree
Showing 12 changed files with 747 additions and 157 deletions.
1 change: 1 addition & 0 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,4 @@
from .train_unizero import train_unizero
from .train_unizero_segment import train_unizero_segment
from .train_unizero_multitask import train_unizero_multitask
from .train_unizero_multitask_segment import train_unizero_multitask_segment
2 changes: 1 addition & 1 deletion lzero/entry/train_unizero_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def train_unizero_multitask(

learner.call_hook('before_run')
value_priority_tasks = {}
update_per_collect = cfg.policy.update_per_collect

while True:
# Precompute positional embedding matrices for collect/eval (not training)
Expand Down Expand Up @@ -175,7 +176,6 @@ def train_unizero_multitask(
new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)

# Determine updates per collection
update_per_collect = cfg.policy.update_per_collect
if update_per_collect is None:
collected_transitions_num = sum(len(game_segment) for game_segment in new_data[0])
update_per_collect = int(collected_transitions_num * cfg.policy.replay_ratio)
Expand Down
296 changes: 296 additions & 0 deletions lzero/entry/train_unizero_multitask_segment.py

Large diffs are not rendered by default.

29 changes: 12 additions & 17 deletions lzero/model/unizero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,21 +88,16 @@ def __init__(
print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder')
print('==' * 20)
elif world_model_cfg.obs_type == 'image':
self.representation_network = nn.ModuleList()
# for task_id in range(self.task_num): # N independent encoder
for task_id in range(1): # TODO: one share encoder
self.representation_network.append(RepresentationNetworkUniZero(
observation_shape,
num_res_blocks,
num_channels,
self.downsample,
activation=self.activation,
norm_type=norm_type,
embedding_dim=world_model_cfg.embed_dim,
group_size=world_model_cfg.group_size,
))
# TODO: we should change the output_shape to the real observation shape
self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64))
self.representation_network = RepresentationNetworkUniZero(
observation_shape,
num_res_blocks,
num_channels,
self.downsample,
activation=self.activation,
norm_type=norm_type,
embedding_dim=world_model_cfg.embed_dim,
group_size=world_model_cfg.group_size,
)

# ====== for analysis ======
if world_model_cfg.analysis_sim_norm:
Expand Down Expand Up @@ -151,7 +146,7 @@ def __init__(
print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network')
print('==' * 20)

def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, task_id=None) -> MZNetworkOutput:
def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None) -> MZNetworkOutput:
"""
Overview:
Initial inference of UniZero model, which is the first step of the UniZero model.
Expand Down Expand Up @@ -188,7 +183,7 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_
)

def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index=0,
latent_state_index_in_search_path=[], task_id=None) -> MZNetworkOutput:
latent_state_index_in_search_path=[]) -> MZNetworkOutput:
"""
Overview:
Recurrent inference of UniZero model.To perform the recurrent inference, we concurrently predict the latent dynamics (reward/next_latent_state)
Expand Down
10 changes: 3 additions & 7 deletions lzero/model/unizero_model_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(
elif world_model_cfg.obs_type == 'image':
self.representation_network = nn.ModuleList()
# for task_id in range(self.task_num): # N independent encoder
for task_id in range(1): # one share encoder
for task_id in range(1): # TODO: one share encoder
self.representation_network.append(RepresentationNetworkUniZero(
observation_shape,
num_res_blocks,
Expand All @@ -110,23 +110,19 @@ def __init__(
group_size=world_model_cfg.group_size,
))
# TODO: we should change the output_shape to the real observation shape
self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64))
# self.decoder_network = LatentDecoder(embedding_dim=world_model_cfg.embed_dim, output_shape=(3, 64, 64))

# ====== for analysis ======
if world_model_cfg.analysis_sim_norm:
self.encoder_hook = FeatureAndGradientHook()
self.encoder_hook.setup_hooks(self.representation_network)

self.tokenizer = Tokenizer(encoder=self.representation_network,
decoder_network=self.decoder_network, with_lpips=True,)
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False)
self.world_model = WorldModelMT(config=world_model_cfg, tokenizer=self.tokenizer)
print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')
print(f'{sum(p.numel() for p in self.world_model.parameters()) - sum(p.numel() for p in self.tokenizer.decoder_network.parameters()) - sum(p.numel() for p in self.tokenizer.lpips.parameters())} parameters in agent.world_model - (decoder_network and lpips)')

print('==' * 20)
print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer')
print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder')
print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network')
print('==' * 20)
elif world_model_cfg.obs_type == 'image_memory':
self.representation_network = LatentEncoderForMemoryEnv(
Expand Down
13 changes: 8 additions & 5 deletions lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(self, config: TransformerConfig, tokenizer) -> None:
self.tokenizer = tokenizer
self.config = config
self.transformer = Transformer(self.config)

self.task_num = 1
if self.config.device == 'cpu':
self.device = torch.device('cpu')
else:
Expand Down Expand Up @@ -181,7 +181,10 @@ def custom_copy_kv_cache_to_shared_wm(self, src_kv: KeysValues) -> int:
Returns:
- index (:obj:`int`): The index of the copied KeysValues object in the shared pool.
"""
src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape
try:
src_kv_shape = src_kv._keys_values[0]._k_cache._cache.shape
except Exception as e:
print(f"src_kv_shape: {src_kv_shape}")

if self.shared_pool_wm[self.shared_pool_index_wm] is None:
self.shared_pool_wm[self.shared_pool_index_wm] = KeysValues(
Expand Down Expand Up @@ -638,18 +641,18 @@ def reset_for_initial_inference(self, obs_act_dict: torch.FloatTensor) -> torch.
current_obs_embeddings = self.tokenizer.encode_to_obs_embeddings(batch_current_obs)
# print(f"current_obs_embeddings.device: {current_obs_embeddings.device}")
self.latent_state = current_obs_embeddings
outputs_wm = self.wm_forward_for_initial_infererence(obs_embeddings, batch_action,
outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action,
current_obs_embeddings)
else:
# ================ calculate the target value in Train phase ================
self.latent_state = obs_embeddings
outputs_wm = self.wm_forward_for_initial_infererence(obs_embeddings, batch_action, None)
outputs_wm = self.wm_forward_for_initial_inference(obs_embeddings, batch_action, None)

return outputs_wm, self.latent_state

#@profile
@torch.no_grad()
def wm_forward_for_initial_infererence(self, last_obs_embeddings: torch.LongTensor,
def wm_forward_for_initial_inference(self, last_obs_embeddings: torch.LongTensor,
batch_action=None,
current_obs_embeddings=None) -> torch.FloatTensor:
"""
Expand Down
Loading

0 comments on commit cc7cc66

Please sign in to comment.