Skip to content

Commit

Permalink
fix(pu): fix self.past_keys_values_cache bug in mcts_ctree for muzero…
Browse files Browse the repository at this point in the history
…_gpt
  • Loading branch information
puyuan1996 committed Nov 30, 2023
1 parent 26b6255 commit fd85f4d
Show file tree
Hide file tree
Showing 12 changed files with 746 additions and 223 deletions.
5 changes: 2 additions & 3 deletions lzero/entry/train_muzero_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,8 @@ def train_muzero_gpt(
if cfg.policy.use_priority:
replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])

# TODO
# policy._learn_model.world_model.past_keys_values_cache.clear()
# torch.cuda.empty_cache()
# NOTE: TODO
policy._learn_model.world_model.past_keys_values_cache.clear()

if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
break
Expand Down
13 changes: 4 additions & 9 deletions lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,11 +281,11 @@ def search(

state_action_history = [] # 初始化 state_action_history 变量
last_latent_state = latent_state_roots
# TODO
# 你可能需要在每次搜索开始时清除past_keys_values_cache,以防止缓存过大:
# NOTE: very important, from the right init key-value-cache
# forward_initial_inference()以及执行了下面的操作
# _ = model.world_model.refresh_keys_values_with_initial_obs_tokens(model.world_model.obs_tokens)

# model.world_model.past_keys_values_cache.clear() # 清除缓存
# if len(model.world_model.past_keys_values_cache) > self._cfg.max_cache_size:
# model.world_model.past_keys_values_cache.clear() # 清除缓存
for simulation_index in range(self._cfg.num_simulations):
# In each simulation, we expanded a new node, so in one search, we have ``num_simulations`` num of nodes at most.

Expand Down Expand Up @@ -320,9 +320,6 @@ def search(
# state_action_history.append((last_latent_state, last_actions.detach().cpu().numpy()))
state_action_history.append((latent_states.detach().cpu().numpy(), last_actions.detach().cpu().numpy()))

# state_action_history.append(last_latent_state)
# state_action_history.append(last_actions)

"""
MCTS stage 2: Expansion
At the final time-step l of the simulation, the next_latent_state and reward/value_prefix are computed by the dynamics function.
Expand All @@ -332,10 +329,8 @@ def search(
"""
# network_output = model.recurrent_inference(latent_states, last_actions) # for classic muzero
# network_output = model.recurrent_inference(last_actions) # TODO: for muzero_gpt latent_states is not used in the model.

network_output = model.recurrent_inference(state_action_history) # TODO: latent_states is not used in the model.


network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits)
network_output.value = to_detach_cpu_numpy(self.inverse_scalar_transform_handle(network_output.value))
Expand Down
54 changes: 31 additions & 23 deletions lzero/model/gpt_models/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,39 +3,47 @@
cfg['tokenizer'] = {'_target_': 'models.tokenizer.Tokenizer',
# 'vocab_size': 512, # TODO: for atari
# 'embed_dim': 512,
# 'vocab_size': 256, # TODO: for atari debug
# 'embed_dim': 256,
'vocab_size': 128, # TODO: for atari debug
'embed_dim': 128,
# 'vocab_size': 64, # TODO: for cartpole
# 'embed_dim': 64,
'encoder':
{'resolution': 1, 'in_channels': 4, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0},# TODO: for cartpole
'decoder':
{'resolution': 1, 'in_channels': 4, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0}} # TODO: for cartpole

{'resolution': 64, 'in_channels': 3, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0},# TODO:for atari debug
'decoder':
{'resolution': 64, 'in_channels': 3, 'z_channels': 128, 'ch': 64,
'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
'out_ch': 3, 'dropout': 0.0}} # TODO:for atari debug
# {'resolution': 64, 'in_channels': 1, 'z_channels': 512, 'ch': 64,
# 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
# 'out_ch': 3, 'dropout': 0.0},# TODO:for atari
# 'decoder':
# {'resolution': 64, 'in_channels': 1, 'z_channels': 512, 'ch': 64,
# 'ch_mult': [1, 1, 1, 1, 1], 'num_res_blocks': 2, 'attn_resolutions': [8, 16],
# 'out_ch': 3, 'dropout': 0.0}} # TODO:for atari
cfg['world_model'] = {
'tokens_per_block': 17,
'max_blocks': 20,
"max_tokens": 17 * 20, # TODO: horizon
'tokens_per_block': 17,
# 'max_blocks': 20,
# "max_tokens": 17 * 20, # TODO: horizon
'max_blocks': 5,
"max_tokens": 17 * 5, # TODO: horizon
'attention': 'causal',
'num_heads': 4,
# 'num_layers': 10,# TODO:for atari
'num_layers': 2, # TODO:for debug
'embed_dim': 128, # TODO: for cartpole
# 'embed_dim': 64, # TODO: for cartpole
'num_layers': 2, # TODO:for atari debug
'num_heads': 4,
'embed_dim': 128, # TODO:for atari
# 'embed_dim': 64, # TODO:for atari debug
'embed_pdrop': 0.1,
'resid_pdrop': 0.1,
'attn_pdrop': 0.1,
"device": 'cuda:0',
# "device": 'cpu',
"device": 'cuda:2',
# "device": 'cpu',
'support_size': 601,
# 'support_size': 21,
'action_shape': 2,# TODO: for cartpole
'max_cache_size': 500,
# 'max_cache_size':25,
'action_shape': 6,# TODO:for atari
# 'max_cache_size':5000,
'max_cache_size':50,

}

