Skip to content

Commit

Permalink
polish(pu): polish jericho config and pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
dyyoungg committed Dec 3, 2024
1 parent 4994c29 commit ecc3d35
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 19 deletions.
2 changes: 1 addition & 1 deletion lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,7 @@ def _compute_target_policy_non_reanalyzed(
# for atari/classic_control/box2d environments that only have one player.
target_policies.append(distributions)
else:
# for board games that have two players.
# for board games that have two players or envs that have varied action space.
policy_tmp = [0 for _ in range(policy_shape)]
for index, legal_action in enumerate(legal_actions[policy_index]):
# only the action in ``legal_action`` the policy logits is nonzero
Expand Down
19 changes: 17 additions & 2 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,16 +275,31 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:


class HFLanguageRepresentationNetwork(nn.Module):
def __init__(self, url: str = 'google-bert/bert-base-uncased'):
def __init__(self, url: str = 'google-bert/bert-base-uncased', embedding_size: int = 768):
super().__init__()
from transformers import AutoModel
self.model = AutoModel.from_pretrained(url)

self.embedding_size = embedding_size
if self.embedding_size != 768:
self.embed_head = nn.Linear(768, self.embedding_size)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.long()

# print(f'='*20)
# print(f'google-bert/bert-base-uncased x.shape: {x.shape}')

outputs = self.model(x)
# [batch_size, seq_len, hidden_size] -> [batch_size, hidden_size]
return outputs.last_hidden_state[:, 0, :]

# print(f'google-bert/bert-base-uncased outputs.last_hidden_state.shape: {outputs.last_hidden_state.shape}')
# print(f'='*20)
if self.embedding_size == 768:
return outputs.last_hidden_state[:, 0, :]
else:
return self.embed_head(outputs.last_hidden_state[:, 0, :])



class RepresentationNetworkUniZero(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion lzero/model/unizero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder')
print('==' * 20)
elif world_model_cfg.obs_type == 'text':
self.representation_network = HFLanguageRepresentationNetwork(url=kwargs['encoder_url'])
self.representation_network = HFLanguageRepresentationNetwork(url=kwargs['encoder_url'], embedding_size=world_model_cfg.embed_dim)
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False,)
self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer)
print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')
Expand Down
14 changes: 13 additions & 1 deletion lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
latent_recon_loss = self.latent_recon_loss
perceptual_loss = self.perceptual_loss

elif self.obs_type == 'vector' or self.obs_type == 'text':
elif self.obs_type == 'vector':
perceptual_loss = torch.tensor(0., device=batch['observations'].device,
dtype=batch['observations'].dtype)

Expand All @@ -1168,6 +1168,18 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
# reconstructed_images)
latent_recon_loss = self.latent_recon_loss

elif self.obs_type == 'text':
perceptual_loss = torch.tensor(0., device=batch['observations'].device,
dtype=torch.float32)

# Reconstruct observations from latent state representations
# reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings.reshape(-1, self.embed_dim))

# # Calculate reconstruction loss
# latent_recon_loss = self.tokenizer.reconstruction_loss(batch['observations'].reshape(-1, 25),
# reconstructed_images)
latent_recon_loss = self.latent_recon_loss

