Skip to content

Commit

Permalink
feature(pu): add ctree version of mcts in alphazero
Browse files Browse the repository at this point in the history
  • Loading branch information
puyuan1996 committed Nov 7, 2023
1 parent 6b8d853 commit 460ab9f
Show file tree
Hide file tree
Showing 11 changed files with 165 additions and 7 deletions.
32 changes: 32 additions & 0 deletions lzero/mcts/ctree/ctree_alphazero/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# Declare the minimum version of CMake that can be used
# To understand and build the project
cmake_minimum_required(VERSION 3.4...3.18)

# Set the project name to mcts_alphazero and set the version to 1.0
project(mcts_alphazero VERSION 1.0)

# Find and get the details of Python package
# This is required for embedding Python in the project
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)

# Add pybind11 as a subdirectory,
# so that its build files are generated alongside the current project.
# This is necessary because the current project depends on pybind11
add_subdirectory(pybind11)

# Add two .cpp files to the mcts_alphazero module
# These files are compiled and linked into the module
pybind11_add_module(mcts_alphazero mcts_alphazero.cpp node_alphazero.cpp)

# Add the Python header file paths to the include paths
# of the mcts_alphazero library. This is necessary for the
# project to find the Python header files it needs to include
target_include_directories(mcts_alphazero PRIVATE ${Python3_INCLUDE_DIRS})

# Link the mcts_alphazero library with the pybind11::module target.
# This is necessary for the mcts_alphazero library to use the functions and classes defined by pybind11
target_link_libraries(mcts_alphazero PRIVATE pybind11::module)

# Set the Python standard to the version of Python found by find_package(Python3)
# This ensures that the code will be compiled against the correct version of Python
set_target_properties(mcts_alphazero PROPERTIES PYTHON_STANDARD ${Python3_VERSION})
29 changes: 29 additions & 0 deletions lzero/mcts/ctree/ctree_alphazero/CMakeLists_mcts.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Declare the minimum version of CMake that can be used
# To understand and build the project
cmake_minimum_required(VERSION 3.4...3.18)

# Set the project name to mcts_alphazero and set the version to 1.0
project(mcts_alphazero VERSION 1.0)

# Find and get the details of Python package
# This is required for embedding Python in the project
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)

# Add pybind11 as a subdirectory,
# so that its build files are generated alongside the current project.
# This is necessary because the current project depends on pybind11
add_subdirectory(pybind11)
pybind11_add_module(mcts_alphazero mcts_alphazero.cpp)

# Add the Python header file paths to the include paths
# of the mcts_alphazero library. This is necessary for the
# project to find the Python header files it needs to include
target_include_directories(mcts_alphazero PRIVATE ${Python3_INCLUDE_DIRS})

# Link the mcts_alphazero library with the pybind11::module target.
# This is necessary for the mcts_alphazero library to use the functions and classes defined by pybind11
target_link_libraries(mcts_alphazero PRIVATE pybind11::module)

# Set the Python standard to the version of Python found by find_package(Python3)
# This ensures that the code will be compiled against the correct version of Python
set_target_properties(mcts_alphazero PROPERTIES PYTHON_STANDARD ${Python3_VERSION})
29 changes: 29 additions & 0 deletions lzero/mcts/ctree/ctree_alphazero/CMakeLists_node.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Declare the minimum version of CMake that can be used
# To understand and build the project
cmake_minimum_required(VERSION 3.4...3.18)

# Set the project name to node_alphazero and set the version to 1.0
project(node_alphazero VERSION 1.0)

# Find and get the details of Python package
# This is required for embedding Python in the project
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)

# Add pybind11 as a subdirectory,
# so that its build files are generated alongside the current project.
# This is necessary because the current project depends on pybind11
add_subdirectory(pybind11)
pybind11_add_module(node_alphazero node_alphazero.cpp)

# Add the Python header file paths to the include paths
# of the node_alphazero library. This is necessary for the
# project to find the Python header files it needs to include
target_include_directories(node_alphazero PRIVATE ${Python3_INCLUDE_DIRS})

# Link the node_alphazero library with the pybind11::module target.
# This is necessary for the node_alphazero library to use the functions and classes defined by pybind11
target_link_libraries(node_alphazero PRIVATE pybind11::module)

# Set the Python standard to the version of Python found by find_package(Python3)
# This ensures that the code will be compiled against the correct version of Python
set_target_properties(node_alphazero PROPERTIES PYTHON_STANDARD ${Python3_VERSION})
15 changes: 15 additions & 0 deletions lzero/mcts/ctree/ctree_alphazero/make.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Navigate to the project directory
cd /Users/puyuan/code/LightZero/lzero/mcts/ctree/ctree_alphazero/

# Create a new directory named "build." The build directory is where the compiled files will be stored.
mkdir build

# Navigate into the "build" directory
cd build

# Run cmake on the parent directory. The ".." refers to the parent directory of the current directory.
# The -DCMAKE_OSX_ARCHITECTURES="arm64" flag specifies that the generated build files should be suitable for the arm64 architecture.
cmake .. -DCMAKE_OSX_ARCHITECTURES="arm64"

