Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lucasborboleta #14

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
*.pyc
__pycache__
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer not to include generic .gitignores like this, in favour of each developer having their own global config, though I appreciate this view isn't shared by everyone. As such, I'd rather not include a .gitignore unless it's for generated files specific to this repository.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your point.

35 changes: 24 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,37 +1,50 @@
# MCTS

This package provides a simple way of using Monte Carlo Tree Search in any perfect information domain.
This package provides a simple way of using Monte Carlo Tree Search in any perfect information domain.

## Installation
## Installation

With pip: `pip install mcts`

Without pip: Download the zip/tar.gz file of the [latest release](https://github.com/pbsinclair42/MCTS/releases), extract it, and run `python setup.py install`

## Quick Usage

In order to run MCTS, you must implement a `State` class which can fully describe the state of the world. It must also implement four methods:
In order to run MCTS, you must implement a `State` class which can fully describe the state of the world. It must also implement the following methods:

- `getCurrentPlayer()`: Returns 1 if it is the maximizer player's turn to choose an action, or -1 for the minimiser player
- `getCurrentPlayer()`: Returns 1 if it is the maximizer player's turn to choose an action, or -1 for the minimizer player
- `getPossibleActions()`: Returns an iterable of all actions which can be taken from this state
- `takeAction(action)`: Returns the state which results from taking action `action`
- `isTerminal()`: Returns whether this state is a terminal state
- `getReward()`: Returns the reward for this state. Only needed for terminal states.
- `getReward()`: Returns the reward for this state: 0 for a draw, positive for a win by maximizer player or negative for a win by the minimizer player. Only needed for terminal states.

You must also choose a hashable representation for an action as used in `getPossibleActions` and `takeAction`. Typically this would be a class with a custom `__hash__` method, but it could also simply be a tuple or a string.
You must also choose a hashable representation for an action as used in `getPossibleActions` and `takeAction`. Typically this would be a class with a custom `__hash__` method, but it could also simply be a tuple or a string.

Once these have been implemented, running MCTS is as simple as initializing your starting state, then running:

```python
from mcts import mcts

mcts = mcts(timeLimit=1000)
bestAction = mcts.search(initialState=initialState)
currentState = MyState()
...
searcher = mcts(timeLimit=1000)
bestAction = searcher.search(initialState=currentState)
currentState = currentState.takeAction(action)
...
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change has now been made in #13


```
See [naughtsandcrosses.py](https://github.com/pbsinclair42/MCTS/blob/master/naughtsandcrosses.py) for a simple example.
See [naughtsandcrosses.py](https://github.com/pbsinclair42/MCTS/blob/master/naughtsandcrosses.py) for a simple example.

## Detailed usage
A few customizations are possible through the `mcts` constructor:

- The number of MCTS search rounds can be limited by either a given time limit or a given iteration number.
- The exploration constant $c$, which appears in the UCT score $w_i/n_i + c\sqrt{{ln N_i}/n_i}$ with theoretical default setting $c=\sqrt 2$, can be adapted to your game.
- The default uniform random rollout/playout policy can be changed.

A few statistics can be retrieved after each MCTS search call (see `naughtsandcrosses.py` example)

## Slow Usage
//TODO
More of MCTS theory could be found at https://en.wikipedia.org/wiki/Monte_Carlo_tree_search and cited references.
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this is useful, I'd still want to include more here before saying that the Detailed Usage / Slow Usage section is complete. Indeed, I feel like these notes would be useful included in the quick usage section too, so maybe pop them there instead and put the slow usage todo note back?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, could we simplify this to the following:

When initialising the MCTS searcher, there are a few optional parameters that can be used to optimise the search:

  • timeLimit: the maximum duration of the search in milliseconds. Exactly one of timeLimit and iterationLimit must be set.
  • iterationLimit: the maximum number of search iterations to be carried out. Exactly one of timeLimit and iterationLimit must be set.
  • explorationConstant: a weight used when searching to help the algorithm prioritise between exploring unknown areas vs deeper exploring areas it currently believes to be valuable. The higher this constant, the more the algorithm will prioritise exploring unknown areas. Default value is √2.
  • rolloutPolicy: the policy to be used in the rollout phase when simulating one full playout. Default is a random uniform policy.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very good synthesis for the quick usage 👍
Regarding the section "Slow usage", what do you mean by "slow" ? Do you intent "detailled information" ? Or have you in mind another mode for MCTS that indirectly slows down the search (like a different rolloutPolicy) ? Even if your material for this "slow usage" is not yet ready, could you either refine the title or put some hints in its TODO, if possible ?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aye, that was going to be for more detailed information on the whole system, so sure changing that header to "Detailed Information" would probably be clearer, go for it.


## Collaborating

Expand Down
16 changes: 13 additions & 3 deletions mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, state, parent):


class mcts():
def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 / math.sqrt(2),
def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=math.sqrt(2),
rolloutPolicy=randomPolicy):
if timeLimit != None:
if iterationLimit != None:
Expand Down Expand Up @@ -75,6 +75,7 @@ def selectNode(self, node):

def expand(self, node):
actions = node.state.getPossibleActions()
random.shuffle(actions)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this shuffle?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I experimented the full game of NaughtsAndCrosses extended to grid of large size, up to 10, I observed the following : allocating 1 second was not sufficient for the MCTS agent for winning against a random agent ; the game usually terminates as a draw. Printing the state of the game at each turn, I observed that the selected actions by MCTS on its successive turns are all located in the first rows. So MCTS behavior is somehow predictable when its exploration of the tree is not sufficient for finding at least a few winning paths. That is the reason for the added shuffle in the expand method. It could be interesting to quantify the gain of many (1000 ?) games : does the added shuffle increase or not the winned games of MCTS, when the tree is quite large ?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. I would argue that that deterministic behaviour on large search spaces is a symptom suggesting the search time needs to be significantly increased for the search to be effective, but I don't see any harm in putting this in here to avoid the ordering bias, aye.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my game project, nammed JERSI-CERTU, to address the above mentionned issue, but still using some official release of your package MCTS, I have just shuffle the possible actions on the client side. What is your opinion on the two solutions ? If you want your code to be simple, the suffle by the client is better.

for action in actions:
if action not in node.children:
newNode = treeNode(node.state.takeAction(action), node)
Expand All @@ -95,8 +96,8 @@ def getBestChild(self, node, explorationValue):
bestValue = float("-inf")
bestNodes = []
for child in node.children.values():
nodeValue = node.state.getCurrentPlayer() * child.totalReward / child.numVisits + explorationValue * math.sqrt(
2 * math.log(node.numVisits) / child.numVisits)
nodeValue = (node.state.getCurrentPlayer() * child.totalReward / child.numVisits +
explorationValue * math.sqrt(math.log(node.numVisits) / child.numVisits))
Comment on lines +99 to +100
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good change 👍

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

if nodeValue > bestValue:
bestValue = nodeValue
bestNodes = [child]
Expand All @@ -108,3 +109,12 @@ def getAction(self, root, bestChild):
for action, node in root.children.items():
if node is bestChild:
return action

def getStatistics(self, action=None):
statistics = {}
statistics['rootNumVisits'] = self.root.numVisits
statistics['rootTotalReward'] = self.root.totalReward
if action is not None:
statistics['actionNumVisits'] = self.root.children[action].numVisits
statistics['actionTotalReward'] = self.root.children[action].totalReward
return statistics
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure of the purpose of this? While I can see the ability to extract extra statistics from the search tree could be useful, this is a relatively limited set of statistics that is being generated, so it seems very specific to one use case, and therefore maybe doesn't belong here, but instead in the calling code? I agree that there needs to be some documented way on the appropriate way of extracting extra statistics from the search tree though.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you can garantee the stability of the tree structure, either by direct access to the attributes or by an API, then, yes, I agree that the kind of statistics that I used can be located in the "client" side.
The purpose of my extracted statistics was : for large tree, regarding the allocation CPU time or round number, I was monitoring the way the action has been selected, from a few winning paths or if no winning paths have been found.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. I'd rather hold off on adding this here then, instead offering a full API for extracting these details appropriately.

103 changes: 82 additions & 21 deletions naughtsandcrosses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,54 +4,72 @@
from mcts import mcts
from functools import reduce
import operator
import random


class NaughtsAndCrossesState():
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is deliberately designed to be a super simple example that users can easily eyeball and see what's going on, so I'd rather not complicate it like this. If you'd like to create an alternative search space and include both a more complex state and a more complex search (doing the full game instead of just one round) then I'd be happy to include that example too, but separately to this naughts and crosses one.

Copy link
Author

@LucasBorboleta LucasBorboleta Jan 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand your point : providing a first super simple example. However, beside this first example, a second example that combines the complete API, like doing the full game, could also helps and attracts new users. Also varying the tree depth and permuting roles of the MCTS and random agents is instructive. Practically, I could propose : 1) naughtsandcrosses.py file kept unchanged ; 2) in another file, some generalization of Connect Four game with a grid size parameter, with code for the full game, in the way I did with NaughtsAndCrosses. Would you agree ?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aye, that sounds like a great idea, thanks!

