Skip to content

Commit

Permalink
fix(pu): fix entry import and nparray object bug in buffer
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Sep 26, 2024
1 parent d5fff6d commit 29197d2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 11 deletions.
4 changes: 2 additions & 2 deletions lzero/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from .eval_muzero_with_gym_env import eval_muzero_with_gym_env
from .train_alphazero import train_alphazero
from .train_muzero import train_muzero
from .train_muzero_rer import train_muzero_rer
from .train_muzero_reanalyze import train_muzero_reanalyze
from .train_muzero_with_gym_env import train_muzero_with_gym_env
from .train_muzero_with_gym_env import train_muzero_with_gym_env
from .train_muzero_with_reward_model import train_muzero_with_reward_model
from .train_rezero import train_rezero
from .train_rezero_uz import train_rezero_uz
from .train_unizero import train_unizero
from .train_unizero_reanalyze import train_unizero_reanalyze
13 changes: 4 additions & 9 deletions lzero/mcts/buffer/game_buffer_muzero.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,20 +550,15 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A
target_values.append(value_list[value_index])
target_rewards.append(reward_list[current_index])
else:
target_values.append(np.array([0.]))
target_rewards.append(np.array([0.]))
target_values.append(np.array(0.))
target_rewards.append(np.array(0.))
value_index += 1

batch_rewards.append(target_rewards)
batch_target_values.append(target_values)

# batch_rewards = np.asarray(batch_rewards)
# batch_target_values = np.asarray(batch_target_values)
# batch_rewards = np.squeeze(batch_rewards, axis=-1)
# batch_target_values = np.squeeze(batch_target_values, axis=-1)

batch_rewards = np.asarray(batch_rewards, dtype=object)
batch_target_values = np.asarray(batch_target_values, dtype=object)
batch_rewards = np.asarray(batch_rewards)
batch_target_values = np.asarray(batch_target_values)

return batch_rewards, batch_target_values

Expand Down

0 comments on commit 29197d2

Please sign in to comment.