elif self.obs_type == 'image_memory':
# Reconstruct observations from latent state representations
# reconstructed_images = self.tokenizer.decode_to_obs(obs_embeddings)
Expand Down
4 changes: 2 additions & 2 deletions lzero/policy/unizero.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,8 +433,8 @@ def _forward_learn(self, data: Tuple[torch.Tensor]) -> Dict[str, Union[float, in
dormant_ratio_world_model = self.intermediate_losses['dormant_ratio_world_model']
latent_state_l2_norms = self.intermediate_losses['latent_state_l2_norms']

assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values"
assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values"
# assert not torch.isnan(losses.loss_total).any(), "Loss contains NaN values"
# assert not torch.isinf(losses.loss_total).any(), "Loss contains Inf values"

# Core learn model update step
self._optimizer_world_model.zero_grad()
Expand Down
58 changes: 50 additions & 8 deletions zoo/jericho/configs/jericho_unizero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from easydict import EasyDict


def main(env_id='zork1.z5', seed=0):
def main(env_id='detective.z5', seed=0):
action_space_size = 50

# ==============================================================
Expand All @@ -11,24 +11,54 @@ def main(env_id='zork1.z5', seed=0):
collector_env_num = 8
game_segment_length = 20
evaluator_env_num = 5
num_segments = 8
num_simulations = 50
max_env_step = int(5e5)
batch_size = 64
num_unroll_steps = 10
infer_context_length = 4
num_layers = 2
replay_ratio = 0.25
embed_dim = 768
max_steps=100
# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
# buffer_reanalyze_freq = 1/10
buffer_reanalyze_freq = 1/100000
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
reanalyze_batch_size = 160
# The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
reanalyze_partition = 0.75

# =========== TODO: only for debug ===========
collector_env_num = 2
num_segments = 2
max_steps=20
game_segment_length = 20
evaluator_env_num = 2
num_simulations = 5
max_env_step = int(5e5)
batch_size = 10
num_unroll_steps = 5
infer_context_length = 2
num_layers = 1
replay_ratio = 0.05
# embed_dim = 768
embed_dim = 32

# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
jericho_unizero_config = dict(
env=dict(
stop_value=int(1e6),
max_steps=100,
observation_shape=512,
max_steps=max_steps,
max_action_num=action_space_size,
tokenizer_path="google-bert/bert-base-uncased",
# tokenizer_path="google-bert/bert-base-uncased",
tokenizer_path="/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594",
max_seq_len=512,
game_path="z-machine-games-master/jericho-game-suite/" + env_id,
# game_path="z-machine-games-master/jericho-game-suite/" + env_id,
game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/"+ env_id,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
Expand All @@ -41,7 +71,8 @@ def main(env_id='zork1.z5', seed=0):
model=dict(
observation_shape=512,
action_space_size=action_space_size,
encoder_url='google-bert/bert-base-uncased',
# encoder_url='google-bert/bert-base-uncased',
encoder_url='/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594',
# The input of the model is text, whose shape is identical to the mlp model.
model_type='mlp',
world_model_cfg=dict(
Expand All @@ -55,23 +86,33 @@ def main(env_id='zork1.z5', seed=0):
action_space_size=action_space_size,
num_layers=num_layers,
num_heads=8,
embed_dim=768,
embed_dim=embed_dim,
obs_type='text', # TODO: Change it.
env_num=max(collector_env_num, evaluator_env_num),
),
),
action_type = 'varied_action_space',
model_path=None,
num_unroll_steps=num_unroll_steps,
reanalyze_ratio=0,
replay_ratio=replay_ratio,
batch_size=batch_size,
learning_rate=0.0001,
num_simulations=num_simulations,
train_start_after_envsteps=2000,
num_segments=num_segments,
train_start_after_envsteps=0, # TODO
game_segment_length=game_segment_length,
replay_buffer_size=int(1e6),
eval_freq=int(5e3),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
# ============= The key different params for reanalyze =============
# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
buffer_reanalyze_freq=buffer_reanalyze_freq,
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
reanalyze_batch_size=reanalyze_batch_size,
# The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
reanalyze_partition=reanalyze_partition,
),
)
jericho_unizero_config = EasyDict(jericho_unizero_config)
Expand All @@ -83,6 +124,7 @@ def main(env_id='zork1.z5', seed=0):
),
# NOTE: use base env manager to avoid the bug of subprocess env manager.
env_manager=dict(type='base'),
# env_manager=dict(type='subprocess'),
policy=dict(
type='unizero',
import_names=['lzero.policy.unizero'],
Expand All @@ -102,7 +144,7 @@ def main(env_id='zork1.z5', seed=0):
import argparse
parser = argparse.ArgumentParser(description='Process some environment.')
parser.add_argument('--env', type=str,
help='The environment to use', default='zork1.z5')
help='The environment to use', default='detective.z5') # 'zork1.z5'
parser.add_argument('--seed', type=int, help='The seed to use', default=0)
args = parser.parse_args()

Expand Down
152 changes: 152 additions & 0 deletions zoo/jericho/configs/jericho_unizero_segment_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import os
from easydict import EasyDict


def main(env_id='detective.z5', seed=0):
action_space_size = 50

# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
collector_env_num = 8
game_segment_length = 20
evaluator_env_num = 5
num_segments = 8
num_simulations = 50
max_env_step = int(5e5)
batch_size = 64
num_unroll_steps = 10
infer_context_length = 4
num_layers = 2
replay_ratio = 0.25
embed_dim = 768
max_steps = 100
# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
# buffer_reanalyze_freq = 1/10
buffer_reanalyze_freq = 1/100000
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
reanalyze_batch_size = 160
# The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
reanalyze_partition = 0.75

# =========== TODO: only for debug ===========
collector_env_num = 2
num_segments = 2
max_steps=20
game_segment_length = 20
evaluator_env_num = 2
num_simulations = 5
max_env_step = int(5e5)
batch_size = 10
num_unroll_steps = 5
infer_context_length = 2
num_layers = 1
replay_ratio = 0.05
embed_dim = 32

# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
jericho_unizero_config = dict(
env=dict(
stop_value=int(1e6),
max_steps=max_steps,
observation_shape=512,
max_action_num=action_space_size,
# tokenizer_path="google-bert/bert-base-uncased",
tokenizer_path="/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594",
max_seq_len=512,
# game_path="z-machine-games-master/jericho-game-suite/" + env_id,
game_path="/mnt/afs/niuyazhe/code/LightZero/zoo/jericho/envs/z-machine-games-master/jericho-game-suite/"+ env_id,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, )
),
policy=dict(
# default is 10000
learn=dict(learner=dict(
hook=dict(save_ckpt_after_iter=1000000, ), ), ),
model=dict(
observation_shape=512,
action_space_size=action_space_size,
# encoder_url='google-bert/bert-base-uncased',
encoder_url='/mnt/afs/zhangshenghan/.cache/huggingface/hub/models--google-bert--bert-base-uncased/snapshots/86b5e0934494bd15c9632b12f734a8a67f723594',
# The input of the model is text, whose shape is identical to the mlp model.
model_type='mlp',
world_model_cfg=dict(
policy_entropy_weight=5e-3,
continuous_action_space=False,
max_blocks=num_unroll_steps,
# NOTE: each timestep has 2 tokens: obs and action
max_tokens=2 * num_unroll_steps,
context_length=2 * infer_context_length,
device='cuda',
action_space_size=action_space_size,
num_layers=num_layers,
num_heads=8,
embed_dim=embed_dim,
obs_type='text', # TODO: Change it.
env_num=max(collector_env_num, evaluator_env_num),
),
),
action_type='varied_action_space',
model_path=None,
num_unroll_steps=num_unroll_steps,
reanalyze_ratio=0,
replay_ratio=replay_ratio,
batch_size=batch_size,
learning_rate=0.0001,
num_simulations=num_simulations,
num_segments=num_segments,
# train_start_after_envsteps=2000,
train_start_after_envsteps=0, # TODO: only for debug
game_segment_length=game_segment_length,
replay_buffer_size=int(1e6),
eval_freq=int(5e3),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
# ============= The key different params for reanalyze =============
# Defines the frequency of reanalysis. E.g., 1 means reanalyze once per epoch, 2 means reanalyze once every two epochs.
buffer_reanalyze_freq=buffer_reanalyze_freq,
# Each reanalyze process will reanalyze <reanalyze_batch_size> sequences (<cfg.policy.num_unroll_steps> transitions per sequence)
reanalyze_batch_size=reanalyze_batch_size,
# The partition of reanalyze. E.g., 1 means reanalyze_batch samples from the whole buffer, 0.5 means samples from the first half of the buffer.
reanalyze_partition=reanalyze_partition,
),
)
jericho_unizero_config = EasyDict(jericho_unizero_config)

jericho_unizero_create_config = dict(
env=dict(
type='jericho',
import_names=['zoo.jericho.envs.jericho_env'],
),
# NOTE: use base env manager to avoid the bug of subprocess env manager.
env_manager=dict(type='base'),
# env_manager=dict(type='subprocess'),
policy=dict(
type='unizero',
import_names=['lzero.policy.unizero'],
),
)
jericho_unizero_create_config = EasyDict(jericho_unizero_create_config)
main_config = jericho_unizero_config
create_config = jericho_unizero_create_config

main_config.exp_name = f'data_unizero/{env_id[:-14]}/{env_id[:-14]}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
from lzero.entry import train_unizero_segment
train_unizero_segment([main_config, create_config], seed=seed,
model_path=main_config.policy.model_path, max_env_step=max_env_step)


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Process some environment.')
parser.add_argument('--env', type=str,
help='The environment to use', default='detective.z5') # 'zork1.z5'
parser.add_argument('--seed', type=int, help='The seed to use', default=0)
args = parser.parse_args()

os.environ['TOKENIZERS_PARALLELISM'] = 'false'
main(args.env, args.seed)
Loading

0 comments on commit ecc3d35

Please sign in to comment.