# Run the "make" command. This command uses the files generated by cmake to compile the project.
make
1 change: 1 addition & 0 deletions lzero/mcts/ctree/ctree_alphazero/pybind11
Submodule pybind11 added at f26069
34 changes: 34 additions & 0 deletions lzero/mcts/ctree/ctree_alphazero/test/test_mcts_alphazero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
This is an illustrative example of Python interfacing with a MCTS (Monte Carlo Tree Search) object implemented in C++.
Please note that this code is not designed for actual execution.
It serves as a conceptual demonstration, providing an understanding of how Python can interact with C++ objects,
specifically within the context of MCTS.
"""
import sys

import torch

sys.path.append('/Users/puyuan/code/LightZero/lzero/mcts/ctree/ctree_alphazero/build')

import mcts_alphazero
mcts_alphazero = mcts_alphazero.MCTS()

def _policy_value_fn(self, env: 'Env') -> Tuple[Dict[int, np.ndarray], float]: # noqa
legal_actions = env.legal_actions
current_state, current_state_scale = env.current_state()
current_state_scale = torch.from_numpy(current_state_scale).to(
device=self._device, dtype=torch.float
).unsqueeze(0)
with torch.no_grad():
action_probs, value = self._policy_model.compute_policy_value(current_state_scale)
action_probs_dict = dict(zip(legal_actions, action_probs.squeeze(0)[legal_actions].detach().cpu().numpy()))
return action_probs_dict, value.item()

action, mcts_probs = mcts_alphazero.get_next_action(
simulate_env=simulate_env,
policy_value_func=_policy_value_fn,
temperature=1,
sample=True,
)

print(action, mcts_probs)
9 changes: 9 additions & 0 deletions lzero/mcts/ctree/ctree_alphazero/test/test_node_alphazero.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import sys
sys.path.append('/Users/puyuan/code/LightZero/lzero/mcts/ctree/ctree_alphazero/build')

import mcts_alphazero
n = mcts_alphazero.Node()
print(n.is_leaf())
print(n.update(5.0))
# print(n.value())
print(n)
2 changes: 1 addition & 1 deletion lzero/policy/alphazero.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def _get_simulation_env(self):
self.simulate_env = TicTacToeEnv(tictactoe_alphazero_config.env)

elif self._cfg.simulation_env_name == 'gomoku':
from zoo.board_games.gomoku.envs.gomoku_env_ui import GomokuEnv
from zoo.board_games.gomoku.envs.gomoku_env import GomokuEnv
if self._cfg.simulation_env_config_type == 'play_with_bot':
from zoo.board_games.gomoku.config.gomoku_alphazero_bot_mode_config import gomoku_alphazero_config
elif self._cfg.simulation_env_config_type == 'self_play':
Expand Down
16 changes: 11 additions & 5 deletions zoo/board_games/gomoku/config/gomoku_alphazero_bot_mode_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
# begin of the most frequently changed config specified by the user
# ==============================================================
board_size = 6 # default_size is 15
collector_env_num = 32
n_episode = 32
evaluator_env_num = 5
collector_env_num = 8
n_episode = 8
evaluator_env_num = 3
num_simulations = 50
update_per_collect = 50
batch_size = 256
max_env_step = int(5e5)
prob_random_action_in_bot = 0.5
mcts_ctree = True

# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
Expand All @@ -21,9 +23,9 @@
env=dict(
board_size=board_size,
battle_mode='play_with_bot_mode',
bot_action_type='v0',
bot_action_type='v1',
prob_random_action_in_bot=prob_random_action_in_bot,
channel_last=False, # NOTE
channel_last=False,
collector_env_num=collector_env_num,
evaluator_env_num=evaluator_env_num,
n_evaluator_episode=evaluator_env_num,
Expand All @@ -35,9 +37,13 @@
prob_expert_agent=0,
scale=True,
check_action_to_connect4_in_bot_v0=False,
mcts_ctree=mcts_ctree,
screen_scaling=9,
render_mode=None,
# ==============================================================
),
policy=dict(
mcts_ctree=mcts_ctree,
# ==============================================================
# for the creation of simulation env
simulation_env_name='gomoku',
Expand Down
2 changes: 1 addition & 1 deletion zoo/board_games/gomoku/envs/gomoku_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(self, cfg: dict = None):
self.render_mode = cfg.render_mode
self.replay_name_suffix = "test"
self.replay_path = None
self.replay_format = 'gif' # 'mp4' #
self.replay_format = 'gif' # options={'gif', 'mp4'}
self.screen = None
self.frames = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
update_per_collect = 50
batch_size = 256
max_env_step = int(2e5)
mcts_ctree = True
# ==============================================================
# end of the most frequently changed config specified by the user
# ==============================================================
Expand All @@ -31,6 +32,7 @@
prob_random_agent=0,
prob_expert_agent=0,
scale=True,
mcts_ctree=mcts_ctree,
# ==============================================================
),
policy=dict(
Expand All @@ -39,6 +41,7 @@
simulation_env_name='tictactoe',
simulation_env_config_type='play_with_bot',
# ==============================================================
mcts_ctree=mcts_ctree,
model=dict(
observation_shape=(3, 3, 3),
action_space_size=int(1 * 3 * 3),
Expand Down

0 comments on commit 460ab9f

Please sign in to comment.