-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(pu): add ctree version of mcts in alphazero
- Loading branch information
1 parent
6b8d853
commit 460ab9f
Showing
11 changed files
with
165 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# Declare the minimum version of CMake that can be used | ||
# To understand and build the project | ||
cmake_minimum_required(VERSION 3.4...3.18) | ||
|
||
# Set the project name to mcts_alphazero and set the version to 1.0 | ||
project(mcts_alphazero VERSION 1.0) | ||
|
||
# Find and get the details of Python package | ||
# This is required for embedding Python in the project | ||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED) | ||
|
||
# Add pybind11 as a subdirectory, | ||
# so that its build files are generated alongside the current project. | ||
# This is necessary because the current project depends on pybind11 | ||
add_subdirectory(pybind11) | ||
|
||
# Add two .cpp files to the mcts_alphazero module | ||
# These files are compiled and linked into the module | ||
pybind11_add_module(mcts_alphazero mcts_alphazero.cpp node_alphazero.cpp) | ||
|
||
# Add the Python header file paths to the include paths | ||
# of the mcts_alphazero library. This is necessary for the | ||
# project to find the Python header files it needs to include | ||
target_include_directories(mcts_alphazero PRIVATE ${Python3_INCLUDE_DIRS}) | ||
|
||
# Link the mcts_alphazero library with the pybind11::module target. | ||
# This is necessary for the mcts_alphazero library to use the functions and classes defined by pybind11 | ||
target_link_libraries(mcts_alphazero PRIVATE pybind11::module) | ||
|
||
# Set the Python standard to the version of Python found by find_package(Python3) | ||
# This ensures that the code will be compiled against the correct version of Python | ||
set_target_properties(mcts_alphazero PROPERTIES PYTHON_STANDARD ${Python3_VERSION}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Declare the minimum version of CMake that can be used | ||
# To understand and build the project | ||
cmake_minimum_required(VERSION 3.4...3.18) | ||
|
||
# Set the project name to mcts_alphazero and set the version to 1.0 | ||
project(mcts_alphazero VERSION 1.0) | ||
|
||
# Find and get the details of Python package | ||
# This is required for embedding Python in the project | ||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED) | ||
|
||
# Add pybind11 as a subdirectory, | ||
# so that its build files are generated alongside the current project. | ||
# This is necessary because the current project depends on pybind11 | ||
add_subdirectory(pybind11) | ||
pybind11_add_module(mcts_alphazero mcts_alphazero.cpp) | ||
|
||
# Add the Python header file paths to the include paths | ||
# of the mcts_alphazero library. This is necessary for the | ||
# project to find the Python header files it needs to include | ||
target_include_directories(mcts_alphazero PRIVATE ${Python3_INCLUDE_DIRS}) | ||
|
||
# Link the mcts_alphazero library with the pybind11::module target. | ||
# This is necessary for the mcts_alphazero library to use the functions and classes defined by pybind11 | ||
target_link_libraries(mcts_alphazero PRIVATE pybind11::module) | ||
|
||
# Set the Python standard to the version of Python found by find_package(Python3) | ||
# This ensures that the code will be compiled against the correct version of Python | ||
set_target_properties(mcts_alphazero PROPERTIES PYTHON_STANDARD ${Python3_VERSION}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
# Declare the minimum version of CMake that can be used | ||
# To understand and build the project | ||
cmake_minimum_required(VERSION 3.4...3.18) | ||
|
||
# Set the project name to node_alphazero and set the version to 1.0 | ||
project(node_alphazero VERSION 1.0) | ||
|
||
# Find and get the details of Python package | ||
# This is required for embedding Python in the project | ||
find_package(Python3 COMPONENTS Interpreter Development REQUIRED) | ||
|
||
# Add pybind11 as a subdirectory, | ||
# so that its build files are generated alongside the current project. | ||
# This is necessary because the current project depends on pybind11 | ||
add_subdirectory(pybind11) | ||
pybind11_add_module(node_alphazero node_alphazero.cpp) | ||
|
||
# Add the Python header file paths to the include paths | ||
# of the node_alphazero library. This is necessary for the | ||
# project to find the Python header files it needs to include | ||
target_include_directories(node_alphazero PRIVATE ${Python3_INCLUDE_DIRS}) | ||
|
||
# Link the node_alphazero library with the pybind11::module target. | ||
# This is necessary for the node_alphazero library to use the functions and classes defined by pybind11 | ||
target_link_libraries(node_alphazero PRIVATE pybind11::module) | ||
|
||
# Set the Python standard to the version of Python found by find_package(Python3) | ||
# This ensures that the code will be compiled against the correct version of Python | ||
set_target_properties(node_alphazero PROPERTIES PYTHON_STANDARD ${Python3_VERSION}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
# Navigate to the project directory | ||
cd /Users/puyuan/code/LightZero/lzero/mcts/ctree/ctree_alphazero/ | ||
|
||
# Create a new directory named "build." The build directory is where the compiled files will be stored. | ||
mkdir build | ||
|
||
# Navigate into the "build" directory | ||
cd build | ||
|
||
# Run cmake on the parent directory. The ".." refers to the parent directory of the current directory. | ||
# The -DCMAKE_OSX_ARCHITECTURES="arm64" flag specifies that the generated build files should be suitable for the arm64 architecture. | ||
cmake .. -DCMAKE_OSX_ARCHITECTURES="arm64" | ||
|
||
# Run the "make" command. This command uses the files generated by cmake to compile the project. | ||
make |
Submodule pybind11
added at
f26069
34 changes: 34 additions & 0 deletions
34
lzero/mcts/ctree/ctree_alphazero/test/test_mcts_alphazero.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
""" | ||
This is an illustrative example of Python interfacing with a MCTS (Monte Carlo Tree Search) object implemented in C++. | ||
Please note that this code is not designed for actual execution. | ||
It serves as a conceptual demonstration, providing an understanding of how Python can interact with C++ objects, | ||
specifically within the context of MCTS. | ||
""" | ||
import sys | ||
|
||
import torch | ||
|
||
sys.path.append('/Users/puyuan/code/LightZero/lzero/mcts/ctree/ctree_alphazero/build') | ||
|
||
import mcts_alphazero | ||
mcts_alphazero = mcts_alphazero.MCTS() | ||
|
||
def _policy_value_fn(self, env: 'Env') -> Tuple[Dict[int, np.ndarray], float]: # noqa | ||
legal_actions = env.legal_actions | ||
current_state, current_state_scale = env.current_state() | ||
current_state_scale = torch.from_numpy(current_state_scale).to( | ||
device=self._device, dtype=torch.float | ||
).unsqueeze(0) | ||
with torch.no_grad(): | ||
action_probs, value = self._policy_model.compute_policy_value(current_state_scale) | ||
action_probs_dict = dict(zip(legal_actions, action_probs.squeeze(0)[legal_actions].detach().cpu().numpy())) | ||
return action_probs_dict, value.item() | ||
|
||
action, mcts_probs = mcts_alphazero.get_next_action( | ||
simulate_env=simulate_env, | ||
policy_value_func=_policy_value_fn, | ||
temperature=1, | ||
sample=True, | ||
) | ||
|
||
print(action, mcts_probs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import sys | ||
sys.path.append('/Users/puyuan/code/LightZero/lzero/mcts/ctree/ctree_alphazero/build') | ||
|
||
import mcts_alphazero | ||
n = mcts_alphazero.Node() | ||
print(n.is_leaf()) | ||
print(n.update(5.0)) | ||
# print(n.value()) | ||
print(n) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters