Skip to content

Commit

Permalink
fix(pu): fix alphazero ctree unittest
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Dec 10, 2024
1 parent f5faac8 commit aa122d0
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 321 deletions.
1 change: 0 additions & 1 deletion lzero/mcts/ctree/ctree_alphazero/mcts_alphazero.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,5 +345,4 @@ PYBIND11_MODULE(mcts_alphazero, m) {
py::arg("policy_value_func"),
py::arg("temperature"),
py::arg("sample"));

}
129 changes: 67 additions & 62 deletions lzero/mcts/ctree/ctree_alphazero/test/eval_alphazero_ctree_zh.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,7 @@
# eval_node_alphazero.py

"""
这是一个针对 C++ 实现的 MCTS(蒙特卡洛树搜索)和 Node 类的单元测试脚本。
该脚本使用 Python 的 unittest 框架,并通过 pybind11 绑定与 C++ 代码进行交互。
为测试方便,我们使用 unittest.mock 来模拟环境(simulate_env)和策略-值函数(policy_value_func)。
"""
# ./lzero/mcts/ctree/ctree_alphazero/test/eval_alphazero_ctree_zh.py

import sys
import unittest
from unittest.mock import MagicMock, patch
import numpy as np
from easydict import EasyDict

Expand All @@ -18,6 +11,43 @@
import mcts_alphazero


class MockEnv:
"""
一个简单的模拟环境类,包含必要的属性和方法。
"""

def __init__(self):
self.legal_actions = [0, 1, 2]
self.battle_mode_in_simulation_env = "self_play_mode"
self.current_player = 1
self.action_space = type('action_space', (), {'n': 3})()

def reset(self, start_player_index, init_state, katago_policy_init, katago_game_state):
"""
模拟环境的 reset 方法。
"""
pass

def step(self, action):
"""
模拟环境的 step 方法。
"""
pass

def get_done_winner(self):
"""
模拟环境的 get_done_winner 方法,返回 (done, winner)。
"""
return (False, -1)


def mock_policy_value_func(env):
"""
一个真实的 policy_value_func 函数,返回动作概率字典和叶节点值。
"""
return ({0: 0.4, 1: 0.4, 2: 0.2}, 0.9)


class TestNodeAlphaZero(unittest.TestCase):
"""
测试 Node 类的功能,包括初始化、更新、递归更新、判断叶子节点和根节点等。
Expand Down Expand Up @@ -89,7 +119,7 @@ def test_node_recursive_update_play_with_bot_mode(self):

# 检查父节点的更新
self.assertEqual(parent.visit_count, 1, "父节点的 visit_count 应为 1")
self.assertAlmostEqual(parent.value, 1.0, "父节点的 value 应为 1.0")
self.assertEqual(parent.value, 1.0, "父节点的 value 应为 1.0")

def test_node_add_child(self):
"""
Expand All @@ -103,6 +133,7 @@ def test_node_add_child(self):
self.assertIs(parent.children[3], child, "添加的子节点应与传入的 child 相同")
self.assertFalse(parent.is_leaf(), "添加子节点后,父节点不应为叶子节点")


class TestMCTSAlphaZero(unittest.TestCase):
"""
测试 MCTS 类的功能,包括初始化、UCB 评分计算、选择子节点、添加探索噪声、
Expand All @@ -127,38 +158,14 @@ def setUp(self):
# 创建一个根节点
self.root = mcts_alphazero.Node()

# 模拟环境
self.mock_env = MagicMock()
# 创建模拟环境
self.mock_env = MockEnv()

# 定义合法动作
self.legal_actions = [0, 1, 2]

# 定义 side_effect 函数,根据属性名称返回不同的值
def attr_side_effect(name):
if name == "legal_actions":
mock_list = MagicMock()
mock_list.cast.return_value = self.legal_actions
return mock_list
elif name == "battle_mode_in_simulation_env":
return "self_play_mode"
elif name == "current_player":
return 1
elif name == "action_space":
mock_action_space = MagicMock()
mock_action_space.attr.return_value = 3 # 假设 action_space.n = 3
return mock_action_space
else:
return MagicMock()

self.mock_env.attr.side_effect = attr_side_effect

# 模拟 policy_value_func
self.policy_value_func = MagicMock()
# 假设有三个合法动作,分别对应不同的 prior_p
self.policy_value_func.return_value = (
{0: 0.4, 1: 0.4, 2: 0.2}, # action_probs_dict
0.9 # leaf_value
)
# 定义 policy_value_func
self.policy_value_func = mock_policy_value_func

def test_ucb_score(self):
"""
Expand All @@ -184,7 +191,7 @@ def test_ucb_score(self):
expected_pb_c *= np.sqrt(parent.visit_count) / (child.visit_count + 1)
expected_score = expected_pb_c * child.prior_p + child.value # 使用 'value' 属性

