Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lucasborboleta #14

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
32 changes: 23 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,36 +1,50 @@
# MCTS

This package provides a simple way of using Monte Carlo Tree Search in any perfect information domain.
This package provides a simple way of using Monte Carlo Tree Search in any perfect information domain.

## Installation
## Installation

With pip: `pip install mcts`

Without pip: Download the zip/tar.gz file of the [latest release](https://github.com/pbsinclair42/MCTS/releases), extract it, and run `python setup.py install`

## Quick Usage

In order to run MCTS, you must implement a `State` class which can fully describe the state of the world. It must also implement four methods:
In order to run MCTS, you must implement a `State` class which can fully describe the state of the world. It must also implement the following methods:

- `getCurrentPlayer()`: Returns 1 if it is the maximizer player's turn to choose an action, or -1 for the minimiser player
- `getPossibleActions()`: Returns an iterable of all actions which can be taken from this state
- `takeAction(action)`: Returns the state which results from taking action `action`
- `isTerminal()`: Returns whether this state is a terminal state
- `getReward()`: Returns the reward for this state. Only needed for terminal states.
- `getReward()`: Returns the reward for this state: 0 for a draw, positive for a win by maximizer player or negative for a win by the minimizer player. Only needed for terminal states.

You must also choose a hashable representation for an action as used in `getPossibleActions` and `takeAction`. Typically this would be a class with a custom `__hash__` method, but it could also simply be a tuple or a string.
You must also choose a hashable representation for an action as used in `getPossibleActions` and `takeAction`. Typically this would be a class with a custom `__hash__` method, but it could also simply be a tuple or a string.

Once these have been implemented, running MCTS is as simple as initializing your starting state, then running:

```python
from mcts import mcts

mcts = mcts(timeLimit=1000)
bestAction = mcts.search(initialState=initialState)
currentState = MyState()
...
searcher = mcts(timeLimit=1000)
bestAction = searcher.search(initialState=currentState)
...
```
See [naughtsandcrosses.py](https://github.com/pbsinclair42/MCTS/blob/master/naughtsandcrosses.py) for a simple example.
See [naughtsandcrosses.py](./naughtsandcrosses.py) for a simple example.

## Slow Usage
See [connectmnk.py](./connectmnk.py) for another example that runs a full *Connect(m,n,k,1,1)* game between two MCTS searchers.

When initializing the MCTS searcher, there are a few optional parameters that can be used to optimize the search:

- `timeLimit`: the maximum duration of the search in milliseconds. Exactly one of `timeLimit` and `iterationLimit` must be set.
- `iterationLimit`: the maximum number of search iterations to be carried out. Exactly one of `timeLimit` and `iterationLimit` must be set.
- `explorationConstant`: a weight used when searching to help the algorithm prioritize between exploring unknown areas vs deeper exploring areas it currently believes to be valuable. The higher this constant, the more the algorithm will prioritize exploring unknown areas. Default value is √2.
- `rolloutPolicy`: the policy to be used in the roll-out phase when simulating one full play-out. Default is a random uniform policy



## Detailed Information
//TODO

## Collaborating
Expand Down
255 changes: 255 additions & 0 deletions connectmnk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
from __future__ import division

import copy
from mcts import mcts
import random


class ConnectMNKState:
"""ConnectMNKState models a Connect(m,n,k,1,1) game that generalizes
the famous "Connect Four" itself equal to the Connect(7,6,4,1,1) game.

Background from wikipedia:
Connect(m,n,k,p,q) games are another generalization of gomoku to a board
with m×n intersections, k in a row needed to win, p stones for each player
to place, and q stones for the first player to place for the first move
only. Each player may play only at the lowest unoccupied place in a column.
In particular, Connect(m,n,6,2,1) is called Connect6.
"""

playerNames = {1:'O', -1:'X'}

def __init__(self, mColumns=7, nRows=6, kConnections=4):
self.mColumns = mColumns
self.nRows = nRows
self.kConnections = kConnections
self.board = [ [0 for _ in range(self.mColumns)] for _ in range(self.nRows)]
self.currentPlayer = max(ConnectMNKState.playerNames.keys())
self.isTerminated = None
self.reward = None
self.possibleActions = None
self.winingPattern = None

def show(self):
rowText = ""
for columnIndex in range(self.mColumns):
rowText += f" {columnIndex % 10} "
print(rowText)

for rowIndex in reversed(range(self.nRows)):
rowText = ""
for x in self.board[rowIndex]:
if x in self.playerNames:
rowText += f" {self.playerNames[x]} "
else:
rowText += " . "
rowText += f" {rowIndex % 10} "
print(rowText)

def getCurrentPlayer(self):
return self.currentPlayer

def getPossibleActions(self):
if self.possibleActions is None:
self.possibleActions = []
for columnIndex in range(self.mColumns):
for rowIndex in range(self.nRows):
if self.board[rowIndex][columnIndex] == 0:
action = Action(player=self.currentPlayer,
columnIndex=columnIndex,
rowIndex=rowIndex)
self.possibleActions.append(action)
break
return self.possibleActions

def takeAction(self, action):
newState = copy.copy(self)
newState.board = copy.deepcopy(newState.board)
newState.board[action.rowIndex][action.columnIndex] = action.player
newState.currentPlayer = self.currentPlayer * -1
newState.isTerminated = None
newState.possibleActions = None
newState.winingPattern = None
return newState

def isTerminal(self):
if self.isTerminated is None:
self.isTerminated = False
for rowIndex in range(self.nRows):
line = self.board[rowIndex]
lineReward = self.__getLineReward(line)
if lineReward != 0:
self.isTerminated = True
self.reward = lineReward
self.winingPattern = "k-in-row"
break

if not self.isTerminated:
for columnIndex in range(self.mColumns):
line = []
for rowIndex in range(self.nRows):
line.append(self.board[rowIndex][columnIndex])
lineReward = self.__getLineReward(line)
if lineReward != 0:
self.isTerminated = True
self.reward = lineReward
self.winingPattern = "k-in-column"
break

if not self.isTerminated:
# diagonal: rowIndex = columnIndex + parameter
for parameter in range(1 - self.mColumns, self.nRows):
line = []
for columnIndex in range(self.mColumns):
rowIndex = columnIndex + parameter
if 0 <= rowIndex < self.nRows:
line.append(self.board[rowIndex][columnIndex])
lineReward = self.__getLineReward(line)
if lineReward != 0:
self.isTerminated = True
self.reward = lineReward
self.winingPattern = "k-in-diagonal"
break

if not self.isTerminated:
# antidiagonal: rowIndex = - columnIndex + parameter
for parameter in range(0, self.mColumns + self.nRows):
line = []
for columnIndex in range(self.mColumns):
rowIndex = -columnIndex + parameter
if 0 <= rowIndex < self.nRows:
line.append(self.board[rowIndex][columnIndex])
lineReward = self.__getLineReward(line)
if lineReward != 0:
self.isTerminated = True
self.reward = lineReward
self.winingPattern = "k-in-antidiagonal"
break

if not self.isTerminated and len(self.getPossibleActions()) == 0:
self.isTerminated = True
self.reward = 0

return self.isTerminated

def __getLineReward(self, line):
lineReward = 0
if len(line) >= self.kConnections:
for player in ConnectMNKState.playerNames.keys():
playerLine = [x == player for x in line]
playerConnections = 0
for x in playerLine:
if x:
playerConnections += 1
if playerConnections == self.kConnections:
lineReward = player
break
else:
playerConnections = 0
if lineReward != 0:
break
return lineReward

def getReward(self):
assert self.isTerminal()
assert self.reward is not None
return self.reward


class Action():
def __init__(self, player, columnIndex, rowIndex):
self.player = player
self.rowIndex = rowIndex
self.columnIndex = columnIndex

def __str__(self):
return str((self.columnIndex, self.rowIndex))

def __repr__(self):
return str(self)

def __eq__(self, other):
return self.__class__ == (other.__class__ and
self.player == other.player and
self.columnIndex == other.columnIndex and
self.rowIndex == other.rowIndex)

def __hash__(self):
return hash((self.columnIndex, self.rowIndex, self.player))


def extractStatistics(searcher, action):
statistics = {}
statistics['rootNumVisits'] = searcher.root.numVisits
statistics['rootTotalReward'] = searcher.root.totalReward
statistics['actionNumVisits'] = searcher.root.children[action].numVisits
statistics['actionTotalReward'] = searcher.root.children[action].totalReward
return statistics


def main():
"""Run a full match between two MCTS searchers, possibly with different
parametrization, playing a Connect(m,n,k) game.

Extraction of MCTS statistics is examplified.

The game parameters (m,n,k) are randomly chosen.
"""

searchers = {}
searchers["mcts-1500ms"] = mcts(timeLimit=1_500)
searchers["mcts-1000ms"] = mcts(timeLimit=1_000)
searchers["mcts-500ms"] = mcts(timeLimit=500)
searchers["mcts-250ms"] = mcts(timeLimit=250)

playerNames = ConnectMNKState.playerNames

playerSearcherNames = {}
for player in sorted(playerNames.keys()):
playerSearcherNames[player] = random.choice(sorted(searchers.keys()))

runnableGames = list()
runnableGames.append((3, 3, 3))
runnableGames.append((7, 6, 4))
runnableGames.append((8, 7, 5))
runnableGames.append((9, 8, 6))
(m, n, k) = random.choice(runnableGames)
currentState = ConnectMNKState(mColumns=m, nRows=n, kConnections=k)

turn = 0
currentState.show()
while not currentState.isTerminal():
turn += 1
player = currentState.getCurrentPlayer()
action_count = len(currentState.getPossibleActions())

searcherName = playerSearcherNames[player]
searcher = searchers[searcherName]

action = searcher.search(initialState=currentState)
statistics = extractStatistics(searcher, action)
currentState = currentState.takeAction(action)

print(f"at turn {turn} player {playerNames[player]}={player} ({searcherName})" +
f" takes action (column, row)={action} amongst {action_count} possibilities")

print("mcts statitics:" +
f" chosen action= {statistics['actionTotalReward']} total reward" +
f" over {statistics['actionNumVisits']} visits /"
f" all explored actions= {statistics['rootTotalReward']} total reward" +
f" over {statistics['rootNumVisits']} visits")

print('-'*120)
currentState.show()

print('-'*120)
if currentState.getReward() == 0:
print(f"Connect(m={m},n={n},k={k}) game terminates; nobody wins")
else:
print(f"Connect(m={m},n={n},k={k}) game terminates;" +
f" player {playerNames[player]}={player} ({searcherName}) wins" +
f" with pattern {currentState.winingPattern}")


if __name__ == "__main__":
main()
7 changes: 4 additions & 3 deletions mcts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self, state, parent):


class mcts():
def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 / math.sqrt(2),
def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=math.sqrt(2),
rolloutPolicy=randomPolicy):
if timeLimit != None:
if iterationLimit != None:
Expand Down Expand Up @@ -75,6 +75,7 @@ def selectNode(self, node):

