Skip to content

Commit

Permalink
fix(pu): fix root_sampled_actions_tmp shape bug in sez ptree
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Oct 17, 2023
1 parent 9e82968 commit e32e495
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 37 deletions.
38 changes: 18 additions & 20 deletions lzero/mcts/buffer/game_buffer_sampled_efficientzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from lzero.mcts.tree_search.mcts_ctree_sampled import SampledEfficientZeroMCTSCtree as MCTSCtree
from lzero.mcts.tree_search.mcts_ptree_sampled import SampledEfficientZeroMCTSPtree as MCTSPtree
from lzero.mcts.utils import prepare_observation
from lzero.mcts.utils import prepare_observation, generate_random_actions_discrete
from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform
from .game_buffer_efficientzero import EfficientZeroGameBuffer

Expand Down Expand Up @@ -161,27 +161,25 @@ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp))
]
else:
actions_tmp += [
np.random.randint(0, self._cfg.model.action_space_size, 1).item()
for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
]
if len(root_sampled_actions_tmp[0].shape) == 1:
root_sampled_actions_tmp += [
np.arange(self._cfg.model.action_space_size)
# NOTE: self._cfg.num_unroll_steps + 1
for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp))
]
else:
root_sampled_actions_tmp += [
np.random.randint(0, self._cfg.model.action_space_size,
self._cfg.model.num_of_sampled_actions).reshape(
self._cfg.model.num_of_sampled_actions, 1
) # NOTE: self._cfg.num_unroll_steps + 1
for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp))
]
# generate random `padded actions_tmp`
actions_tmp += generate_random_actions_discrete(
self._cfg.num_unroll_steps - len(actions_tmp),
self._cfg.model.action_space_size,
1 # Number of sampled actions for actions_tmp is 1
)

# generate random padded `root_sampled_actions_tmp`
# root_sampled_action have different shape in mcts_ctree and mcts_ptree, thus we need to pad differently
reshape = True if self._cfg.mcts_ctree else False
root_sampled_actions_tmp += generate_random_actions_discrete(
self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp),
self._cfg.model.action_space_size,
self._cfg.model.num_of_sampled_actions,
reshape=reshape
)

# obtain the input observations
# stack+num_unroll_steps 4+5
# stack+num_unroll_steps = 4+5
# pad if length of obs in game_segment is less than stack+num_unroll_steps
obs_list.append(
game_lst[i].get_unroll_obs(
Expand Down
29 changes: 29 additions & 0 deletions lzero/mcts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,35 @@
from graphviz import Digraph


def generate_random_actions_discrete(num_actions: int, action_space_size: int, num_of_sampled_actions: int,
reshape=False):
"""
Overview:
Generate a list of random actions.
Arguments:
- num_actions (:obj:`int`): The number of actions to generate.
- action_space_size (:obj:`int`): The size of the action space.
- num_of_sampled_actions (:obj:`int`): The number of sampled actions.
- reshape (:obj:`bool`): Whether to reshape the actions.
Returns:
A list of random actions.
"""
actions = [
np.random.randint(0, action_space_size, num_of_sampled_actions).reshape(-1)
for _ in range(num_actions)
]

# If num_of_sampled_actions == 1, flatten the actions to a list of numbers
if num_of_sampled_actions == 1:
actions = [action[0] for action in actions]

# Reshape actions if needed
if reshape and num_of_sampled_actions > 1:
actions = [action.reshape(num_of_sampled_actions, 1) for action in actions]

return actions


@dataclass
class BufferedData:
data: Any
Expand Down
23 changes: 6 additions & 17 deletions zoo/atari/config/atari_sampled_efficientzero_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,14 @@
# ==============================================================
# begin of the most frequently changed config specified by the user
# ==============================================================
# continuous_action_space = False
# K = 5 # num_of_sampled_actions
# collector_env_num = 8
# n_episode = 8
# evaluator_env_num = 3
# num_simulations = 50
# update_per_collect = 1000
# batch_size = 256
# max_env_step = int(1e6)
# reanalyze_ratio = 0.
continuous_action_space = False
K = 5 # num_of_sampled_actions
collector_env_num = 1
n_episode = 1
evaluator_env_num = 1
num_simulations = 5
update_per_collect = 1
batch_size = 2
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
num_simulations = 50
update_per_collect = 1000
batch_size = 256
max_env_step = int(1e6)
reanalyze_ratio = 0.
# ==============================================================
Expand All @@ -53,7 +43,6 @@
manager=dict(shared_memory=False, ),
),
policy=dict(
mcts_ctree=False,
model=dict(
observation_shape=(4, 96, 96),
frame_stack_num=4,
Expand Down

0 comments on commit e32e495

Please sign in to comment.