From 060c815bba8bd0c2727ea139cb08436196a7071d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Sun, 1 Dec 2024 10:31:53 +0800 Subject: [PATCH 1/7] init commit --- .../_templates/{layout.html => layout.html} | 0 zoo/jericho/__init__.py | 0 zoo/jericho/configs/jericho_unizero_config.py | 103 ++++++++++++++++ zoo/jericho/envs/__init__.py | 1 + zoo/jericho/envs/jericho_env.py | 112 ++++++++++++++++++ zoo/jericho/envs/test_jericho_env.py | 41 +++++++ 6 files changed, 257 insertions(+) rename docs/source/_templates/{layout.html => layout.html} (100%) create mode 100644 zoo/jericho/__init__.py create mode 100644 zoo/jericho/configs/jericho_unizero_config.py create mode 100644 zoo/jericho/envs/__init__.py create mode 100644 zoo/jericho/envs/jericho_env.py create mode 100644 zoo/jericho/envs/test_jericho_env.py diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html similarity index 100% rename from docs/source/_templates/layout.html rename to docs/source/_templates/layout.html diff --git a/zoo/jericho/__init__.py b/zoo/jericho/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py new file mode 100644 index 000000000..3e63756fd --- /dev/null +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -0,0 +1,103 @@ +from easydict import EasyDict + + +def main(env_id='zork1.z5', seed=0): + action_space_size = 50 + + # ============================================================== + # begin of the most frequently changed config specified by the user + # ============================================================== + collector_env_num = 8 + game_segment_length = 20 + evaluator_env_num = 5 + num_simulations = 50 + max_env_step = int(5e5) + batch_size = 64 + num_unroll_steps = 10 + infer_context_length = 4 + num_layers = 2 + replay_ratio = 0.25 + # ============================================================== + # end of the most frequently changed config specified by the user + # ============================================================== + jericho_unizero_config = dict( + env=dict( + stop_value=int(1e6), + max_action_num=action_space_size, + tokenizer_path="google-bert/bert-base-uncased", + max_seq_len=512, + game_path="z-machine-games-master/jericho-game-suite/" + env_id, + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + n_evaluator_episode=evaluator_env_num, + manager=dict(shared_memory=False, ) + ), + policy=dict( + # default is 10000 + learn=dict(learner=dict( + hook=dict(save_ckpt_after_iter=1000000, ), ), ), + model=dict( + observation_shape=512, + action_space_size=action_space_size, + world_model_cfg=dict( + policy_entropy_weight=5e-3, + continuous_action_space=False, + max_blocks=num_unroll_steps, + # NOTE: each timestep has 2 tokens: obs and action + max_tokens=2 * num_unroll_steps, + context_length=2 * infer_context_length, + device='cuda', + action_space_size=action_space_size, + num_layers=num_layers, + num_heads=8, + embed_dim=768, + obs_type='vector', # TODO: Change it. + env_num=max(collector_env_num, evaluator_env_num), + ), + ), + model_path=None, + num_unroll_steps=num_unroll_steps, + replay_ratio=replay_ratio, + batch_size=batch_size, + learning_rate=0.0001, + num_simulations=num_simulations, + train_start_after_envsteps=2000, + game_segment_length=game_segment_length, + replay_buffer_size=int(1e6), + eval_freq=int(5e3), + collector_env_num=collector_env_num, + evaluator_env_num=evaluator_env_num, + ), + ) + jericho_unizero_config = EasyDict(jericho_unizero_config) + + jericho_unizero_create_config = dict( + env=dict( + type='jericho_lightzero', + import_names=['zoo.jericho.envs.jericho_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict( + type='unizero', + import_names=['lzero.policy.unizero'], + ), + ) + jericho_unizero_create_config = EasyDict(jericho_unizero_create_config) + main_config = jericho_unizero_config + create_config = jericho_unizero_create_config + + main_config.exp_name = f'data_unizero/{env_id[:-14]}/{env_id[:-14]}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}' + from lzero.entry import train_unizero + train_unizero([main_config, create_config], seed=seed, + model_path=main_config.policy.model_path, max_env_step=max_env_step) + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='Process some environment.') + parser.add_argument('--env', type=str, + help='The environment to use', default='zork1.z5') + parser.add_argument('--seed', type=int, help='The seed to use', default=0) + args = parser.parse_args() + + main(args.env, args.seed) diff --git a/zoo/jericho/envs/__init__.py b/zoo/jericho/envs/__init__.py new file mode 100644 index 000000000..740dab512 --- /dev/null +++ b/zoo/jericho/envs/__init__.py @@ -0,0 +1 @@ +from .jericho_env import JerichoEnv diff --git a/zoo/jericho/envs/jericho_env.py b/zoo/jericho/envs/jericho_env.py new file mode 100644 index 000000000..aa38a3d3a --- /dev/null +++ b/zoo/jericho/envs/jericho_env.py @@ -0,0 +1,112 @@ +import copy +from typing import List + +import gym +import numpy as np +from transformers import AutoTokenizer +from ding.utils import ENV_REGISTRY +from ding.envs import BaseEnv, BaseEnvTimestep +from jericho import FrotzEnv + + +@ENV_REGISTRY.register('jericho') +class JerichoEnv(BaseEnv): + + def __init__(self, cfg): + self.cfg = cfg + self.game_path = cfg.game_path + self.max_action_num = cfg.max_action_num + self.max_seq_len = cfg.max_seq_len + self.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path) + + self._env = FrotzEnv(self.game_path) + self._action_list = None + + self.finished = False + self._init_flag = False + self.episode_return = 0 + + self.observation_space = gym.spaces.Dict() + self.action_space = gym.spaces.Discrete(self.max_action_num) + self.reward_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32) + + def prepare_obs(self, obs, return_str: bool = False): + if self._action_list is None: + self._action_list = self._env.get_valid_actions() + full_obs = obs + "\nValid actions: " + str(self._action_list) + if not return_str: + full_obs = self.tokenizer([full_obs], truncation=True, padding=True, max_length=self.max_seq_len) + full_obs = np.array(full_obs['input_ids'][0], dtype=np.int32) + action_mask = [1] * len(self._action_list) + [0] * (self.max_action_num - len(self._action_list)) + action_mask = np.array(action_mask, dtype=np.int8) + return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1} + + def reset(self, return_str: bool = False): + initial_observation, info = self._env.reset() + self.episode_return = 0 + return self.prepare_obs(initial_observation, return_str) + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + """ + Set the seed for the environment. + """ + self._seed = seed + self._env.seed(seed) + + def close(self) -> None: + self._init_flag = False + + def __repr__(self) -> str: + return "LightZero Jericho Env" + + def step(self, action: int, return_str: bool = False): + action_str = self._action_list[action] + observation, reward, done, info = self._env.step(action_str) + self.episode_return += reward + self._action_list = None + observation = self.prepare_obs(observation, return_str) + + if done: + self.finished = True + info['eval_episode_return'] = self.episode_return + + return BaseEnvTimestep(observation, reward, done, info) + + @staticmethod + def create_collector_env_cfg(cfg: dict) -> List[dict]: + collector_env_num = cfg.pop('collector_env_num') + cfg = copy.deepcopy(cfg) + # when in collect phase, sometimes we need to normalize the reward + # reward_normalize is determined by the config. + cfg.is_collect = True + return [cfg for _ in range(collector_env_num)] + + @staticmethod + def create_evaluator_env_cfg(cfg: dict) -> List[dict]: + evaluator_env_num = cfg.pop('evaluator_env_num') + cfg = copy.deepcopy(cfg) + # when in evaluate phase, we don't need to normalize the reward. + cfg.reward_normalize = False + cfg.is_collect = False + return [cfg for _ in range(evaluator_env_num)] + + +if __name__ == '__main__': + from easydict import EasyDict + env_cfg = EasyDict( + dict( + game_path="z-machine-games-master/jericho-game-suite/zork1.z5", + max_action_num=50, + tokenizer_path="google-bert/bert-base-uncased", + max_seq_len=512 + ) + ) + env = JerichoEnv(env_cfg) + obs = env.reset(return_str=True) + print(f'[OBS]:\n{obs["observation"]}') + while True: + action_id = int(input('Please input the action id:')) + obs, reward, done, info = env.step(action_id, return_str=True) + print(f'[OBS]:\n{obs["observation"]}') + if done: + break diff --git a/zoo/jericho/envs/test_jericho_env.py b/zoo/jericho/envs/test_jericho_env.py new file mode 100644 index 000000000..28db93b53 --- /dev/null +++ b/zoo/jericho/envs/test_jericho_env.py @@ -0,0 +1,41 @@ +from easydict import EasyDict +from .jericho_env import JerichoEnv +import numpy as np +import pytest + + +@pytest.mark.unittest +class TestJerichoEnv(): + def setup(self) -> None: + # Configuration for the Jericho environment + cfg = EasyDict( + dict( + game_path="z-machine-games-master/jericho-game-suite/zork1.z5", + max_action_num=50, + tokenizer_path="google-bert/bert-base-uncased", + max_seq_len=512 + ) + ) + # Create a Jericho environment that will be used in the following tests. + self.env = JerichoEnv(cfg) + + # Test the initialization of the Jericho environment. + def test_initialization(self): + assert isinstance(self.env, JerichoEnv) + + # Test the reset method of the Jericho environment. + # Ensure that the shape of the observation is as expected. + def test_reset(self): + obs = self.env.reset() + assert obs['observation'].shape == (512,) + + # Test the step method of the Jericho environment. + # Ensure that the shape of the observation, the type of the reward, + # the type of the done flag and the type of the info are as expected. + def test_step_shape(self): + self.env.reset() + obs, reward, done, info = self.env.step(1) + assert obs['observation'].shape == (512,) + assert isinstance(reward, np.ndarray) + assert isinstance(done, bool) + assert isinstance(info, dict) From 13521bf23e4f692d19d011169b23186eda453430 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Sun, 1 Dec 2024 11:29:22 +0800 Subject: [PATCH 2/7] add bert encoding --- lzero/model/common.py | 12 ++++++++++++ lzero/model/unizero_model.py | 12 +++++++++++- lzero/model/unizero_world_models/world_model.py | 2 +- zoo/jericho/configs/jericho_unizero_config.py | 9 ++++++--- zoo/jericho/envs/jericho_env.py | 15 +++++++++------ 5 files changed, 39 insertions(+), 11 deletions(-) diff --git a/lzero/model/common.py b/lzero/model/common.py index 22afa95fe..035158b5b 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -272,6 +272,18 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: f"You should transform the observation shape to 64 or 96 in the env.") return output + + +class HFLanguageRepresentationNetwork(nn.Module): + def __init__(self, url: str = 'google-bert/bert-base-uncased'): + super().__init__() + from transformers import AutoModel + self.model = AutoModel.from_pretrained(url) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + outputs = self.model(x) + # [batch_size, seq_len, hidden_size] -> [batch_size, hidden_size] + return outputs.last_hidden_state[:, 0, :] class RepresentationNetworkUniZero(nn.Module): diff --git a/lzero/model/unizero_model.py b/lzero/model/unizero_model.py index e28322215..2563468c3 100644 --- a/lzero/model/unizero_model.py +++ b/lzero/model/unizero_model.py @@ -6,7 +6,8 @@ from easydict import EasyDict from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \ - VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook + VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \ + HFLanguageRepresentationNetwork from .unizero_world_models.tokenizer import Tokenizer from .unizero_world_models.world_model import WorldModel @@ -87,6 +88,15 @@ def __init__( print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') print('==' * 20) + elif world_model_cfg.obs_type == 'text': + self.representation_network = HFLanguageRepresentationNetwork(url=kwargs['encoder_url']) + self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False,) + self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer) + print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model') + print('==' * 20) + print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer') + print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder') + print('==' * 20) elif world_model_cfg.obs_type == 'image': self.representation_network = RepresentationNetworkUniZero( observation_shape, diff --git a/lzero/model/unizero_world_models/world_model.py b/lzero/model/unizero_world_models/world_model.py index 37d4cd3ec..d112e4d95 100644 --- a/lzero/model/unizero_world_models/world_model.py +++ b/lzero/model/unizero_world_models/world_model.py @@ -1156,7 +1156,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar latent_recon_loss = self.latent_recon_loss perceptual_loss = self.perceptual_loss - elif self.obs_type == 'vector': + elif self.obs_type == 'vector' or self.obs_type == 'text': perceptual_loss = torch.tensor(0., device=batch['observations'].device, dtype=batch['observations'].dtype) diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index 3e63756fd..ecbee916a 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -37,8 +37,10 @@ def main(env_id='zork1.z5', seed=0): learn=dict(learner=dict( hook=dict(save_ckpt_after_iter=1000000, ), ), ), model=dict( - observation_shape=512, + observation_shape=(512,), action_space_size=action_space_size, + model_type='text', + encoder_url='google-bert/bert-base-uncased', world_model_cfg=dict( policy_entropy_weight=5e-3, continuous_action_space=False, @@ -73,10 +75,11 @@ def main(env_id='zork1.z5', seed=0): jericho_unizero_create_config = dict( env=dict( - type='jericho_lightzero', + type='jericho', import_names=['zoo.jericho.envs.jericho_env'], ), - env_manager=dict(type='subprocess'), + # NOTE: use base env manager to avoid the bug of subprocess env manager. + env_manager=dict(type='base'), policy=dict( type='unizero', import_names=['lzero.policy.unizero'], diff --git a/zoo/jericho/envs/jericho_env.py b/zoo/jericho/envs/jericho_env.py index aa38a3d3a..eb7f16c71 100644 --- a/zoo/jericho/envs/jericho_env.py +++ b/zoo/jericho/envs/jericho_env.py @@ -28,19 +28,22 @@ def __init__(self, cfg): self.observation_space = gym.spaces.Dict() self.action_space = gym.spaces.Discrete(self.max_action_num) - self.reward_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32) - + self.reward_space = gym.spaces.Box( + low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32) + def prepare_obs(self, obs, return_str: bool = False): if self._action_list is None: self._action_list = self._env.get_valid_actions() full_obs = obs + "\nValid actions: " + str(self._action_list) if not return_str: - full_obs = self.tokenizer([full_obs], truncation=True, padding=True, max_length=self.max_seq_len) + full_obs = self.tokenizer( + [full_obs], truncation=True, padding="max_length", max_length=self.max_seq_len) full_obs = np.array(full_obs['input_ids'][0], dtype=np.int32) - action_mask = [1] * len(self._action_list) + [0] * (self.max_action_num - len(self._action_list)) + action_mask = [1] * len(self._action_list) + [0] * \ + (self.max_action_num - len(self._action_list)) action_mask = np.array(action_mask, dtype=np.int8) return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1} - + def reset(self, return_str: bool = False): initial_observation, info = self._env.reset() self.episode_return = 0 @@ -71,7 +74,7 @@ def step(self, action: int, return_str: bool = False): info['eval_episode_return'] = self.episode_return return BaseEnvTimestep(observation, reward, done, info) - + @staticmethod def create_collector_env_cfg(cfg: dict) -> List[dict]: collector_env_num = cfg.pop('collector_env_num') From 783e4d2a5279510891eec36e8b2b396188e75d9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Sun, 1 Dec 2024 11:34:23 +0800 Subject: [PATCH 3/7] debug --- zoo/jericho/configs/jericho_unizero_config.py | 3 +-- zoo/jericho/envs/jericho_env.py | 7 +++++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index ecbee916a..2e92e2a43 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -39,7 +39,6 @@ def main(env_id='zork1.z5', seed=0): model=dict( observation_shape=(512,), action_space_size=action_space_size, - model_type='text', encoder_url='google-bert/bert-base-uncased', world_model_cfg=dict( policy_entropy_weight=5e-3, @@ -53,7 +52,7 @@ def main(env_id='zork1.z5', seed=0): num_layers=num_layers, num_heads=8, embed_dim=768, - obs_type='vector', # TODO: Change it. + obs_type='text', # TODO: Change it. env_num=max(collector_env_num, evaluator_env_num), ), ), diff --git a/zoo/jericho/envs/jericho_env.py b/zoo/jericho/envs/jericho_env.py index eb7f16c71..ab9423c19 100644 --- a/zoo/jericho/envs/jericho_env.py +++ b/zoo/jericho/envs/jericho_env.py @@ -11,13 +11,16 @@ @ENV_REGISTRY.register('jericho') class JerichoEnv(BaseEnv): + tokenizer = None def __init__(self, cfg): self.cfg = cfg self.game_path = cfg.game_path self.max_action_num = cfg.max_action_num self.max_seq_len = cfg.max_seq_len - self.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path) + + if JerichoEnv.tokenizer is None: + JerichoEnv.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path) self._env = FrotzEnv(self.game_path) self._action_list = None @@ -36,7 +39,7 @@ def prepare_obs(self, obs, return_str: bool = False): self._action_list = self._env.get_valid_actions() full_obs = obs + "\nValid actions: " + str(self._action_list) if not return_str: - full_obs = self.tokenizer( + full_obs = JerichoEnv.tokenizer( [full_obs], truncation=True, padding="max_length", max_length=self.max_seq_len) full_obs = np.array(full_obs['input_ids'][0], dtype=np.int32) action_mask = [1] * len(self._action_list) + [0] * \ From aa3a324f5632d8608344dd1e652c2bed1eed053d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Sun, 1 Dec 2024 12:28:32 +0800 Subject: [PATCH 4/7] debug and polish env --- lzero/model/common.py | 1 + zoo/jericho/configs/jericho_unizero_config.py | 6 +++++- zoo/jericho/envs/jericho_env.py | 12 +++++++++++- 3 files changed, 17 insertions(+), 2 deletions(-) diff --git a/lzero/model/common.py b/lzero/model/common.py index 035158b5b..774ab16f8 100644 --- a/lzero/model/common.py +++ b/lzero/model/common.py @@ -281,6 +281,7 @@ def __init__(self, url: str = 'google-bert/bert-base-uncased'): self.model = AutoModel.from_pretrained(url) def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.long() outputs = self.model(x) # [batch_size, seq_len, hidden_size] -> [batch_size, hidden_size] return outputs.last_hidden_state[:, 0, :] diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index 2e92e2a43..2ef18d023 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -1,3 +1,4 @@ +import os from easydict import EasyDict @@ -37,9 +38,11 @@ def main(env_id='zork1.z5', seed=0): learn=dict(learner=dict( hook=dict(save_ckpt_after_iter=1000000, ), ), ), model=dict( - observation_shape=(512,), + observation_shape=512, action_space_size=action_space_size, encoder_url='google-bert/bert-base-uncased', + # The input of the model is text, whose shape is identical to the mlp model. + model_type='mlp', world_model_cfg=dict( policy_entropy_weight=5e-3, continuous_action_space=False, @@ -102,4 +105,5 @@ def main(env_id='zork1.z5', seed=0): parser.add_argument('--seed', type=int, help='The seed to use', default=0) args = parser.parse_args() + os.environ['TOKENIZERS_PARALLELISM'] = 'false' main(args.env, args.seed) diff --git a/zoo/jericho/envs/jericho_env.py b/zoo/jericho/envs/jericho_env.py index ab9423c19..b56ef7493 100644 --- a/zoo/jericho/envs/jericho_env.py +++ b/zoo/jericho/envs/jericho_env.py @@ -15,6 +15,7 @@ class JerichoEnv(BaseEnv): def __init__(self, cfg): self.cfg = cfg + self.max_env_num = cfg.max_env_num self.game_path = cfg.game_path self.max_action_num = cfg.max_action_num self.max_seq_len = cfg.max_seq_len @@ -24,10 +25,10 @@ def __init__(self, cfg): self._env = FrotzEnv(self.game_path) self._action_list = None - self.finished = False self._init_flag = False self.episode_return = 0 + self.env_step = 0 self.observation_space = gym.spaces.Dict() self.action_space = gym.spaces.Discrete(self.max_action_num) @@ -49,7 +50,12 @@ def prepare_obs(self, obs, return_str: bool = False): def reset(self, return_str: bool = False): initial_observation, info = self._env.reset() + self.finished = False + self._init_flag = True + self._action_list = None self.episode_return = 0 + self.env_step = 0 + return self.prepare_obs(initial_observation, return_str) def seed(self, seed: int, dynamic_seed: bool = True) -> None: @@ -68,10 +74,14 @@ def __repr__(self) -> str: def step(self, action: int, return_str: bool = False): action_str = self._action_list[action] observation, reward, done, info = self._env.step(action_str) + self.env_step += 1 self.episode_return += reward self._action_list = None observation = self.prepare_obs(observation, return_str) + if self.env_step >= self.max_env_step: + done = True + if done: self.finished = True info['eval_episode_return'] = self.episode_return From 5c984b07809b2a7104f077594e9951ad1f79468c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Sun, 1 Dec 2024 12:29:17 +0800 Subject: [PATCH 5/7] polish config --- zoo/jericho/configs/jericho_unizero_config.py | 1 + zoo/jericho/envs/jericho_env.py | 1 + 2 files changed, 2 insertions(+) diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index 2ef18d023..96ee7f2da 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -24,6 +24,7 @@ def main(env_id='zork1.z5', seed=0): jericho_unizero_config = dict( env=dict( stop_value=int(1e6), + max_env_step=100, max_action_num=action_space_size, tokenizer_path="google-bert/bert-base-uncased", max_seq_len=512, diff --git a/zoo/jericho/envs/jericho_env.py b/zoo/jericho/envs/jericho_env.py index b56ef7493..2034ea64b 100644 --- a/zoo/jericho/envs/jericho_env.py +++ b/zoo/jericho/envs/jericho_env.py @@ -113,6 +113,7 @@ def create_evaluator_env_cfg(cfg: dict) -> List[dict]: dict( game_path="z-machine-games-master/jericho-game-suite/zork1.z5", max_action_num=50, + max_env_step=100, tokenizer_path="google-bert/bert-base-uncased", max_seq_len=512 ) From 4994c292ba45e58b9af3ff4240e77569f266b3c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Sun, 1 Dec 2024 12:36:25 +0800 Subject: [PATCH 6/7] update --- zoo/jericho/configs/jericho_unizero_config.py | 2 +- zoo/jericho/envs/jericho_env.py | 9 +++++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index 96ee7f2da..923c32d25 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -24,7 +24,7 @@ def main(env_id='zork1.z5', seed=0): jericho_unizero_config = dict( env=dict( stop_value=int(1e6), - max_env_step=100, + max_steps=100, max_action_num=action_space_size, tokenizer_path="google-bert/bert-base-uncased", max_seq_len=512, diff --git a/zoo/jericho/envs/jericho_env.py b/zoo/jericho/envs/jericho_env.py index 2034ea64b..73400b59b 100644 --- a/zoo/jericho/envs/jericho_env.py +++ b/zoo/jericho/envs/jericho_env.py @@ -11,11 +11,16 @@ @ENV_REGISTRY.register('jericho') class JerichoEnv(BaseEnv): + """ + Overview: + The environment for Jericho games. For more details about the game, please refer to the \ + `Jericho `. + """ tokenizer = None def __init__(self, cfg): self.cfg = cfg - self.max_env_num = cfg.max_env_num + self.max_steps = cfg.max_steps self.game_path = cfg.game_path self.max_action_num = cfg.max_action_num self.max_seq_len = cfg.max_seq_len @@ -79,7 +84,7 @@ def step(self, action: int, return_str: bool = False): self._action_list = None observation = self.prepare_obs(observation, return_str) - if self.env_step >= self.max_env_step: + if self.env_step >= self.max_steps: done = True if done: From 49f88c4ee1771fb7bcee57735849a88bf46af3af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98whl=E2=80=99?= <‘18231213@buaa.edu.cn’> Date: Tue, 3 Dec 2024 17:24:59 +0800 Subject: [PATCH 7/7] debug --- zoo/jericho/configs/jericho_unizero_config.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/zoo/jericho/configs/jericho_unizero_config.py b/zoo/jericho/configs/jericho_unizero_config.py index 923c32d25..ef448336b 100644 --- a/zoo/jericho/configs/jericho_unizero_config.py +++ b/zoo/jericho/configs/jericho_unizero_config.py @@ -14,6 +14,7 @@ def main(env_id='zork1.z5', seed=0): num_simulations = 50 max_env_step = int(5e5) batch_size = 64 + reanalyze_ratio = 0 num_unroll_steps = 10 infer_context_length = 4 num_layers = 2 @@ -69,6 +70,7 @@ def main(env_id='zork1.z5', seed=0): train_start_after_envsteps=2000, game_segment_length=game_segment_length, replay_buffer_size=int(1e6), + reanalyze_ratio=reanalyze_ratio, eval_freq=int(5e3), collector_env_num=collector_env_num, evaluator_env_num=evaluator_env_num, @@ -102,7 +104,7 @@ def main(env_id='zork1.z5', seed=0): import argparse parser = argparse.ArgumentParser(description='Process some environment.') parser.add_argument('--env', type=str, - help='The environment to use', default='zork1.z5') + help='The environment to use', default='detective.z5') parser.add_argument('--seed', type=int, help='The seed to use', default=0) args = parser.parse_args()