def expand(self, node):
actions = node.state.getPossibleActions()
random.shuffle(actions)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the purpose of this shuffle?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When I experimented the full game of NaughtsAndCrosses extended to grid of large size, up to 10, I observed the following : allocating 1 second was not sufficient for the MCTS agent for winning against a random agent ; the game usually terminates as a draw. Printing the state of the game at each turn, I observed that the selected actions by MCTS on its successive turns are all located in the first rows. So MCTS behavior is somehow predictable when its exploration of the tree is not sufficient for finding at least a few winning paths. That is the reason for the added shuffle in the expand method. It could be interesting to quantify the gain of many (1000 ?) games : does the added shuffle increase or not the winned games of MCTS, when the tree is quite large ?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha. I would argue that that deterministic behaviour on large search spaces is a symptom suggesting the search time needs to be significantly increased for the search to be effective, but I don't see any harm in putting this in here to avoid the ordering bias, aye.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On my game project, nammed JERSI-CERTU, to address the above mentionned issue, but still using some official release of your package MCTS, I have just shuffle the possible actions on the client side. What is your opinion on the two solutions ? If you want your code to be simple, the suffle by the client is better.

for action in actions:
if action not in node.children:
newNode = treeNode(node.state.takeAction(action), node)
Expand All @@ -95,8 +96,8 @@ def getBestChild(self, node, explorationValue):
bestValue = float("-inf")
bestNodes = []
for child in node.children.values():
nodeValue = node.state.getCurrentPlayer() * child.totalReward / child.numVisits + explorationValue * math.sqrt(
2 * math.log(node.numVisits) / child.numVisits)
nodeValue = (node.state.getCurrentPlayer() * child.totalReward / child.numVisits +
explorationValue * math.sqrt(math.log(node.numVisits) / child.numVisits))
Comment on lines +99 to +100
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good change 👍

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks !

if nodeValue > bestValue:
bestValue = nodeValue
bestNodes = [child]
Expand Down
8 changes: 4 additions & 4 deletions naughtsandcrosses.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def isTerminal(self):
[self.board[i][len(self.board) - i - 1] for i in range(len(self.board))]]:
if abs(sum(diagonal)) == 3:
return True
return reduce(operator.mul, sum(self.board, []), 1)
return reduce(operator.mul, sum(self.board, []), 1) != 0

def getReward(self):
for row in self.board:
Expand All @@ -52,7 +52,7 @@ def getReward(self):
[self.board[i][len(self.board) - i - 1] for i in range(len(self.board))]]:
if abs(sum(diagonal)) == 3:
return sum(diagonal) / 3
return False
return 0


class Action():
Expand All @@ -75,7 +75,7 @@ def __hash__(self):


initialState = NaughtsAndCrossesState()
mcts = mcts(timeLimit=1000)
action = mcts.search(initialState=initialState)
searcher = mcts(timeLimit=1000)
action = searcher.search(initialState=initialState)

print(action)