Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(pu): add rope that use the true timestep index as pos_index #266

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ def _prepare_reward_value_context(
game_segment_lens = []
# for board games
action_mask_segment, to_play_segment = [], []
# step_index_segment = []

td_steps_list = []
for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list):
Expand All @@ -253,6 +254,8 @@ def _prepare_reward_value_context(
action_mask_segment.append(game_segment.action_mask_segment)
to_play_segment.append(game_segment.to_play_segment)

# step_index_segment.append(game_segment.step_index_segment)

for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
# get the <num_unroll_steps+1> bootstrapped target obs
td_steps_list.append(td_steps)
Expand Down
28 changes: 21 additions & 7 deletions lzero/mcts/buffer/game_buffer_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def sample(
batch_size, self._cfg.reanalyze_ratio
)

# current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list]
# current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list, step_index_list]

# target reward, target value
batch_rewards, batch_target_values = self._compute_target_reward_value(
reward_value_context, policy._target_model, current_batch[1] # current_batch[1] is action_batch
reward_value_context, policy._target_model, current_batch[1], current_batch[-1] # current_batch[1] is action_batch
)
# target policy
batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model,
Expand Down Expand Up @@ -120,13 +120,17 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data
batch_size = len(batch_index_list)
obs_list, action_list, mask_list = [], [], []
step_index_list = []
# prepare the inputs of a batch
for i in range(batch_size):
game = game_segment_list[i]
pos_in_game_segment = pos_in_game_segment_list[i]

actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
self._cfg.num_unroll_steps].tolist()

step_index_tmp = game.step_index_segment[pos_in_game_segment:pos_in_game_segment +
self._cfg.num_unroll_steps].tolist()
# add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid
mask_tmp = [1. for i in range(len(actions_tmp))]
mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))]
Expand All @@ -136,6 +140,11 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
np.random.randint(0, game.action_space_size)
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
]
# TODO
step_index_tmp += [
0
for _ in range(self._cfg.num_unroll_steps - len(step_index_tmp))
]

# obtain the input observations
# pad if length of obs in game_segment is less than stack+num_unroll_steps
Expand All @@ -147,12 +156,13 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
)
action_list.append(actions_tmp)
mask_list.append(mask_tmp)
step_index_list.append(step_index_tmp)

# formalize the input observations
obs_list = prepare_observation(obs_list, self._cfg.model.model_type)

# formalize the inputs of a batch
current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list]
current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list, step_index_list]
for i in range(len(current_batch)):
current_batch[i] = np.asarray(current_batch[i])

Expand Down Expand Up @@ -216,13 +226,15 @@ def _prepare_policy_reanalyzed_context(
rewards, child_visits, game_segment_lens = [], [], []
# for board games
action_mask_segment, to_play_segment = [], []
step_index_segment = []
for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list):
game_segment_len = len(game_segment)
game_segment_lens.append(game_segment_len)
rewards.append(game_segment.reward_segment)
# for board games
action_mask_segment.append(game_segment.action_mask_segment)
to_play_segment.append(game_segment.to_play_segment)
step_index_segment.append(game_segment.step_index_segment)