from easydict import EasyDict
Expand Down
2 changes: 1 addition & 1 deletion lzero/model/gpt_models/cfg_atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
'support_size': 601,
'action_shape': 6,# TODO:for atari
# 'max_cache_size':5000,
'max_cache_size':25,
'max_cache_size':50,

}

Expand Down
8 changes: 5 additions & 3 deletions lzero/model/gpt_models/cfg_cartpole.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

cfg['world_model'] = {
'tokens_per_block': 17,
'max_blocks': 20,
"max_tokens": 17 * 20, # TODO: horizon
# 'max_blocks': 20,
# "max_tokens": 17 * 20, # TODO: horizon
'max_blocks': 5,
"max_tokens": 17 * 5, # TODO: horizon
'attention': 'causal',
'num_heads': 4,
# 'num_layers': 10,# TODO:for atari
Expand All @@ -34,7 +36,7 @@
'support_size': 601,
# 'support_size': 21,
'action_shape': 2,# TODO: for cartpole
# 'max_cache_size': 5000,
# 'max_cache_size': 500,
'max_cache_size':25,
}

Expand Down
80 changes: 34 additions & 46 deletions lzero/model/gpt_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,8 @@ def __init__(self, obs_vocab_size: int, act_vocab_size: int, config: Transformer
break

import collections
# self.past_keys_values_cache = collections.OrderedDict()
self.past_keys_values_cache = {}
self.past_keys_values_cache = collections.OrderedDict()
# self.past_keys_values_cache = {}
# from collections import deque
# self.past_keys_values_cache = deque(maxlen=self.max_cache_size)
# self.cache_dict = {}
Expand Down Expand Up @@ -239,6 +239,7 @@ def refresh_keys_values_with_initial_obs_tokens(self, obs_tokens: torch.LongTens
# return outputs_wm.output_sequence # (B, K, E)
return outputs_wm

@torch.no_grad()
def forward_initial_inference(self, obs: torch.LongTensor, should_predict_next_obs: bool = True):

if len(obs[0].shape) == 3:
Expand All @@ -256,11 +257,13 @@ def forward_initial_inference(self, obs: torch.LongTensor, should_predict_next_o
outputs_wm, _, obs_tokens = self.reset_from_initial_observations(obs)

# TODO
cache_key = hash([obs_tokens.detach().cpu().numpy()])
# self.past_keys_values_cache[cache_key] = copy.deepcopy(self.keys_values_wm)
# self.past_keys_values_cache[cache_key] = self.keys_values_wm
# self.past_keys_values_cache.clear()

if obs_tokens.shape[0] > 1:
# TODO: minibatch
pass
else:
# collect eval only for env_num=1
cache_key = tuple(obs_tokens.squeeze(0).detach().cpu().numpy())
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.keys_values_wm)

# return outputs_wm.output_sequence, outputs_wm.logits_observations, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value
return outputs_wm.output_sequence, obs_tokens, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value
Expand All @@ -269,37 +272,26 @@ def forward_initial_inference(self, obs: torch.LongTensor, should_predict_next_o
# TODO: only for inference, not for training
@torch.no_grad()
def forward_recurrent_inference(self, state_action_history, should_predict_next_obs: bool = True):
# 已知state_action_history[0] 是 (root_latent_state, last_actions) 并且state_action_history[-1]是最后一步的latent state和action
# 而且中间可能有很多个(root_latent_state, *) 由于transformer在计算时应该都是从(root_latent_state, root_action)开始 unroll的,因此我希望
# 找到这样的一个(root_latent_state, root_action)的最后一个位置,然后从这个位置找到其在keys_values_cache中的对应的value,然后从这个value开始
# 进行unroll,这样就可以保证在进行推断时,不会出现重复计算的情况

# 我们只需要找到最后一个root_latent_state的位置,然后从这个位置开始的action_history就可以了,不需要找到
# 最后一个root_action的位置,因为这个位置可能是不对的,因为在root_state可能有多个不同的action

# 找到 action_history 中最后一个 root_latent_state 的位置
root_latent_state, _ = state_action_history[0]
last_root_position = max(
i for i, (latent_state, _) in enumerate(state_action_history) if np.array_equal(latent_state, root_latent_state))

# 从位置last_root_position开始的 state_action_history
# [(s0,a0)] -> [s0]
# [(s0,a0),(s1,a1)] -> [(s0,a0),s1]
state_action_history_from_last_root = state_action_history[last_root_position:-1] + [state_action_history[-1][0]]
# if last_root_position>0:
# print('='*20)
# print('last_root_position>0')
# print('='*20)

# cache_key = tuple(state_action_history_from_last_root)
cache_key = hash(state_action_history_from_last_root)

if cache_key in self.past_keys_values_cache:
# self.keys_values_wm = self.past_keys_values_cache[cache_key]
self.keys_values_wm = copy.deepcopy(self.past_keys_values_cache[cache_key])

# 如果没有找到对应的缓存,那么直接使用当前的self.keys_values_wm
# self.past_keys_values_cache.clear()
# 一般来讲,在一次 MCTS search中,我们需要维护H长度的context来使用transformer进行推理。
# 由于在一次search里面。agent最多访问sim个不同的节点,因此我们只需维护一个 {(state:kv_cache)}的列表。
# 但如果假设环境是MDP的话,然后根据当前的 latest_state s_t 在这个列表中查找即可
# TODO: 但如果假设环境是非MDP的话,需要维护一个 {(rootstate_action_history:kv_cache)}的列表?

latest_state = state_action_history[-1][0]

matched_key = None
for key in self.past_keys_values_cache.keys():
# TODO:从后往前寻找,因为后面的是最新的
# 将 key 转换为 numpy 数组并与 latest_state 进行比较
if np.allclose(np.array(key), latest_state, atol=1e-5): # atol 是差值的阈值
matched_key = key
break # 如果找到匹配的 key,就退出循环

if matched_key is not None:
self.keys_values_wm = copy.deepcopy(self.past_keys_values_cache[matched_key])
else:
# NOTE: very important
_ = self.refresh_keys_values_with_initial_obs_tokens(torch.tensor(latest_state, dtype=torch.long).to(self.device))

assert self.keys_values_wm is not None and self.num_observations_tokens is not None

Expand All @@ -308,6 +300,7 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
output_sequence, obs_tokens = [], []

if self.keys_values_wm.size + num_passes > self.config.max_tokens:
# TODO: the impact
_ = self.refresh_keys_values_with_initial_obs_tokens(self.obs_tokens)

# TODO
Expand Down Expand Up @@ -344,19 +337,14 @@ def forward_recurrent_inference(self, state_action_history, should_predict_next_
self.obs_tokens = torch.cat(obs_tokens, dim=1) # (B, K)

obs = self.decode_obs_tokens() if should_predict_next_obs else None
# return outputs_wm.output_sequence, outputs_wm.logits_observations, outputs_wm.logits_rewards, outputs_wm.logits_policy, outputs_wm.logits_value

# [(s0,a0)] -> [(s0,a0),s1]
# [(s0,a0),(s1,a1)] -> [(s0,a0),(s1,a1), s2]
state_action_history_from_last_root = [state_action_history[last_root_position], self.obs_tokens.detach().cpu().numpy()]
cache_key = hash(state_action_history_from_last_root)
cache_key = tuple(self.obs_tokens.squeeze(0).detach().cpu().numpy())

# TODO: 在计算结束后,更新缓存. 是否需要deepcopy
self.past_keys_values_cache[cache_key] = copy.deepcopy(self.keys_values_wm)
# self.past_keys_values_cache.clear()
if len(self.past_keys_values_cache) > self.max_cache_size:
# if len(self.past_keys_values_cache) > 20:
self.past_keys_values_cache.clear()
# TODO: lru_cache
self.past_keys_values_cache.popitem(last=False) # Removes the earliest inserted item

return outputs_wm.output_sequence, self.obs_tokens, reward, outputs_wm.logits_policy, outputs_wm.logits_value

Expand Down
Loading

0 comments on commit fd85f4d

Please sign in to comment.