Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add "write your own policy" in README #15

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
69 changes: 59 additions & 10 deletions README.md
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,36 +1,85 @@
# MCTS

This package provides a simple way of using Monte Carlo Tree Search in any perfect information domain.
This package provides a simple way of using Monte Carlo Tree Search in any perfect information domain.

## Installation
## Installation

With pip: `pip install mcts`

Without pip: Download the zip/tar.gz file of the [latest release](https://github.com/pbsinclair42/MCTS/releases), extract it, and run `python setup.py install`

## Quick Usage

In order to run MCTS, you must implement a `State` class which can fully describe the state of the world. It must also implement four methods:
In order to run MCTS, you must implement a `State` class which can fully describe the state of the world. It must also implement four methods:

- `getCurrentPlayer()`: Returns 1 if it is the maximizer player's turn to choose an action, or -1 for the minimiser player
- `getPossibleActions()`: Returns an iterable of all actions which can be taken from this state
- `getPossibleActions()`: Returns an iterable of all `action`s which can be taken from this state
- `takeAction(action)`: Returns the state which results from taking action `action`
- `isTerminal()`: Returns whether this state is a terminal state
- `getReward()`: Returns the reward for this state. Only needed for terminal states.
- `isTerminal()`: Returns `True` if this state is a terminal state
- `getReward()`: Returns the reward for this state. Only needed for terminal states.

You must also choose a hashable representation for an action as used in `getPossibleActions` and `takeAction`. Typically this would be a class with a custom `__hash__` method, but it could also simply be a tuple or a string.
You must also choose a hashable representation for an action as used in `getPossibleActions` and `takeAction`. Typically this would be a class with a custom `__hash__` method, but it could also simply be a tuple or a string.

Once these have been implemented, running MCTS is as simple as initializing your starting state, then running:

```python
from mcts import mcts

mcts = mcts(timeLimit=1000)
bestAction = mcts.search(initialState=initialState)
searcher = mcts(timeLimit=1000)
bestAction = searcher.search(initialState=initialState)
```
See [naughtsandcrosses.py](https://github.com/pbsinclair42/MCTS/blob/master/naughtsandcrosses.py) for a simple example.
Here the unit of `timeLimit=1000` is millisecond. You can also use `iterationLimit=1600` to specify the number of rollouts. Exactly one of `timeLimit` and `iterationLimit` should be specified. The expected reward of best action can be got by setting `needDetails` to `True` in `searcher`.

```python
resultDict = searcher.search(initialState=initialState, needDetails=True)
print(resultDict.keys()) #currently includes dict_keys(['action', 'expectedReward'])
```

See [naughtsandcrosses.py](https://github.com/pbsinclair42/MCTS/blob/master/naughtsandcrosses.py) for a simple example.

### Alpha-Beta Pruning

The use of alpha-beta pruning is almost the same as MCTS. The only different is that `getReward()` is needed for all states.

```python
from mcts import abpruning
searcher=abpruning(deep=3)
bestAction=searcher.search(initialState)
```

The parameters for `abpruning`'s construction function are

* deep : search deepth;
* n_killer : number of killers in killer heuristic optimization, default is 2;
* gameinf : an upper bound of getReward() return values, used as "inf" in algorithm, default 65535.

After `search()` is called, details of children can be found in `searcher.children`, and `searcher.counter` records how many leaf nodes are visited. `searcher.children` is a dictinary looks like {action:value}.

## Slow Usage

### Write Your Own Policy

The default policy for this package is `randomPolicy` defined in `mcts.py`. Its structure is

```
def randomPolicy(state):
while not state.isTerminal():
action = random.choice(state.getPossibleActions())
state = state.takeAction(action)
return state.getReward()
```

By substituting it with a stronger policy, you can make the search more efficient. The new policy should be a function which takes `state` as its input and return reward from the point of view of `state`'s current player and will be hand over to mcts by changing `rolloutPolicy=randomPolicy` in `mcts`'s construct function. Pay attention to the sign of reward the policy function returned. Or it will play for its opponent. For example, suppose I have trained a neural network which can estimate the expected reward even the state is not terminal; I can use it to accelerate the rollout

```
def nnPolicy(state):
if state.isTerminal():
return state.getReward()
else:
return reward_estimated_by_neural_network
```

### More
//TODO

## Collaborating
Expand Down
2 changes: 2 additions & 0 deletions chess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
class ChessState():
pass
101 changes: 95 additions & 6 deletions mcts.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import math
import random
import heapq


def randomPolicy(state):
Expand All @@ -25,6 +26,13 @@ def __init__(self, state, parent):
self.totalReward = 0
self.children = {}

def __str__(self):
s=[]
s.append("totalReward: %s"%(self.totalReward))
s.append("numVisits: %d"%(self.numVisits))
s.append("isTerminal: %s"%(self.isTerminal))
s.append("possibleActions: %s"%(self.children.keys()))
return "%s: {%s}"%(self.__class__.__name__, ', '.join(s))

class mcts():
def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 / math.sqrt(2),
Expand All @@ -46,7 +54,7 @@ def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 /
self.explorationConstant = explorationConstant
self.rollout = rolloutPolicy

def search(self, initialState):
def search(self, initialState, needDetails=False):
self.root = treeNode(initialState, None)

if self.limitType == 'time':
Expand All @@ -58,9 +66,16 @@ def search(self, initialState):
self.executeRound()

bestChild = self.getBestChild(self.root, 0)
return self.getAction(self.root, bestChild)
action=(action for action, node in self.root.children.items() if node is bestChild).__next__()
if needDetails:
return {"action": action, "expectedReward": bestChild.totalReward / bestChild.numVisits}
else:
return action

def executeRound(self):
"""
execute a selection-expansion-simulation-backpropagation round
"""
node = self.selectNode(self.root)
reward = self.rollout(node.state)
self.backpropogate(node, reward)
Expand Down Expand Up @@ -104,7 +119,81 @@ def getBestChild(self, node, explorationValue):
bestNodes.append(child)
return random.choice(bestNodes)

def getAction(self, root, bestChild):
for action, node in root.children.items():
if node is bestChild:
return action
def trivialPolicy(state):
return state.getReward()

class abpruning():
def __init__(self, deep, rolloutPolicy = trivialPolicy, n_killer = 2, gameinf=65535):
"""
deep: how many layers to be search, must >= 1
gameinf: an upper bound of getReward() return values used as "inf" in algorithm
"""
self.deep = deep
self.rollout = rolloutPolicy
self.n_killer = n_killer
self.gameinf = gameinf
self.counter = 0

def search(self, initialState, needDetails=False):
children = {}
killers = {} # best actions of brother branches, for killer heuristic optimization
for action in initialState.getPossibleActions():
val,ks = self.alphabeta(initialState.takeAction(action), self.deep-1, -1*self.gameinf, self.gameinf, killers = killers)
children[action] = val
for k in ks:
killers[k] = killers.setdefault(k,0) + 1
self.children = children

"""CurrentPlayer=initialState.getCurrentPlayer()
if CurrentPlayer==1:
bestAction = max(self.children.items(),key=lambda x: x[1])
elif CurrentPlayer==-1:
bestAction = min(self.children.items(),key=lambda x: x[1])
else:
raise Exception("getCurrentPlayer() should return 1 or -1 rather than %s"%(CurrentPlayer,))

if needDetails:
return {"action": bestAction[0], "expectedReward": bestAction[1]}
else:
return bestAction[0]"""

def alphabeta(self, node, deep, alpha, beta, killers = {}):
if deep==0 or node.isTerminal():
self.counter += 1
return self.rollout(node),[]

CurrentPlayer=node.getCurrentPlayer()
actions = node.getPossibleActions()
actions.sort(key=lambda x: killers.get(x,-1),reverse=True)
subkillers = {}
bestactions = []
if CurrentPlayer == 1:
maxeval = -1*self.gameinf
for action in actions:
val,ks = self.alphabeta(node.takeAction(action), deep-1, alpha, beta, killers = subkillers)
maxeval = max(val,maxeval)
alpha = max(val, alpha)
bestactions.append((action,val))
if beta <= alpha:
break
for k in ks:
subkillers[k] = subkillers.setdefault(k,0) + 1
bestactions.sort(key=lambda x: x[1],reverse=True)
bestactions = [i[0] for i in bestactions[0:min(len(bestactions),self.n_killer)]]
return maxeval,bestactions
elif CurrentPlayer == -1:
mineval = self.gameinf
for action in actions:
val,ks = self.alphabeta(node.takeAction(action), deep-1, alpha, beta, killers = subkillers)
mineval = min(val, mineval)
beta = min(val, beta)
bestactions.append((action,val))
if beta <= alpha:
break
for k in ks:
subkillers[k] = subkillers.setdefault(k,0) + 1
bestactions.sort(key=lambda x: x[1])
bestactions = [i[0] for i in bestactions[0:min(len(bestactions),self.n_killer)]]
return mineval,bestactions
else:
raise Exception("getCurrentPlayer() should return 1 or -1 rather than %s"%(CurrentPlayer,))
38 changes: 32 additions & 6 deletions naughtsandcrosses.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,35 @@ def __eq__(self, other):
def __hash__(self):
return hash((self.x, self.y, self.player))


initialState = NaughtsAndCrossesState()
mcts = mcts(timeLimit=1000)
action = mcts.search(initialState=initialState)

print(action)
def test_01():
initialState = NaughtsAndCrossesState()
initialState.currentPlayer = 1

# "1" first 1 step win
# deep=1,2: {(0, 1): False, (0, 2): 1.0, (1, 2): False, (2, 1): False, (2, 2): False}
# deep>=3 : {(0, 1): 1.0, (0, 2): 1.0, (1, 2): False, (2, 1): 1.0, (2, 2): 1.0}
# searcher.counter| no pruning| vanilla alphabeta| n_killer=2
# deep=2 | 17 | 17 | 17
# deep=3 | 49 | 31 | 28
#initialState.board = [[-1, 0, 0], [-1, 1, 0], [1, 0, 0]]

# "1" first 3 step win
# deep=1,2: {(0, 0): False, (0, 1): False, (1, 2): False, (2, 1): False, (2, 2): False}
# deep>=3 : {(0, 0): False, (0, 1): False, (1, 2): False, (2, 1): 1.0, (2, 2): 1.0}
# searcher.counter| no pruning| vanilla alphabeta| n_killer=2
# deep=2 | 20 | 20 | 20
# deep=3 | 60 | 43 | 37
initialState.board = [[0, 0, -1], [-1, 1, 0], [1, 0, 0]]

from mcts import abpruning
searcher=abpruning(deep=3,n_killer=2)
action=searcher.search(initialState,needDetails=True)
print(searcher.children)
print(searcher.counter)

if __name__=="__main__":
#initialState = NaughtsAndCrossesState()
#searcher = mcts(timeLimit=1000)
#action = searcher.search(initialState=initialState)
#print(action)
test_01()