Skip to content

Commit

Permalink
fix(pu): fix init latent state, fix past_keys_values_cache in mcts se…
Browse files Browse the repository at this point in the history
…arch
  • Loading branch information
puyuan1996 committed Nov 28, 2023
1 parent 36adc43 commit e39d421
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 14 deletions.
4 changes: 3 additions & 1 deletion lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,9 @@ def search(

# TODO
# 在每次模拟后更新 state_action_history
state_action_history.append((last_latent_state, last_actions))
state_action_history.append((last_latent_state, last_actions.cpu().numpy()))
# state_action_history.append(last_latent_state)
# state_action_history.append(last_actions)

"""
MCTS stage 2: Expansion
Expand Down
40 changes: 32 additions & 8 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def forward(self, tokens: torch.LongTensor, past_keys_values: Optional[KeysValue
logits_policy = self.head_policy(x, num_steps=num_steps, prev_steps=prev_steps)
logits_value = self.head_value(x, num_steps=num_steps, prev_steps=prev_steps)

# TODO: root reward value
return WorldModelOutput(x, logits_observations, logits_rewards, logits_ends, logits_policy, logits_value)

@torch.no_grad()
Expand Down Expand Up @@ -216,7 +217,7 @@ def reset_from_initial_observations(self, observations: torch.FloatTensor) -> to
outputs_wm = self.refresh_keys_values_with_initial_obs_tokens(obs_tokens)
self.obs_tokens = obs_tokens

return outputs_wm, self.decode_obs_tokens()
return outputs_wm, self.decode_obs_tokens(), self.obs_tokens

@torch.no_grad()
def refresh_keys_values_with_initial_obs_tokens(self, obs_tokens: torch.LongTensor) -> torch.FloatTensor:
Expand Down Expand Up @@ -244,9 +245,22 @@ def forward_initial_inference(self, obs: torch.LongTensor, should_predict_next_o
expanded_observations = expanded_observations.expand(*desired_shape)
obs = expanded_observations

outputs_wm, _ = self.reset_from_initial_observations(obs)
outputs_wm, _, obs_tokens = self.reset_from_initial_observations(obs)

# TODO
# If the cache is full, remove the oldest item from the dictionary
if len(self.past_keys_values_cache) == self.max_cache_size:
oldest_key = self.past_keys_values_cache.popleft()
del self.cache_dict[oldest_key]

cache_key = hash([obs_tokens.cpu().numpy()])
# Add the new item to the deque and dictionary
self.past_keys_values_cache.append(cache_key)
self.cache_dict[cache_key] = copy.deepcopy(self.keys_values_wm)

# return outputs_wm.output_sequence, outputs_wm.logits_observations, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value
return outputs_wm.output_sequence, obs_tokens, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value

return outputs_wm.output_sequence, outputs_wm.logits_observations, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value

def forward_recurrent_inference(self, state_action_history, should_predict_next_obs: bool = True):

Expand All @@ -270,9 +284,13 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
i for i, (latent_state, _) in enumerate(state_action_history) if np.array_equal(latent_state, root_latent_state))

# 从这个位置开始的 action_history
action_history_from_last_root = state_action_history[last_root_position:]
# cache_key = tuple(action_history_from_last_root)
cache_key = hash(action_history_from_last_root)
# state_action_history_from_last_root = state_action_history[last_root_position:]
# [(s0,a0)] -> [s0]
# [(s0,a0),(s1,a1)] -> [(s0,a0),s1]
state_action_history_from_last_root = state_action_history[last_root_position:-1] + [state_action_history[-1][0]]

# cache_key = tuple(state_action_history_from_last_root)
cache_key = hash(state_action_history_from_last_root)

# if cache_key in self.past_keys_values_cache:
# self.keys_values_wm = self.past_keys_values_cache[cache_key]
Expand Down Expand Up @@ -305,12 +323,12 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_

# obs is in token level
# num_steps=1, prev_steps=16
outputs_wm = self.forward(token, past_keys_values=self.keys_values_wm,
is_root=False) # TODO: inference=False
outputs_wm = self.forward(token, past_keys_values=self.keys_values_wm, is_root=False) # TODO: inference=False
# if k==0, action_token self.head_observations 1,...,0,1
output_sequence.append(outputs_wm.output_sequence)

if k == 0:
# if k==0, token is action_token outputs_wm.logits_rewards 是有值的
# reward = Categorical(logits=outputs_wm.logits_rewards).sample().float().cpu().numpy().reshape(-1) - 1 # (B,)
done = Categorical(logits=outputs_wm.logits_ends).sample().cpu().numpy().astype(bool).reshape(-1) # (B,)
reward = outputs_wm.logits_rewards # (B,)
Expand All @@ -336,6 +354,12 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
# # 这样可以在固定的内存空间中保持缓存,并自动清理旧的缓存项。
# self.past_keys_values_cache.append((cache_key, copy.deepcopy(self.keys_values_wm)))

# [(s0,a0)] -> [(s0,a0),s1]
# [(s0,a0),(s1,a1)] -> [(s0,a0),(s1,a1), s2]
state_action_history_from_last_root = [state_action_history[last_root_position], self.obs_tokens.cpu().numpy()]
# cache_key = tuple(state_action_history_from_last_root)
cache_key = hash(state_action_history_from_last_root)

# If the cache is full, remove the oldest item from the dictionary
if len(self.past_keys_values_cache) == self.max_cache_size:
oldest_key = self.past_keys_values_cache.popleft()
Expand Down
17 changes: 13 additions & 4 deletions lzero/model/muzero_gpt_model_vector_obs.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,31 @@ def initial_inference(self, obs: torch.Tensor) -> MZNetworkOutput:
# latent_state,
# )

x, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(
# x, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(
# obs)
# logits_observations, reward, policy_logits, value = logits_observations, logits_rewards, logits_policy, logits_value

# # obs discrete distribution to one_hot latent state?
# # torch.Size([8, 16, 512]) -> torch.Size([8, 16])
# latent_state = torch.argmax(logits_observations, dim=2, keepdim=False)

x, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(
obs)
logits_observations, reward, policy_logits, value = logits_observations, logits_rewards, logits_policy, logits_value
reward, policy_logits, value = logits_rewards, logits_policy, logits_value

# obs discrete distribution to one_hot latent state?
# torch.Size([8, 16, 512]) -> torch.Size([8, 16])
latent_state = torch.argmax(logits_observations, dim=2, keepdim=False)
latent_state = obs_token

# TODO: root value policy_logit
# torch.Size([8, 1, 2]) - > torch.Size([8, 2])
policy_logits = policy_logits.squeeze(1)
# torch.Size([8, 1, 601]) - > torch.Size([8, 601])
value = value.squeeze(1)

return MZNetworkOutput(
value,
[0. for _ in range(batch_size)],
[0. for _ in range(batch_size)], # reward
policy_logits,
latent_state,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
# ==============================================================

cartpole_muzero_gpt_config = dict(
exp_name=f'data_mz_gpt_ctree_1127_debug/cartpole_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd128_mediumnet_bs{batch_size}_mcs5e3_fixtokenizer_seed0',
exp_name=f'data_mz_gpt_ctree_1128/cartpole_muzero_gpt_ns{num_simulations}_upc{update_per_collect}-mur{model_update_ratio}_rr{reanalyze_ratio}_H{num_unroll_steps}_nlayers2_emd128_mediumnet_bs{batch_size}_mcs5e3_fixedtokenizer_fixinitlatent_seed0',
env=dict(
env_name='CartPole-v0',
continuous=False,
Expand Down

0 comments on commit e39d421

Please sign in to comment.