def __init__(self):
self.board = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]

playerNames = {1:'O', -1:'X'}

def __init__(self, gridSize=3):
self.gridSize = gridSize
self.board = [ [0 for _ in range(self.gridSize)] for _ in range(self.gridSize)]
self.currentPlayer = 1
self.possibleActions = None

def show(self):
for row in self.board:
row_text = ""
for cell in row:
if cell in self.playerNames:
row_text += f" {self.playerNames[cell]} "
else:
row_text += " . "
print(row_text)

def getCurrentPlayer(self):
return self.currentPlayer

def getPossibleActions(self):
possibleActions = []
for i in range(len(self.board)):
for j in range(len(self.board[i])):
if self.board[i][j] == 0:
possibleActions.append(Action(player=self.currentPlayer, x=i, y=j))
return possibleActions
if self.possibleActions is None:
self.possibleActions = []
for i in range(len(self.board)):
for j in range(len(self.board[i])):
if self.board[i][j] == 0:
self.possibleActions.append(Action(player=self.currentPlayer, x=i, y=j))
return self.possibleActions

def takeAction(self, action):
newState = deepcopy(self)
newState.board[action.x][action.y] = action.player
newState.currentPlayer = self.currentPlayer * -1
newState.possibleActions = None
return newState

def isTerminal(self):
for row in self.board:
if abs(sum(row)) == 3:
if abs(sum(row)) == self.gridSize:
return True
for column in list(map(list, zip(*self.board))):
if abs(sum(column)) == 3:
if abs(sum(column)) == self.gridSize:
return True
for diagonal in [[self.board[i][i] for i in range(len(self.board))],
[self.board[i][len(self.board) - i - 1] for i in range(len(self.board))]]:
if abs(sum(diagonal)) == 3:
if abs(sum(diagonal)) == self.gridSize:
return True
return reduce(operator.mul, sum(self.board, []), 1)

