Skip to content

Commit

Permalink
search: add transposition table
Browse files Browse the repository at this point in the history
  • Loading branch information
e0ff committed Nov 26, 2023
1 parent 87e80fb commit e32a1ef
Show file tree
Hide file tree
Showing 3 changed files with 202 additions and 5 deletions.
2 changes: 1 addition & 1 deletion cmd/rosaline/interfaces/uci.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ loop:
bestScore = results.Score
}

fmt.Printf("info depth %d score cp %d nodes %d nps %f time %d\n", depth, results.Score, results.Nodes, results.NPS, results.Time.Milliseconds())
fmt.Printf("info depth %d score cp %d nodes %d nps %f time %d tbhits %d\n", depth, results.Score, results.Nodes, results.NPS, results.Time.Milliseconds(), results.Hits)
}

position.MakeMove(bestMove)
Expand Down
57 changes: 53 additions & 4 deletions internal/search/negamax.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ type SearchResults struct {
Nodes int
Time time.Duration
NPS float64
Hits int
Misses int
}

type NegamaxSearcher struct {
Expand All @@ -34,6 +36,8 @@ type NegamaxSearcher struct {
killerMoves map[chess.Color][]chess.Move
killerMoveIndex int

ttable TranspositionTable

stop bool

nodes int
Expand All @@ -45,6 +49,7 @@ func NewNegamaxSearcher(evaluator evaluation.Evaluator) NegamaxSearcher {
drawTable: newDrawTable(),
killerMoves: make(map[chess.Color][]chess.Move),
killerMoveIndex: 0,
ttable: NewTranspositionTable(),
nodes: 0,
}
}
Expand All @@ -53,6 +58,8 @@ func (s NegamaxSearcher) Search(position *chess.Position, depth int) SearchResul
s.nodes = 0
s.stop = false

s.ttable.ResetCounters()

start := time.Now()

bestMove := chess.Move{}
Expand Down Expand Up @@ -89,6 +96,8 @@ func (s NegamaxSearcher) Search(position *chess.Position, depth int) SearchResul
Nodes: s.nodes,
Time: elapsed,
NPS: nps,
Hits: s.ttable.hits,
Misses: s.ttable.misses,
}
}

Expand Down Expand Up @@ -139,6 +148,30 @@ func (s *NegamaxSearcher) doSearch(position *chess.Position, alpha int, beta int
}
}

s.nodes++

entry, ok := s.ttable.Get(position.Hash())
if ok {
if entry.Depth >= depth {
switch entry.Type {
case ExactNode:
return entry.Score
case UpperNode:
if entry.Score < alpha {
return alpha
}

break
case LowerNode:
if entry.Score > beta {
return beta
}

break
}
}
}

// null move pruning
doNullPruning := !inCheck && !pvNode
if doNullPruning && depth >= 3 && ply != 0 {
Expand All @@ -155,12 +188,14 @@ func (s *NegamaxSearcher) doSearch(position *chess.Position, alpha int, beta int
}
}

s.nodes++

slices.SortFunc(moves, func(m1, m2 chess.Move) int {
return cmp.Compare(s.scoreMove(position, m1), s.scoreMove(position, m2))
})

bestMove := chess.NullMove
bestScore := math.MinInt
nodeType := UpperNode

for _, move := range moves {
s.drawTable.Push(position.Hash())

Expand All @@ -170,7 +205,14 @@ func (s *NegamaxSearcher) doSearch(position *chess.Position, alpha int, beta int

s.drawTable.Pop()

if score > bestScore {
bestScore = score
bestMove = move
}

if score >= beta {
nodeType = LowerNode

if !move.HasFlag(chess.CaputureMoveFlag) {
turn := position.Turn()
length := len(s.killerMoves[turn])
Expand All @@ -190,15 +232,21 @@ func (s *NegamaxSearcher) doSearch(position *chess.Position, alpha int, beta int
}
}

return beta
break
}

if score > alpha {
alpha = score
nodeType = ExactNode
}
}

return alpha
if !s.stop {
entry := NewTableEntry(nodeType, bestMove, bestScore, depth, position.Plies())
s.ttable.Insert(position.Hash(), entry)
}

return bestScore
}

