-
Notifications
You must be signed in to change notification settings - Fork 73
/
naughtsandcrosses.py
executable file
·81 lines (65 loc) · 2.54 KB
/
naughtsandcrosses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from __future__ import division
from copy import deepcopy
from mcts import mcts
from functools import reduce
import operator
class NaughtsAndCrossesState():
def __init__(self):
self.board = [[0, 0, 0], [0, 0, 0], [0, 0, 0]]
self.currentPlayer = 1
def getCurrentPlayer(self):
return self.currentPlayer
def getPossibleActions(self):
possibleActions = []
for i in range(len(self.board)):
for j in range(len(self.board[i])):
if self.board[i][j] == 0:
possibleActions.append(Action(player=self.currentPlayer, x=i, y=j))
return possibleActions
def takeAction(self, action):
newState = deepcopy(self)
newState.board[action.x][action.y] = action.player
newState.currentPlayer = self.currentPlayer * -1
return newState
def isTerminal(self):
for row in self.board:
if abs(sum(row)) == 3:
return True
for column in list(map(list, zip(*self.board))):
if abs(sum(column)) == 3:
return True
for diagonal in [[self.board[i][i] for i in range(len(self.board))],
[self.board[i][len(self.board) - i - 1] for i in range(len(self.board))]]:
if abs(sum(diagonal)) == 3:
return True
return reduce(operator.mul, sum(self.board, []), 1)
def getReward(self):
for row in self.board:
if abs(sum(row)) == 3:
return sum(row) / 3
for column in list(map(list, zip(*self.board))):
if abs(sum(column)) == 3:
return sum(column) / 3
for diagonal in [[self.board[i][i] for i in range(len(self.board))],
[self.board[i][len(self.board) - i - 1] for i in range(len(self.board))]]:
if abs(sum(diagonal)) == 3:
return sum(diagonal) / 3
return False
class Action():
def __init__(self, player, x, y):
self.player = player
self.x = x
self.y = y
def __str__(self):
return str((self.x, self.y))
def __repr__(self):
return str(self)
def __eq__(self, other):
return self.__class__ == other.__class__ and self.x == other.x and self.y == other.y and self.player == other.player
def __hash__(self):
return hash((self.x, self.y, self.player))
if __name__=="__main__":
initialState = NaughtsAndCrossesState()
searcher = mcts(timeLimit=1000)
action = searcher.search(initialState=initialState)
print(action)