def getReward(self):
for row in self.board:
if abs(sum(row)) == 3:
return sum(row) / 3
if abs(sum(row)) == self.gridSize:
return sum(row) / self.gridSize
for column in list(map(list, zip(*self.board))):
if abs(sum(column)) == 3:
return sum(column) / 3
if abs(sum(column)) == self.gridSize:
return sum(column) / self.gridSize
for diagonal in [[self.board[i][i] for i in range(len(self.board))],
[self.board[i][len(self.board) - i - 1] for i in range(len(self.board))]]:
if abs(sum(diagonal)) == 3:
return sum(diagonal) / 3
if abs(sum(diagonal)) == self.gridSize:
return sum(diagonal) / self.gridSize
return False


Expand All @@ -74,8 +92,51 @@ def __hash__(self):
return hash((self.x, self.y, self.player))


initialState = NaughtsAndCrossesState()
mcts = mcts(timeLimit=1000)
action = mcts.search(initialState=initialState)
def main():
"""Example of a NaughtsAndCrossesState game play between MCTS and random searchers.
The standard 3x3 grid is randomly extended up to 10x10 in order to exercise the MCTS time ressource.
One of the two player is randomly assigned to the MCTS searcher for purpose of correctness checking.
A basic statistics is provided at each MCTS turn."""

playerNames = NaughtsAndCrossesState.playerNames
mctsPlayer = random.choice(sorted(playerNames.keys()))
gridSize = random.choice(list(range(3,11)))

currentState = NaughtsAndCrossesState(gridSize)
turn = 0
currentState.show()
while not currentState.isTerminal():
turn += 1
player = currentState.getCurrentPlayer()
action_count = len(currentState.getPossibleActions())

if player == mctsPlayer:
searcher = mcts(timeLimit=1_000)
searcherName = "mcts-1-second"
action = searcher.search(initialState=currentState)
statistics = searcher.getStatistics(action)
else:
searcherName = "random"
action = random.choice(currentState.getPossibleActions())
statistics = None

currentState = currentState.takeAction(action)
print(f"at turn {turn} player {playerNames[player]}={player} ({searcherName}) takes action {action} amongst {action_count} possibilities")

if statistics is not None:
print(f"mcts statitics for the chosen action: {statistics['actionTotalReward']} total reward over {statistics['actionNumVisits']} visits")
print(f"mcts statitics for all explored actions: {statistics['rootTotalReward']} total reward over {statistics['rootNumVisits']} visits")

print('-'*90)
currentState.show()

print('-'*90)
if currentState.getReward() == 0:
print(f"game {gridSize}x{gridSize} terminates; nobody wins")
else:
print(f"game {gridSize}x{gridSize} terminates; player {playerNames[player]}={player} ({searcherName}) wins")


if __name__ == "__main__":
main()

print(action)