From 29197d2fd88956be6cdc954471e0373ca220c34b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=92=B2=E6=BA=90?= <2402552459@qq.com> Date: Thu, 26 Sep 2024 17:36:40 +0800 Subject: [PATCH] fix(pu): fix entry import and nparray object bug in buffer --- lzero/entry/__init__.py | 4 ++-- lzero/mcts/buffer/game_buffer_muzero.py | 13 ++++--------- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/lzero/entry/__init__.py b/lzero/entry/__init__.py index 5e86ca902..cb7a32e7f 100644 --- a/lzero/entry/__init__.py +++ b/lzero/entry/__init__.py @@ -3,10 +3,10 @@ from .eval_muzero_with_gym_env import eval_muzero_with_gym_env from .train_alphazero import train_alphazero from .train_muzero import train_muzero -from .train_muzero_rer import train_muzero_rer +from .train_muzero_reanalyze import train_muzero_reanalyze from .train_muzero_with_gym_env import train_muzero_with_gym_env from .train_muzero_with_gym_env import train_muzero_with_gym_env from .train_muzero_with_reward_model import train_muzero_with_reward_model from .train_rezero import train_rezero -from .train_rezero_uz import train_rezero_uz from .train_unizero import train_unizero +from .train_unizero_reanalyze import train_unizero_reanalyze \ No newline at end of file diff --git a/lzero/mcts/buffer/game_buffer_muzero.py b/lzero/mcts/buffer/game_buffer_muzero.py index 2f0acfdca..7a07e1df9 100644 --- a/lzero/mcts/buffer/game_buffer_muzero.py +++ b/lzero/mcts/buffer/game_buffer_muzero.py @@ -550,20 +550,15 @@ def _compute_target_reward_value(self, reward_value_context: List[Any], model: A target_values.append(value_list[value_index]) target_rewards.append(reward_list[current_index]) else: - target_values.append(np.array([0.])) - target_rewards.append(np.array([0.])) + target_values.append(np.array(0.)) + target_rewards.append(np.array(0.)) value_index += 1 batch_rewards.append(target_rewards) batch_target_values.append(target_values) - # batch_rewards = np.asarray(batch_rewards) - # batch_target_values = np.asarray(batch_target_values) - # batch_rewards = np.squeeze(batch_rewards, axis=-1) - # batch_target_values = np.squeeze(batch_target_values, axis=-1) - - batch_rewards = np.asarray(batch_rewards, dtype=object) - batch_target_values = np.asarray(batch_target_values, dtype=object) + batch_rewards = np.asarray(batch_rewards) + batch_target_values = np.asarray(batch_target_values) return batch_rewards, batch_target_values