Skip to content

Commit

Permalink
feature(pu): add mistralai moe in transformer feedforward and head of…
Browse files Browse the repository at this point in the history
… unizero
  • Loading branch information
dyyoungg committed Jul 23, 2024
1 parent 5117459 commit 2495d60
Show file tree
Hide file tree
Showing 10 changed files with 431 additions and 45 deletions.
2 changes: 2 additions & 0 deletions lzero/entry/train_unizero_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
from lzero.worker import MuZeroCollector as Collector, MuZeroEvaluator as Evaluator
from lzero.mcts import UniZeroGameBuffer as GameBuffer

from line_profiler import line_profiler

#@profile
def train_unizero_multitask(
input_cfg_list: List[Tuple[int, Tuple[dict, dict]]],
seed: int = 0,
Expand Down
6 changes: 6 additions & 0 deletions lzero/mcts/buffer/game_buffer_unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

if TYPE_CHECKING:
from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy
from line_profiler import line_profiler


@BUFFER_REGISTRY.register('game_buffer_unizero')
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(self, cfg: dict):
self.task_id = None
print("No task_id found in configuration. Task ID is set to None.")

#@profile
def sample(
self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]
) -> List[Any]:
Expand Down Expand Up @@ -103,6 +105,7 @@ def sample(
train_data = [current_batch, target_batch]
return train_data

#@profile
def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
"""
Overview:
Expand Down Expand Up @@ -198,6 +201,7 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
context = reward_value_context, policy_re_context, policy_non_re_context, current_batch
return context

#@profile
def _prepare_policy_reanalyzed_context(
self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str]
) -> List[Any]:
Expand Down Expand Up @@ -251,6 +255,7 @@ def _prepare_policy_reanalyzed_context(
]
return policy_re_context

#@profile
def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any, action_batch) -> np.ndarray:
"""
Overview:
Expand Down Expand Up @@ -374,6 +379,7 @@ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model:

return batch_target_policies_re

#@profile
def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any, action_batch) -> Tuple[
Any, Any]:
"""
Expand Down
7 changes: 4 additions & 3 deletions lzero/mcts/tree_search/mcts_ctree.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from lzero.mcts.ctree.ctree_muzero import mz_tree as mz_ctree
from lzero.mcts.ctree.ctree_gumbel_muzero import gmz_tree as gmz_ctree

from line_profiler import line_profiler

class UniZeroMCTSCtree(object):
"""
Expand Down Expand Up @@ -71,7 +72,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m
from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree
return ctree.Roots(active_collect_env_num, legal_actions)

# @profile
#@profile
def search(
self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int,
List[Any]], task_id=None
Expand Down Expand Up @@ -225,7 +226,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "m
from lzero.mcts.ctree.ctree_muzero import mz_tree as ctree
return ctree.Roots(active_collect_env_num, legal_actions)

# @profile
# #@profile
def search(
self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any], to_play_batch: Union[int,
List[Any]]
Expand Down Expand Up @@ -494,7 +495,7 @@ def roots(cls: int, active_collect_env_num: int, legal_actions: List[Any]) -> "e
"""
return tree_muzero.Roots(active_collect_env_num, legal_actions)

# @profile
# #@profile
def search(
self, roots: Any, model: torch.nn.Module, latent_state_roots: List[Any],
world_model_latent_history_roots: List[Any], to_play_batch: Union[int, List[Any]], ready_env_id=None,
Expand Down
4 changes: 4 additions & 0 deletions lzero/model/unizero_model_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@
from .unizero_world_models.tokenizer import Tokenizer
from .unizero_world_models.world_model_multitask import WorldModelMT

from line_profiler import line_profiler

# use ModelRegistry to register the model, for more details about ModelRegistry, please refer to DI-engine's document.
@MODEL_REGISTRY.register('UniZeroMTModel')
class UniZeroMTModel(nn.Module):

#@profile
def __init__(
self,
observation_shape: SequenceType = (4, 64, 64),
Expand Down Expand Up @@ -162,6 +164,7 @@ def __init__(
print(f'{sum(p.numel() for p in self.tokenizer.decoder_network.parameters())} parameters in agent.tokenizer.decoder_network')
print('==' * 20)

#@profile
def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_obs_batch=None, task_id=None) -> MZNetworkOutput:
"""
Overview:
Expand Down Expand Up @@ -198,6 +201,7 @@ def initial_inference(self, obs_batch: torch.Tensor, action_batch=None, current_
latent_state,
)

