-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lucasborboleta #14
Open
LucasBorboleta
wants to merge
12
commits into
pbsinclair42:master
Choose a base branch
from
LucasBorboleta:lucasborboleta
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Lucasborboleta #14
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
c37fcd2
Improved example + improved expand method
LucasBorboleta 689d895
Impove game play example with MCTS statistics on the selected action
LucasBorboleta 256af87
Update naughtsandcrosses.py
LucasBorboleta 8c3d805
Contribution prepared
LucasBorboleta b297d2b
Create kinarow.py
LucasBorboleta cfb0753
Update kinarow.py
LucasBorboleta 6a3483d
Update kinarow.py
LucasBorboleta e651677
Polishing new example renamed connectmnk
LucasBorboleta f48311a
Preparing pull request
LucasBorboleta 50bea15
Delete .gitignore
LucasBorboleta 7f8e159
Update naughtsandcrosses.py
LucasBorboleta 4e62716
Update naughtsandcrosses.py
LucasBorboleta File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,255 @@ | ||
from __future__ import division | ||
|
||
import copy | ||
from mcts import mcts | ||
import random | ||
|
||
|
||
class ConnectMNKState: | ||
"""ConnectMNKState models a Connect(m,n,k,1,1) game that generalizes | ||
the famous "Connect Four" itself equal to the Connect(7,6,4,1,1) game. | ||
|
||
Background from wikipedia: | ||
Connect(m,n,k,p,q) games are another generalization of gomoku to a board | ||
with m×n intersections, k in a row needed to win, p stones for each player | ||
to place, and q stones for the first player to place for the first move | ||
only. Each player may play only at the lowest unoccupied place in a column. | ||
In particular, Connect(m,n,6,2,1) is called Connect6. | ||
""" | ||
|
||
playerNames = {1:'O', -1:'X'} | ||
|
||
def __init__(self, mColumns=7, nRows=6, kConnections=4): | ||
self.mColumns = mColumns | ||
self.nRows = nRows | ||
self.kConnections = kConnections | ||
self.board = [ [0 for _ in range(self.mColumns)] for _ in range(self.nRows)] | ||
self.currentPlayer = max(ConnectMNKState.playerNames.keys()) | ||
self.isTerminated = None | ||
self.reward = None | ||
self.possibleActions = None | ||
self.winingPattern = None | ||
|
||
def show(self): | ||
rowText = "" | ||
for columnIndex in range(self.mColumns): | ||
rowText += f" {columnIndex % 10} " | ||
print(rowText) | ||
|
||
for rowIndex in reversed(range(self.nRows)): | ||
rowText = "" | ||
for x in self.board[rowIndex]: | ||
if x in self.playerNames: | ||
rowText += f" {self.playerNames[x]} " | ||
else: | ||
rowText += " . " | ||
rowText += f" {rowIndex % 10} " | ||
print(rowText) | ||
|
||
def getCurrentPlayer(self): | ||
return self.currentPlayer | ||
|
||
def getPossibleActions(self): | ||
if self.possibleActions is None: | ||
self.possibleActions = [] | ||
for columnIndex in range(self.mColumns): | ||
for rowIndex in range(self.nRows): | ||
if self.board[rowIndex][columnIndex] == 0: | ||
action = Action(player=self.currentPlayer, | ||
columnIndex=columnIndex, | ||
rowIndex=rowIndex) | ||
self.possibleActions.append(action) | ||
break | ||
return self.possibleActions | ||
|
||
def takeAction(self, action): | ||
newState = copy.copy(self) | ||
newState.board = copy.deepcopy(newState.board) | ||
newState.board[action.rowIndex][action.columnIndex] = action.player | ||
newState.currentPlayer = self.currentPlayer * -1 | ||
newState.isTerminated = None | ||
newState.possibleActions = None | ||
newState.winingPattern = None | ||
return newState | ||
|
||
def isTerminal(self): | ||
if self.isTerminated is None: | ||
self.isTerminated = False | ||
for rowIndex in range(self.nRows): | ||
line = self.board[rowIndex] | ||
lineReward = self.__getLineReward(line) | ||
if lineReward != 0: | ||
self.isTerminated = True | ||
self.reward = lineReward | ||
self.winingPattern = "k-in-row" | ||
break | ||
|
||
if not self.isTerminated: | ||
for columnIndex in range(self.mColumns): | ||
line = [] | ||
for rowIndex in range(self.nRows): | ||
line.append(self.board[rowIndex][columnIndex]) | ||
lineReward = self.__getLineReward(line) | ||
if lineReward != 0: | ||
self.isTerminated = True | ||
self.reward = lineReward | ||
self.winingPattern = "k-in-column" | ||
break | ||
|
||
if not self.isTerminated: | ||
# diagonal: rowIndex = columnIndex + parameter | ||
for parameter in range(1 - self.mColumns, self.nRows): | ||
line = [] | ||
for columnIndex in range(self.mColumns): | ||
rowIndex = columnIndex + parameter | ||
if 0 <= rowIndex < self.nRows: | ||
line.append(self.board[rowIndex][columnIndex]) | ||
lineReward = self.__getLineReward(line) | ||
if lineReward != 0: | ||
self.isTerminated = True | ||
self.reward = lineReward | ||
self.winingPattern = "k-in-diagonal" | ||
break | ||
|
||
if not self.isTerminated: | ||
# antidiagonal: rowIndex = - columnIndex + parameter | ||
for parameter in range(0, self.mColumns + self.nRows): | ||
line = [] | ||
for columnIndex in range(self.mColumns): | ||
rowIndex = -columnIndex + parameter | ||
if 0 <= rowIndex < self.nRows: | ||
line.append(self.board[rowIndex][columnIndex]) | ||
lineReward = self.__getLineReward(line) | ||
if lineReward != 0: | ||
self.isTerminated = True | ||
self.reward = lineReward | ||
self.winingPattern = "k-in-antidiagonal" | ||
break | ||
|
||
if not self.isTerminated and len(self.getPossibleActions()) == 0: | ||
self.isTerminated = True | ||
self.reward = 0 | ||
|
||
return self.isTerminated | ||
|
||
def __getLineReward(self, line): | ||
lineReward = 0 | ||
if len(line) >= self.kConnections: | ||
for player in ConnectMNKState.playerNames.keys(): | ||
playerLine = [x == player for x in line] | ||
playerConnections = 0 | ||
for x in playerLine: | ||
if x: | ||
playerConnections += 1 | ||
if playerConnections == self.kConnections: | ||
lineReward = player | ||
break | ||
else: | ||
playerConnections = 0 | ||
if lineReward != 0: | ||
break | ||
return lineReward | ||
|
||
def getReward(self): | ||
assert self.isTerminal() | ||
assert self.reward is not None | ||
return self.reward | ||
|
||
|
||
class Action(): | ||
def __init__(self, player, columnIndex, rowIndex): | ||
self.player = player | ||
self.rowIndex = rowIndex | ||
self.columnIndex = columnIndex | ||
|
||
def __str__(self): | ||
return str((self.columnIndex, self.rowIndex)) | ||
|
||
def __repr__(self): | ||
return str(self) | ||
|
||
def __eq__(self, other): | ||
return self.__class__ == (other.__class__ and | ||
self.player == other.player and | ||
self.columnIndex == other.columnIndex and | ||
self.rowIndex == other.rowIndex) | ||
|
||
def __hash__(self): | ||
return hash((self.columnIndex, self.rowIndex, self.player)) | ||
|
||
|
||
def extractStatistics(searcher, action): | ||
statistics = {} | ||
statistics['rootNumVisits'] = searcher.root.numVisits | ||
statistics['rootTotalReward'] = searcher.root.totalReward | ||
statistics['actionNumVisits'] = searcher.root.children[action].numVisits | ||
statistics['actionTotalReward'] = searcher.root.children[action].totalReward | ||
return statistics | ||
|
||
|
||
def main(): | ||
"""Run a full match between two MCTS searchers, possibly with different | ||
parametrization, playing a Connect(m,n,k) game. | ||
|
||
Extraction of MCTS statistics is examplified. | ||
|
||
The game parameters (m,n,k) are randomly chosen. | ||
""" | ||
|
||
searchers = {} | ||
searchers["mcts-1500ms"] = mcts(timeLimit=1_500) | ||
searchers["mcts-1000ms"] = mcts(timeLimit=1_000) | ||
searchers["mcts-500ms"] = mcts(timeLimit=500) | ||
searchers["mcts-250ms"] = mcts(timeLimit=250) | ||
|
||
playerNames = ConnectMNKState.playerNames | ||
|
||
playerSearcherNames = {} | ||
for player in sorted(playerNames.keys()): | ||
playerSearcherNames[player] = random.choice(sorted(searchers.keys())) | ||
|
||
runnableGames = list() | ||
runnableGames.append((3, 3, 3)) | ||
runnableGames.append((7, 6, 4)) | ||
runnableGames.append((8, 7, 5)) | ||
runnableGames.append((9, 8, 6)) | ||
(m, n, k) = random.choice(runnableGames) | ||
currentState = ConnectMNKState(mColumns=m, nRows=n, kConnections=k) | ||
|
||
turn = 0 | ||
currentState.show() | ||
while not currentState.isTerminal(): | ||
turn += 1 | ||
player = currentState.getCurrentPlayer() | ||
action_count = len(currentState.getPossibleActions()) | ||
|
||
searcherName = playerSearcherNames[player] | ||
searcher = searchers[searcherName] | ||
|
||
action = searcher.search(initialState=currentState) | ||
statistics = extractStatistics(searcher, action) | ||
currentState = currentState.takeAction(action) | ||
|
||
print(f"at turn {turn} player {playerNames[player]}={player} ({searcherName})" + | ||
f" takes action (column, row)={action} amongst {action_count} possibilities") | ||
|
||
print("mcts statitics:" + | ||
f" chosen action= {statistics['actionTotalReward']} total reward" + | ||
f" over {statistics['actionNumVisits']} visits /" | ||
f" all explored actions= {statistics['rootTotalReward']} total reward" + | ||
f" over {statistics['rootNumVisits']} visits") | ||
|
||
print('-'*120) | ||
currentState.show() | ||
|
||
print('-'*120) | ||
if currentState.getReward() == 0: | ||
print(f"Connect(m={m},n={n},k={k}) game terminates; nobody wins") | ||
else: | ||
print(f"Connect(m={m},n={n},k={k}) game terminates;" + | ||
f" player {playerNames[player]}={player} ({searcherName}) wins" + | ||
f" with pattern {currentState.winingPattern}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,7 +27,7 @@ def __init__(self, state, parent): | |
|
||
|
||
class mcts(): | ||
def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=1 / math.sqrt(2), | ||
def __init__(self, timeLimit=None, iterationLimit=None, explorationConstant=math.sqrt(2), | ||
rolloutPolicy=randomPolicy): | ||
if timeLimit != None: | ||
if iterationLimit != None: | ||
|
@@ -75,6 +75,7 @@ def selectNode(self, node): | |
|
||
def expand(self, node): | ||
actions = node.state.getPossibleActions() | ||
random.shuffle(actions) | ||
for action in actions: | ||
if action not in node.children: | ||
newNode = treeNode(node.state.takeAction(action), node) | ||
|
@@ -95,8 +96,8 @@ def getBestChild(self, node, explorationValue): | |
bestValue = float("-inf") | ||
bestNodes = [] | ||
for child in node.children.values(): | ||
nodeValue = node.state.getCurrentPlayer() * child.totalReward / child.numVisits + explorationValue * math.sqrt( | ||
2 * math.log(node.numVisits) / child.numVisits) | ||
nodeValue = (node.state.getCurrentPlayer() * child.totalReward / child.numVisits + | ||
explorationValue * math.sqrt(math.log(node.numVisits) / child.numVisits)) | ||
Comment on lines
+99
to
+100
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good change 👍 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks ! |
||
if nodeValue > bestValue: | ||
bestValue = nodeValue | ||
bestNodes = [child] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the purpose of this shuffle?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When I experimented the full game of NaughtsAndCrosses extended to grid of large size, up to 10, I observed the following : allocating 1 second was not sufficient for the MCTS agent for winning against a random agent ; the game usually terminates as a draw. Printing the state of the game at each turn, I observed that the selected actions by MCTS on its successive turns are all located in the first rows. So MCTS behavior is somehow predictable when its exploration of the tree is not sufficient for finding at least a few winning paths. That is the reason for the added shuffle in the expand method. It could be interesting to quantify the gain of many (1000 ?) games : does the added shuffle increase or not the winned games of MCTS, when the tree is quite large ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha. I would argue that that deterministic behaviour on large search spaces is a symptom suggesting the search time needs to be significantly increased for the search to be effective, but I don't see any harm in putting this in here to avoid the ordering bias, aye.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On my game project, nammed JERSI-CERTU, to address the above mentionned issue, but still using some official release of your package MCTS, I have just shuffle the possible actions on the client side. What is your opinion on the two solutions ? If you want your code to be simple, the suffle by the client is better.