func (s NegamaxSearcher) quiescence(position *chess.Position, alpha int, beta int) int {
Expand Down Expand Up @@ -246,4 +294,5 @@ func (s *NegamaxSearcher) ClearPreviousSearch() {
func (s *NegamaxSearcher) Reset() {
s.drawTable.Clear()
s.ClearPreviousSearch()
s.ttable.Clear()
}
148 changes: 148 additions & 0 deletions internal/search/transposition.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package search

import (
"fmt"
"rosaline/internal/chess"
"unsafe"
)

type NodeType uint8

const (
ExactNode NodeType = iota
UpperNode
LowerNode
)

func (t NodeType) String() string {
switch t {
case ExactNode:
return "Exact"
case UpperNode:
return "Upper"
case LowerNode:
return "Lower"
}

panic(fmt.Sprintf("unknown NodeType '%d' encountered", t))
}

type TableEntry struct {
Type NodeType
Move chess.Move
Score int
Depth int
Age int
}

var emptyEntry = TableEntry{}

const (
entrySize = int(unsafe.Sizeof(emptyEntry))

kb = 1024
mb = kb * kb
maxTableSize = (64 * mb) / entrySize
)

// NewTableEntry creates a new TableEntry.
func NewTableEntry(nodeType NodeType, move chess.Move, score, depth int, age int) TableEntry {
return TableEntry{
Type: nodeType,
Move: move,
Score: score,
Depth: depth,
Age: age,
}
}

func (e TableEntry) String() string {
return fmt.Sprintf("<Entry: type: %s move: %s score: %d depth: %d>", e.Type, e.Move, e.Score, e.Depth)
}

type TranspositionTable struct {
table map[uint64]TableEntry
hits int
misses int
}

// NewTranspositionTable creates a new TranspositionTable.
func NewTranspositionTable() TranspositionTable {
return TranspositionTable{
table: make(map[uint64]TableEntry),
hits: 0,
misses: 0,
}
}

// Insert adds a new entry to the table.
func (t *TranspositionTable) Insert(hash uint64, entry TableEntry) {
if len(t.table) >= maxTableSize { // TODO: look into replacing old positions instead of clearing table
clear(t.table)
}

t.table[hash] = entry
}

// Remove removes an entry from the table.
func (t *TranspositionTable) Remove(hash uint64) {
delete(t.table, hash)
}

// Get retreives the entry that corresponds to the given hash.
func (t *TranspositionTable) Get(hash uint64) (TableEntry, bool) {
value, ok := t.table[hash]

if ok {
t.hits++
} else {
t.misses++
}

return value, ok
}

// Size returns the size of the table.
func (t TranspositionTable) Size() int {
return len(t.table)
}

// Hits returns the number times a position has been found in the table.
func (t TranspositionTable) Hits() int {
return t.hits
}

// Misses returns the number times a position wa not found in the table.
func (t TranspositionTable) Misses() int {
return t.misses
}

func (t TranspositionTable) List() {
entries := map[NodeType]int{
ExactNode: 0,
UpperNode: 0,
LowerNode: 0,
}

for hash, entry := range t.table {
fmt.Printf("%d: %s\n", hash, entry)
entries[entry.Type]++
}

for key, value := range entries {
fmt.Printf("%s: %d\n", key, value)
}
fmt.Println("# of entries:", len(t.table))
}

// ResetCounters resets the hits and misses counters.
func (t *TranspositionTable) ResetCounters() {
t.hits = 0
t.misses = 0
}

// Clear clears the table and resets the hits and misses counters.
func (t *TranspositionTable) Clear() {
clear(t.table)
t.ResetCounters()
}

0 comments on commit e32a1ef

Please sign in to comment.