#@profile
def recurrent_inference(self, state_action_history: torch.Tensor, simulation_index=0,
latent_state_index_in_search_path=[], task_id=None) -> MZNetworkOutput:
"""
Expand Down
11 changes: 11 additions & 0 deletions lzero/model/unizero_world_models/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,17 @@
from simple_parsing.helpers import Serializable
from torch import nn

# Modified from https://github.com/mistralai/mistral-inference/blob/main/src/mistral_inference/transformer.py#L108
class MultiplicationFeedForward(nn.Module):
def __init__(self, config):
super().__init__()

self.w1 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False)
self.w2 = nn.Linear(4 * config.embed_dim, config.embed_dim, bias=False)
self.w3 = nn.Linear(config.embed_dim, 4 * config.embed_dim, bias=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x)) # type: ignore

@dataclasses.dataclass
class MoeArgs(Serializable):
Expand Down
25 changes: 23 additions & 2 deletions lzero/model/unizero_world_models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from torch.nn import functional as F

from .kv_caching import KeysValues
from .moe import MoeLayer
from .moe import MoeLayer, MultiplicationFeedForward
from line_profiler import line_profiler


@dataclass
class TransformerConfig:
Expand Down Expand Up @@ -69,6 +71,7 @@ def generate_empty_keys_values(self, n: int, max_tokens: int) -> KeysValues:
device = self.ln_f.weight.device # Assumption: All submodules are on the same device
return KeysValues(n, self.config.num_heads, max_tokens, self.config.embed_dim, self.config.num_layers, device)

#@profile
def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Expand All @@ -91,6 +94,8 @@ def forward(self, sequences: torch.Tensor, past_keys_values: Optional[KeysValues
return x




class Block(nn.Module):
"""
Transformer block class.
Expand Down Expand Up @@ -122,7 +127,7 @@ def __init__(self, config: TransformerConfig) -> None:
self.ln2 = nn.LayerNorm(config.embed_dim)
self.attn = SelfAttention(config)
if config.moe_in_transformer:
# 创建多个独立的 MLP 实例
# 创Create multiple independent MLP instances
self.experts = nn.ModuleList([
nn.Sequential(
nn.Linear(config.embed_dim, 4 * config.embed_dim),
Expand All @@ -141,6 +146,21 @@ def __init__(self, config: TransformerConfig) -> None:
print("="*20)
print(f'use moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}')
print("="*20)
elif config.multiplication_moe_in_transformer:
# Create multiple FeedForward instances for multiplication-based MoE
self.experts = nn.ModuleList([
MultiplicationFeedForward(config) for _ in range(config.num_experts_of_moe_in_transformer)
])

self.feed_forward = MoeLayer(
experts=self.experts,
gate=nn.Linear(config.embed_dim, config.num_experts_of_moe_in_transformer, bias=False),
num_experts_per_tok=1,
)

print("="*20)
print(f'use multiplication moe in feed_forward of transformer, num of expert: {config.num_experts_of_moe_in_transformer}')
print("="*20)
else:
self.feed_forward = nn.Sequential(
nn.Linear(config.embed_dim, 4 * config.embed_dim),
Expand Down Expand Up @@ -209,6 +229,7 @@ def __init__(self, config: TransformerConfig) -> None:
causal_mask = torch.tril(torch.ones(config.max_tokens, config.max_tokens))
self.register_buffer('mask', causal_mask)

#@profile
def forward(self, x: torch.Tensor, kv_cache: Optional[KeysValues] = None,
valid_context_lengths: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Expand Down
Loading

0 comments on commit 2495d60

Please sign in to comment.