child_visits.append(game_segment.child_visit_segment)
# prepare the corresponding observations
Expand All @@ -241,7 +253,7 @@ def _prepare_policy_reanalyzed_context(

policy_re_context = [
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens,
action_mask_segment, to_play_segment
action_mask_segment, to_play_segment, step_index_segment
]
return policy_re_context

Expand All @@ -260,10 +272,11 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:

# for board games
policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \
to_play_segment = policy_re_context # noqa
to_play_segment, step_index_segment = policy_re_context # noqa
transition_batch_size = len(policy_obs_list)
game_segment_batch_size = len(pos_in_game_segment_list)

# TODO: step_index_segment
to_play, action_mask = self._preprocess_to_play_and_action_mask(
game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
)
Expand All @@ -289,6 +302,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:
# calculate the target value
# action_batch.shape (32, 10)
# m_obs.shape torch.Size([352, 3, 64, 64]) 32*11=352
# TODO: step_index_batch
m_output = model.initial_inference(m_obs, action_batch[:self.reanalyze_num]) # NOTE: :self.reanalyze_num
# =======================================================================

Expand Down Expand Up @@ -368,7 +382,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:

return batch_target_policies_re

def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, action_batch) -> Tuple[
def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, action_batch, step_index_batch) -> Tuple[
Any, Any]:
"""
Overview:
Expand Down Expand Up @@ -410,7 +424,7 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
# =============== NOTE: The key difference with MuZero =================
# calculate the target value
# m_obs.shape torch.Size([352, 3, 64, 64]) 32*11 = 352
m_output = model.initial_inference(m_obs, action_batch)
m_output = model.initial_inference(m_obs, action_batch, start_pos=step_index_batch) # TODO: step_index
# ======================================================================

if not model.training:
Expand Down
8 changes: 8 additions & 0 deletions lzero/mcts/buffer/game_segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(self, action_space: int, game_segment_length: int = 200, config: Ea

self.action_mask_segment = []
self.to_play_segment = []
self.step_index_segment = []

self.target_values = []
self.target_rewards = []
Expand Down Expand Up @@ -133,6 +134,7 @@ def append(
reward: np.ndarray,
action_mask: np.ndarray = None,
to_play: int = -1,
step_index: int = 0,
chance: int = 0,
) -> None:
"""
Expand All @@ -145,6 +147,8 @@ def append(

self.action_mask_segment.append(action_mask)
self.to_play_segment.append(to_play)
self.step_index_segment.append(step_index)

if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment.append(chance)

Expand Down Expand Up @@ -290,6 +294,8 @@ def game_segment_to_array(self) -> None:

self.action_mask_segment = np.array(self.action_mask_segment)
self.to_play_segment = np.array(self.to_play_segment)
self.step_index_segment = np.array(self.step_index_segment)

if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = np.array(self.chance_segment)

Expand All @@ -310,6 +316,8 @@ def reset(self, init_observations: np.ndarray) -> None:

self.action_mask_segment = []
self.to_play_segment = []
self.step_index_segment = []

if self.use_ture_chance_label_in_chance_encoder:
self.chance_segment = []

Expand Down
4 changes: 2 additions & 2 deletions lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m
# @profile
def search(
self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int,
List[Any]]
List[Any]], step_index
) -> None:
"""
Overview:
Expand Down Expand Up @@ -144,7 +144,7 @@ def search(
At the end of the simulation, the statistics along the trajectory are updated.
"""
# for UniZero
network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path)
network_output = model.recurrent_inference(state_action_history, simulation_index, latent_state_index_in_search_path, step_index)

network_output.latent_state = to_detach_cpu_numpy(network_output.latent_state)
network_output.policy_logits = to_detach_cpu_numpy(network_output.policy_logits)
Expand Down
8 changes: 4 additions & 4 deletions lzero/model/unizero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __init__(
print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network')
print('==' * 20)

def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None) -> MZNetworkOutput:
def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, start_pos: int = 0) -> MZNetworkOutput:
"""
Overview:
Initial inference of UniZero model, which is the first step of the UniZero model.
Expand All @@ -177,7 +177,7 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_
"""
batch_size = obs_batch.size(0)
obs_act_dict = {'obs': obs_batch, 'action': action_batch, 'current_obs': current_obs_batch}
_, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict)
_, obs_token, logits_rewards, logits_policy, logits_value = self.world_model.forward_initial_inference(obs_act_dict, start_pos)
latent_state, reward, policy_logits, value = obs_token, logits_rewards, logits_policy, logits_value
policy_logits = policy_logits.squeeze(1)
value = value.squeeze(1)
Expand All @@ -190,7 +190,7 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_
)

def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index=0,
latent_state_index_in_search_path=[]) -> MZNetworkOutput:
latent_state_index_in_search_path=[], start_pos: int = 0) -> MZNetworkOutput:
"""
Overview:
Recurrent inference of UniZero model.To perform the recurrent inference, we concurrently predict the latent dynamics (reward/next_latent_state)
Expand All @@ -216,7 +216,7 @@ def recurrent_inference(self, state_action_history: torch.Tensor, simulation_ind
latent state, W_ is the width of latent state.
"""
_, logits_observations, logits_rewards, logits_policy, logits_value = self.world_model.forward_recurrent_inference(
state_action_history, simulation_index, latent_state_index_in_search_path)
state_action_history, simulation_index, latent_state_index_in_search_path, start_pos)
next_latent_state, reward, policy_logits, value = logits_observations, logits_rewards, logits_policy, logits_value
policy_logits = policy_logits.squeeze(1)
value = value.squeeze(1)
Expand Down
84 changes: 76 additions & 8 deletions lzero/model/unizero_world_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@

import math
from dataclasses import dataclass
from typing import Optional
from typing import Optional, Tuple

import torch
import torch.nn as nn
from ding.torch_utils.network import GRUGatingUnit
from einops import rearrange
from torch.nn import functional as F
import numpy as np

