Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature(whl): Adapt unizero to jericho environments. #301

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
File renamed without changes.
13 changes: 13 additions & 0 deletions lzero/model/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,19 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
f"You should transform the observation shape to 64 or 96 in the env.")

return output


class HFLanguageRepresentationNetwork(nn.Module):
def __init__(self, url: str = 'google-bert/bert-base-uncased'):
super().__init__()
from transformers import AutoModel
self.model = AutoModel.from_pretrained(url)

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.long()
outputs = self.model(x)
# [batch_size, seq_len, hidden_size] -> [batch_size, hidden_size]
return outputs.last_hidden_state[:, 0, :]


class RepresentationNetworkUniZero(nn.Module):
Expand Down
12 changes: 11 additions & 1 deletion lzero/model/unizero_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from easydict import EasyDict

from .common import MZNetworkOutput, RepresentationNetworkUniZero, RepresentationNetworkMLP, LatentDecoder, \
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook
VectorDecoderForMemoryEnv, LatentEncoderForMemoryEnv, LatentDecoderForMemoryEnv, FeatureAndGradientHook, \
HFLanguageRepresentationNetwork
from .unizero_world_models.tokenizer import Tokenizer
from .unizero_world_models.world_model import WorldModel

Expand Down Expand Up @@ -87,6 +88,15 @@ def __init__(
print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer')
print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder')
print('==' * 20)
elif world_model_cfg.obs_type == 'text':
self.representation_network = HFLanguageRepresentationNetwork(url=kwargs['encoder_url'])
self.tokenizer = Tokenizer(encoder=self.representation_network, decoder_network=None, with_lpips=False,)
self.world_model = WorldModel(config=world_model_cfg, tokenizer=self.tokenizer)
print(f'{sum(p.numel() for p in self.world_model.parameters())} parameters in agent.world_model')
print('==' * 20)
print(f'{sum(p.numel() for p in self.world_model.transformer.parameters())} parameters in agent.world_model.transformer')
print(f'{sum(p.numel() for p in self.tokenizer.encoder.parameters())} parameters in agent.tokenizer.encoder')
print('==' * 20)
elif world_model_cfg.obs_type == 'image':
self.representation_network = RepresentationNetworkUniZero(
observation_shape,
Expand Down
2 changes: 1 addition & 1 deletion lzero/model/unizero_world_models/world_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,7 +1156,7 @@ def compute_loss(self, batch, target_tokenizer: Tokenizer = None, inverse_scalar
latent_recon_loss = self.latent_recon_loss
perceptual_loss = self.perceptual_loss

elif self.obs_type == 'vector':
elif self.obs_type == 'vector' or self.obs_type == 'text':
perceptual_loss = torch.tensor(0., device=batch['observations'].device,
dtype=batch['observations'].dtype)

Expand Down
Empty file added zoo/jericho/__init__.py
Empty file.
112 changes: 112 additions & 0 deletions zoo/jericho/configs/jericho_unizero_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
import os
from easydict import EasyDict


def main(env_id='zork1.z5', seed=0):
action_space_size = 50

# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
collector_env_num = 8
game_segment_length = 20
evaluator_env_num = 5
num_simulations = 50
max_env_step = int(5e5)
batch_size = 64
reanalyze_ratio = 0
num_unroll_steps = 10
infer_context_length = 4
num_layers = 2
replay_ratio = 0.25
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
jericho_unizero_config = dict(
env=dict(
stop_value=int(1e6),
max_steps=100,
max_action_num=action_space_size,
tokenizer_path="google-bert/bert-base-uncased",
max_seq_len=512,
game_path="z-machine-games-master/jericho-game-suite/" + env_id,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
manager=dict(shared_memory=False, )
),
policy=dict(
# default is 10000
learn=dict(learner=dict(
hook=dict(save_ckpt_after_iter=1000000, ), ), ),
model=dict(
observation_shape=512,
action_space_size=action_space_size,
encoder_url='google-bert/bert-base-uncased',
# The input of the model is text, whose shape is identical to the mlp model.
model_type='mlp',
world_model_cfg=dict(
policy_entropy_weight=5e-3,
continuous_action_space=False,
max_blocks=num_unroll_steps,
# NOTE: each timestep has 2 tokens: obs and action
max_tokens=2 * num_unroll_steps,
context_length=2 * infer_context_length,
device='cuda',
action_space_size=action_space_size,
num_layers=num_layers,
num_heads=8,
embed_dim=768,
obs_type='text', # TODO: Change it.
env_num=max(collector_env_num, evaluator_env_num),
),
),
model_path=None,
num_unroll_steps=num_unroll_steps,
replay_ratio=replay_ratio,
batch_size=batch_size,
learning_rate=0.0001,
num_simulations=num_simulations,
train_start_after_envsteps=2000,
game_segment_length=game_segment_length,
replay_buffer_size=int(1e6),
reanalyze_ratio=reanalyze_ratio,
eval_freq=int(5e3),
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
),
)
jericho_unizero_config = EasyDict(jericho_unizero_config)

jericho_unizero_create_config = dict(
env=dict(
type='jericho',
import_names=['zoo.jericho.envs.jericho_env'],
),
# NOTE: use base env manager to avoid the bug of subprocess env manager.
env_manager=dict(type='base'),
policy=dict(
type='unizero',
import_names=['lzero.policy.unizero'],
),
)
jericho_unizero_create_config = EasyDict(jericho_unizero_create_config)
main_config = jericho_unizero_config
create_config = jericho_unizero_create_config

main_config.exp_name = f'data_unizero/{env_id[:-14]}/{env_id[:-14]}_uz_nlayer{num_layers}_gsl{game_segment_length}_rr{replay_ratio}_Htrain{num_unroll_steps}-Hinfer{infer_context_length}_bs{batch_size}_seed{seed}'
from lzero.entry import train_unizero
train_unizero([main_config, create_config], seed=seed,
model_path=main_config.policy.model_path, max_env_step=max_env_step)


if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Process some environment.')
parser.add_argument('--env', type=str,
help='The environment to use', default='detective.z5')
parser.add_argument('--seed', type=int, help='The seed to use', default=0)
args = parser.parse_args()

os.environ['TOKENIZERS_PARALLELISM'] = 'false'
main(args.env, args.seed)
1 change: 1 addition & 0 deletions zoo/jericho/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .jericho_env import JerichoEnv
134 changes: 134 additions & 0 deletions zoo/jericho/envs/jericho_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
import copy
from typing import List

import gym
import numpy as np
from transformers import AutoTokenizer
from ding.utils import ENV_REGISTRY
from ding.envs import BaseEnv, BaseEnvTimestep
from jericho import FrotzEnv


@ENV_REGISTRY.register('jericho')
class JerichoEnv(BaseEnv):
"""
Overview:
The environment for Jericho games. For more details about the game, please refer to the \
`Jericho <https://github.com/microsoft/GameZero/tree/main/zoo/jericho>`.
"""
tokenizer = None

def __init__(self, cfg):
self.cfg = cfg
self.max_steps = cfg.max_steps
self.game_path = cfg.game_path
self.max_action_num = cfg.max_action_num
self.max_seq_len = cfg.max_seq_len

if JerichoEnv.tokenizer is None:
JerichoEnv.tokenizer = AutoTokenizer.from_pretrained(cfg.tokenizer_path)

self._env = FrotzEnv(self.game_path)
self._action_list = None
self.finished = False
self._init_flag = False
self.episode_return = 0
self.env_step = 0

self.observation_space = gym.spaces.Dict()
self.action_space = gym.spaces.Discrete(self.max_action_num)
self.reward_space = gym.spaces.Box(
low=-np.inf, high=np.inf, shape=(1, ), dtype=np.float32)

def prepare_obs(self, obs, return_str: bool = False):
if self._action_list is None:
self._action_list = self._env.get_valid_actions()
full_obs = obs + "\nValid actions: " + str(self._action_list)
if not return_str:
full_obs = JerichoEnv.tokenizer(
[full_obs], truncation=True, padding="max_length", max_length=self.max_seq_len)
full_obs = np.array(full_obs['input_ids'][0], dtype=np.int32)
action_mask = [1] * len(self._action_list) + [0] * \
(self.max_action_num - len(self._action_list))
action_mask = np.array(action_mask, dtype=np.int8)
return {'observation': full_obs, 'action_mask': action_mask, 'to_play': -1}

def reset(self, return_str: bool = False):
initial_observation, info = self._env.reset()
self.finished = False
self._init_flag = True
self._action_list = None
self.episode_return = 0
self.env_step = 0

return self.prepare_obs(initial_observation, return_str)

def seed(self, seed: int, dynamic_seed: bool = True) -> None:
"""
Set the seed for the environment.
"""
self._seed = seed
self._env.seed(seed)

def close(self) -> None:
self._init_flag = False

def __repr__(self) -> str:
return "LightZero Jericho Env"

def step(self, action: int, return_str: bool = False):
action_str = self._action_list[action]
observation, reward, done, info = self._env.step(action_str)
self.env_step += 1
self.episode_return += reward
self._action_list = None
observation = self.prepare_obs(observation, return_str)

if self.env_step >= self.max_steps:
done = True

if done:
self.finished = True
info['eval_episode_return'] = self.episode_return

return BaseEnvTimestep(observation, reward, done, info)

@staticmethod
def create_collector_env_cfg(cfg: dict) -> List[dict]:
collector_env_num = cfg.pop('collector_env_num')
cfg = copy.deepcopy(cfg)
# when in collect phase, sometimes we need to normalize the reward
# reward_normalize is determined by the config.
cfg.is_collect = True
return [cfg for _ in range(collector_env_num)]

@staticmethod
def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
evaluator_env_num = cfg.pop('evaluator_env_num')
cfg = copy.deepcopy(cfg)
# when in evaluate phase, we don't need to normalize the reward.
cfg.reward_normalize = False
cfg.is_collect = False
return [cfg for _ in range(evaluator_env_num)]


if __name__ == '__main__':
from easydict import EasyDict
env_cfg = EasyDict(
dict(
game_path="z-machine-games-master/jericho-game-suite/zork1.z5",
max_action_num=50,
max_env_step=100,
tokenizer_path="google-bert/bert-base-uncased",
max_seq_len=512
)
)
env = JerichoEnv(env_cfg)
obs = env.reset(return_str=True)
print(f'[OBS]:\n{obs["observation"]}')
while True:
action_id = int(input('Please input the action id:'))
obs, reward, done, info = env.step(action_id, return_str=True)
print(f'[OBS]:\n{obs["observation"]}')
if done:
break
41 changes: 41 additions & 0 deletions zoo/jericho/envs/test_jericho_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from easydict import EasyDict
from .jericho_env import JerichoEnv
import numpy as np
import pytest


@pytest.mark.unittest
class TestJerichoEnv():
def setup(self) -> None:
# Configuration for the Jericho environment
cfg = EasyDict(
dict(
game_path="z-machine-games-master/jericho-game-suite/zork1.z5",
max_action_num=50,
tokenizer_path="google-bert/bert-base-uncased",
max_seq_len=512
)
)
# Create a Jericho environment that will be used in the following tests.
self.env = JerichoEnv(cfg)

# Test the initialization of the Jericho environment.
def test_initialization(self):
assert isinstance(self.env, JerichoEnv)

# Test the reset method of the Jericho environment.
# Ensure that the shape of the observation is as expected.
def test_reset(self):
obs = self.env.reset()
assert obs['observation'].shape == (512,)

# Test the step method of the Jericho environment.
# Ensure that the shape of the observation, the type of the reward,
# the type of the done flag and the type of the info are as expected.
def test_step_shape(self):
self.env.reset()
obs, reward, done, info = self.env.step(1)
assert obs['observation'].shape == (512,)
assert isinstance(reward, np.ndarray)
assert isinstance(done, bool)
assert isinstance(info, dict)