diff --git a/cmd/rosaline/interfaces/uci.go b/cmd/rosaline/interfaces/uci.go index d6baebf..ac6a062 100644 --- a/cmd/rosaline/interfaces/uci.go +++ b/cmd/rosaline/interfaces/uci.go @@ -68,7 +68,7 @@ loop: bestScore = results.Score } - fmt.Printf("info depth %d score cp %d nodes %d nps %f time %d\n", depth, results.Score, results.Nodes, results.NPS, results.Time.Milliseconds()) + fmt.Printf("info depth %d score cp %d nodes %d nps %f time %d tbhits %d\n", depth, results.Score, results.Nodes, results.NPS, results.Time.Milliseconds(), results.Hits) } position.MakeMove(bestMove) diff --git a/internal/search/negamax.go b/internal/search/negamax.go index 9a3699a..a60c9c0 100644 --- a/internal/search/negamax.go +++ b/internal/search/negamax.go @@ -25,6 +25,8 @@ type SearchResults struct { Nodes int Time time.Duration NPS float64 + Hits int + Misses int } type NegamaxSearcher struct { @@ -34,6 +36,8 @@ type NegamaxSearcher struct { killerMoves map[chess.Color][]chess.Move killerMoveIndex int + ttable TranspositionTable + stop bool nodes int @@ -45,6 +49,7 @@ func NewNegamaxSearcher(evaluator evaluation.Evaluator) NegamaxSearcher { drawTable: newDrawTable(), killerMoves: make(map[chess.Color][]chess.Move), killerMoveIndex: 0, + ttable: NewTranspositionTable(), nodes: 0, } } @@ -53,6 +58,8 @@ func (s NegamaxSearcher) Search(position *chess.Position, depth int) SearchResul s.nodes = 0 s.stop = false + s.ttable.ResetCounters() + start := time.Now() bestMove := chess.Move{} @@ -89,6 +96,8 @@ func (s NegamaxSearcher) Search(position *chess.Position, depth int) SearchResul Nodes: s.nodes, Time: elapsed, NPS: nps, + Hits: s.ttable.hits, + Misses: s.ttable.misses, } } @@ -139,6 +148,30 @@ func (s *NegamaxSearcher) doSearch(position *chess.Position, alpha int, beta int } } + s.nodes++ + + entry, ok := s.ttable.Get(position.Hash()) + if ok { + if entry.Depth >= depth { + switch entry.Type { + case ExactNode: + return entry.Score + case UpperNode: + if entry.Score < alpha { + return alpha + } + + break + case LowerNode: + if entry.Score > beta { + return beta + } + + break + } + } + } + // null move pruning doNullPruning := !inCheck && !pvNode if doNullPruning && depth >= 3 && ply != 0 { @@ -155,12 +188,14 @@ func (s *NegamaxSearcher) doSearch(position *chess.Position, alpha int, beta int } } - s.nodes++ - slices.SortFunc(moves, func(m1, m2 chess.Move) int { return cmp.Compare(s.scoreMove(position, m1), s.scoreMove(position, m2)) }) + bestMove := chess.NullMove + bestScore := math.MinInt + nodeType := UpperNode + for _, move := range moves { s.drawTable.Push(position.Hash()) @@ -170,7 +205,14 @@ func (s *NegamaxSearcher) doSearch(position *chess.Position, alpha int, beta int s.drawTable.Pop() + if score > bestScore { + bestScore = score + bestMove = move + } + if score >= beta { + nodeType = LowerNode + if !move.HasFlag(chess.CaputureMoveFlag) { turn := position.Turn() length := len(s.killerMoves[turn]) @@ -190,15 +232,21 @@ func (s *NegamaxSearcher) doSearch(position *chess.Position, alpha int, beta int } } - return beta + break } if score > alpha { alpha = score + nodeType = ExactNode } } - return alpha + if !s.stop { + entry := NewTableEntry(nodeType, bestMove, bestScore, depth, position.Plies()) + s.ttable.Insert(position.Hash(), entry) + } + + return bestScore } func (s NegamaxSearcher) quiescence(position *chess.Position, alpha int, beta int) int { @@ -246,4 +294,5 @@ func (s *NegamaxSearcher) ClearPreviousSearch() { func (s *NegamaxSearcher) Reset() { s.drawTable.Clear() s.ClearPreviousSearch() + s.ttable.Clear() } diff --git a/internal/search/transposition.go b/internal/search/transposition.go new file mode 100644 index 0000000..e1e33b0 --- /dev/null +++ b/internal/search/transposition.go @@ -0,0 +1,148 @@ +package search + +import ( + "fmt" + "rosaline/internal/chess" + "unsafe" +) + +type NodeType uint8 + +const ( + ExactNode NodeType = iota + UpperNode + LowerNode +) + +func (t NodeType) String() string { + switch t { + case ExactNode: + return "Exact" + case UpperNode: + return "Upper" + case LowerNode: + return "Lower" + } + + panic(fmt.Sprintf("unknown NodeType '%d' encountered", t)) +} + +type TableEntry struct { + Type NodeType + Move chess.Move + Score int + Depth int + Age int +} + +var emptyEntry = TableEntry{} + +const ( + entrySize = int(unsafe.Sizeof(emptyEntry)) + + kb = 1024 + mb = kb * kb + maxTableSize = (64 * mb) / entrySize +) + +// NewTableEntry creates a new TableEntry. +func NewTableEntry(nodeType NodeType, move chess.Move, score, depth int, age int) TableEntry { + return TableEntry{ + Type: nodeType, + Move: move, + Score: score, + Depth: depth, + Age: age, + } +} + +func (e TableEntry) String() string { + return fmt.Sprintf("", e.Type, e.Move, e.Score, e.Depth) +} + +type TranspositionTable struct { + table map[uint64]TableEntry + hits int + misses int +} + +// NewTranspositionTable creates a new TranspositionTable. +func NewTranspositionTable() TranspositionTable { + return TranspositionTable{ + table: make(map[uint64]TableEntry), + hits: 0, + misses: 0, + } +} + +// Insert adds a new entry to the table. +func (t *TranspositionTable) Insert(hash uint64, entry TableEntry) { + if len(t.table) >= maxTableSize { // TODO: look into replacing old positions instead of clearing table + clear(t.table) + } + + t.table[hash] = entry +} + +// Remove removes an entry from the table. +func (t *TranspositionTable) Remove(hash uint64) { + delete(t.table, hash) +} + +// Get retreives the entry that corresponds to the given hash. +func (t *TranspositionTable) Get(hash uint64) (TableEntry, bool) { + value, ok := t.table[hash] + + if ok { + t.hits++ + } else { + t.misses++ + } + + return value, ok +} + +// Size returns the size of the table. +func (t TranspositionTable) Size() int { + return len(t.table) +} + +// Hits returns the number times a position has been found in the table. +func (t TranspositionTable) Hits() int { + return t.hits +} + +// Misses returns the number times a position wa not found in the table. +func (t TranspositionTable) Misses() int { + return t.misses +} + +func (t TranspositionTable) List() { + entries := map[NodeType]int{ + ExactNode: 0, + UpperNode: 0, + LowerNode: 0, + } + + for hash, entry := range t.table { + fmt.Printf("%d: %s\n", hash, entry) + entries[entry.Type]++ + } + + for key, value := range entries { + fmt.Printf("%s: %d\n", key, value) + } + fmt.Println("# of entries:", len(t.table)) +} + +// ResetCounters resets the hits and misses counters. +func (t *TranspositionTable) ResetCounters() { + t.hits = 0 + t.misses = 0 +} + +// Clear clears the table and resets the hits and misses counters. +func (t *TranspositionTable) Clear() { + clear(t.table) + t.ResetCounters() +}