from .kv_caching import KeysValues

Expand All @@ -28,7 +29,10 @@ class TransformerConfig:
embed_pdrop: float
resid_pdrop: float
attn_pdrop: float


# for RoPE
rope_theta: float
max_seq_len: int
@property
def max_tokens(self):
return self.tokens_per_block * self.max_blocks
Expand All @@ -55,6 +59,12 @@ def __init__(self, config: TransformerConfig) -> None:
self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)])
self.ln_f = nn.LayerNorm(config.embed_dim)

self.freqs_cis = precompute_freqs_cis(
self.config.embed_dim // self.config.num_heads,
self.config.max_seq_len * 2,
self.config.rope_theta,
)

def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
"""
Generate a placeholder for keys and values.
Expand All @@ -70,7 +80,7 @@ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device)

def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
valid_context_lengths: Optional[torch.Tensor] = None, start_pos: int = 0) -> torch.Tensor:
"""
Forward pass of the Transformer model.

Expand All @@ -82,11 +92,30 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues
Returns:
- torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim).
"""
seqlen = sequences.shape[1]
self.freqs_cis = self.freqs_cis.to(sequences.device)

# freqs_cis = self.freqs_cis[start_pos: start_pos + seqlen]

# If the start position is greater than the predefined maximum sequence length, wrap around
start_pos = torch.tensor(np.array(start_pos))
if len(start_pos.shape) > 1:
# TODO: train start pos [0]
start_pos = torch.remainder(start_pos, self.config.max_seq_len)[:,0]
else:
start_pos = torch.remainder(start_pos, self.config.max_seq_len)

start_pos_list = torch.unbind(start_pos)
try:
freqs_cis_slices = [self.freqs_cis[int(pos.item()): int(pos.item()) + seqlen] for pos in start_pos_list]
except:
print('debug')
freqs_cis = torch.stack(freqs_cis_slices).squeeze(1)

assert past_keys_values is None or len(past_keys_values) == len(self.blocks)
x = self.drop(sequences)
for i, block in enumerate(self.blocks):
x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths)

x = block(x, None if past_keys_values is None else past_keys_values[i], valid_context_lengths, freqs_cis)
x = self.ln_f(x)
return x

Expand Down Expand Up @@ -129,7 +158,7 @@ def __init__(self, config: TransformerConfig) -> None:
)

def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
valid_context_lengths: Optional[torch.Tensor] = None, freqs_cis: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass of the Transformer block.

Expand All @@ -141,7 +170,7 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None
Returns:
- torch.Tensor: Output tensor of shape (batch_size, seq_length, embed_dim).
"""
x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths)
x_attn = self.attn(self.ln1(x), past_keys_values, valid_context_lengths, freqs_cis)
if self.gru_gating:
x = self.gate1(x, x_attn)
x = self.gate2(x, self.mlp(self.ln2(x)))
Expand All @@ -152,6 +181,42 @@ def forward(self, x: torch.Tensor, past_keys_values: Optional[KeysValues] = None
return x


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
freqs = torch.outer(t, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis


def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
# print(f"freqs_cis shape: {freqs_cis.shape}, x shape: {x.shape}")
assert 0 <= 1 < ndim
# assert freqs_cis.shape == (x.shape[2], x.shape[-1])
# shape = [d if i == 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
# TODO: check
shape = [d if i == 2 or i == ndim - 1 or i == 0 else 1 for i, d in enumerate(x.shape)]

return freqs_cis.view(*shape)


def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
try:
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
except:
print('We are at the reset timestep!')
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2)
return xq_out.type_as(xq), xk_out.type_as(xk)


class SelfAttention(nn.Module):
"""
Implements self-attention mechanism for transformers.
Expand Down Expand Up @@ -189,7 +254,7 @@ def __init__(self, config: TransformerConfig) -> None:
self.register_buffer('mask', causal_mask)

def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
valid_context_lengths: Optional[torch.Tensor] = None, freqs_cis: torch.Tensor = None) -> torch.Tensor:
"""
Forward pass for the self-attention mechanism.

Expand All @@ -212,6 +277,9 @@ def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None,
q = self.query(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size)
k = self.key(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size)
v = self.value(x).view(B, T, self.num_heads, C // self.num_heads).transpose(1, 2) # (B, num_heads, T, head_size)

if self.config.rotary_emb:
q, k = apply_rotary_emb(q, k, freqs_cis=freqs_cis)

if kv_cache is not None:
kv_cache.update(k, v)
Expand Down
Loading
Loading