self.assertAlmostEqual(ucb, expected_score, places=5, msg="UCB 分数计算不正确")
self.assertEqual(ucb, expected_score, msg="UCB 分数计算不正确")

def test_add_exploration_noise(self):
"""
Expand All @@ -207,14 +214,16 @@ def test_get_next_action(self):
"""
测试 MCTS 的 get_next_action 方法是否正确返回动作和概率分布。
"""
# 配置模拟环境的行为
# self.mcts.set_simulate_env(self.mock_env)
# 配置 MCTS 对象的 simulate_env
self.mcts.simulate_env = self.mock_env

state_config_for_simulation_env_reset = EasyDict(dict(start_player_index=0,
init_state=None,
katago_policy_init=False,
katago_game_state=None))
state_config_for_simulation_env_reset = EasyDict({
'start_player_index': 0,
'init_state': None,
'katago_policy_init': False,
'katago_game_state': None
})

# 执行 get_next_action
action, action_probs = self.mcts.get_next_action(
state_config_for_env_reset=state_config_for_simulation_env_reset, # 根据需要传入具体配置
Expand All @@ -231,44 +240,40 @@ def test_get_next_action(self):
f"动作概率分布的长度应为 {len(self.legal_actions)}")

# 检查 action_probs 是否为有效的概率分布
self.assertAlmostEqual(sum(action_probs), 1.0, places=5, msg="动作概率分布的和应为 1.0")
self.assertEqual(sum(action_probs), 1.0, msg="动作概率分布的和应为 1.0")

def test_expand_leaf_node(self):
"""
测试 MCTS 的 _expand_leaf_node 方法是否正确扩展叶节点。
"""
# 模拟 policy_value_func 的返回值
self.policy_value_func.return_value = (
{0: 0.4, 1: 0.4, 2: 0.2},
0.9
)
# 设置 simulate_env 为 mock_env
simulate_env = self.mock_env

# 扩展叶节点
leaf_value = self.mcts._expand_leaf_node(self.root, self.mock_env, self.policy_value_func)
leaf_value = self.mcts._expand_leaf_node(self.root, simulate_env, self.policy_value_func)

# 检查返回的叶值
self.assertEqual(leaf_value, 0.9, "扩展叶节点时返回的叶值应为 0.9")

# 检查子节点是否被正确添加
for action, prior_p in self.policy_value_func.return_value[0].items():
# child = self.root.get_child(action)
for action, prior_p in ({0: 0.4, 1: 0.4, 2: 0.2}).items():
child = self.root.children.get(action, None)
self.assertIsNotNone(child, f"动作 {action} 的子节点应存在")
self.assertEqual(child.prior_p, prior_p, f"动作 {action} 的 prior_p 应为 {prior_p}")
self.assertAlmostEqual(child.prior_p, prior_p, places=5, msg=f"动作 {action} 的 prior_p 应为 {prior_p}")

@patch.object(mcts_alphazero.MCTS, '_simulate')
def test_simulate(self, mock_simulate):
def test_simulate(self):
"""
测试 MCTS 的 _simulate 方法是否能够正确执行模拟。
由于 _simulate 方法内部有许多依赖,这里主要测试是否能够调用和更新节点。
"""
# 由于 _simulate 方法内部有许多依赖,主要测试是否能够调用和更新节点
mock_simulate.return_value = None # 不关心返回值

# 执行模拟
# 调用 _simulate 方法
self.mcts._simulate(self.root, self.mock_env, self.policy_value_func)

# 检查 _simulate 是否被调用
mock_simulate.assert_called()
# 检查节点是否有更新
# 由于 simulate 调用的是 update_recursive,视具体实现,这里可以检查某些期望的值
# 例如,检查 root 的 visit_count 是否增加
self.assertGreaterEqual(self.root.visit_count, 0, "根节点的 visit_count 应大于或等于 0")
# 这里无法具体判断,因为 _simulate 的内部逻辑被忽略

def tearDown(self):
"""
Expand Down
Loading

0 comments on commit aa122d0

Please sign in to comment.