From 8dcfbd1a065fe4b9c51e9ef601af546a6597dbca Mon Sep 17 00:00:00 2001 From: Danny Hammer Date: Sun, 18 Aug 2024 22:36:26 -0600 Subject: [PATCH] feat: implemented TT for move ordering --- brogle/src/engine.rs | 27 +++-- brogle/src/lib.rs | 2 + brogle/src/search/searcher.rs | 57 +++++++-- brogle/src/search/transposition_table.rs | 142 +++++++++++++++++++++++ brogle_core/src/position.rs | 4 +- brogle_core/src/zobrist.rs | 7 +- 6 files changed, 215 insertions(+), 24 deletions(-) create mode 100644 brogle/src/search/transposition_table.rs diff --git a/brogle/src/engine.rs b/brogle/src/engine.rs index 46ce72b..8382de0 100644 --- a/brogle/src/engine.rs +++ b/brogle/src/engine.rs @@ -4,7 +4,7 @@ use std::{ sync::{ atomic::{AtomicBool, Ordering}, mpsc::{self, Sender}, - Arc, LazyLock, + Arc, LazyLock, RwLock, }, time::{Duration, Instant}, }; @@ -14,11 +14,11 @@ use brogle_core::{print_perft, Bitboard, Color, Game, Move, Position, Tile, FEN_ use log::{error, warn}; use threadpool::ThreadPool; -use crate::protocols::UciInfo; - use super::{ - protocols::{UciCommand, UciEngine, UciOption, UciResponse, UciScore, UciSearchOptions}, - search::Searcher, + protocols::{ + UciCommand, UciEngine, UciInfo, UciOption, UciResponse, UciScore, UciSearchOptions, + }, + search::{Searcher, TTable}, Evaluator, MATE, MAX_DEPTH, MAX_MATE, }; @@ -73,9 +73,9 @@ pub struct Engine { /// etc. game: Game, - // /// Transposition table for game states. - // ttable: TranspositionTable, - // + /// Transposition table for game states. + ttable: Arc>, + /// Whether to display additional information in `info` commands. /// /// Defaults to `false`.` @@ -613,6 +613,7 @@ impl UciEngine for Engine { let sender = self.sender.clone().unwrap(); let game = self.game.clone(); + let ttable = Arc::clone(&self.ttable); let max_depth = options.depth.unwrap_or(MAX_DEPTH); // Initialize bestmove to the first move available, if there are any @@ -627,7 +628,13 @@ impl UciEngine for Engine { } // Create a search instance with the appropriate thread data - let search = Searcher::new(&game, starttime, timeout, Arc::clone(&is_searching)); + let search = Searcher::new( + &game, + starttime, + timeout, + Arc::clone(&ttable), + Arc::clone(&is_searching), + ); // If we received an error, that means the search was stopped externally match search.start(depth) { @@ -733,7 +740,7 @@ impl Default for Engine { fn default() -> Self { Self { game: Game::default(), - // ttable: TranspositionTable::default(), + ttable: Arc::default(), debug: Arc::default(), is_searching: Arc::default(), sender: None, diff --git a/brogle/src/lib.rs b/brogle/src/lib.rs index 6f229d5..4f89849 100644 --- a/brogle/src/lib.rs +++ b/brogle/src/lib.rs @@ -5,7 +5,9 @@ pub use engine::*; pub mod search { pub mod searcher; + pub mod transposition_table; pub use searcher::*; + pub use transposition_table::*; } pub mod eval { diff --git a/brogle/src/search/searcher.rs b/brogle/src/search/searcher.rs index 240cf69..16a51f7 100644 --- a/brogle/src/search/searcher.rs +++ b/brogle/src/search/searcher.rs @@ -1,13 +1,15 @@ use std::sync::atomic::{AtomicBool, Ordering}; -use std::sync::Arc; +use std::sync::{Arc, RwLock}; use std::time::Duration; use std::time::Instant; use anyhow::{bail, Result}; -use brogle_core::{Game, Move, PieceKind}; +use brogle_core::{Game, Move, PieceKind, ZobristKey}; use crate::{value_of, Evaluator, Score, INF, MATE}; +use super::{NodeType, TTable, TTableEntry}; + pub struct SearchData { pub nodes_searched: usize, pub score: Score, @@ -28,6 +30,7 @@ impl Default for SearchData { pub struct Searcher<'a> { game: &'a Game, timeout: Duration, + ttable: Arc>, stopper: Arc, starttime: Instant, @@ -46,6 +49,7 @@ impl<'a> Searcher<'a> { game: &'a Game, starttime: Instant, timeout: Duration, + ttable: Arc>, stopper: Arc, ) -> Self { Self { @@ -53,6 +57,7 @@ impl<'a> Searcher<'a> { starttime, timeout, stopper, + ttable, data: SearchData::default(), } } @@ -71,15 +76,18 @@ impl<'a> Searcher<'a> { self.data.bestmove = moves.first().cloned(); - moves.sort_by_cached_key(|mv| score_move(self.game, mv)); + let tt_bestmove = self.get_tt_bestmove(self.game.key()); + moves.sort_by_cached_key(|mv| score_move(self.game, mv, tt_bestmove)); // Start with a default (very bad) result. let mut alpha = -INF; + let original_alpha = alpha; let beta = INF; let ply = 0; for i in 0..moves.len() { let mv = moves[i]; + // Make the score move on the position, getting a new position in return let new_pos = self.game.clone().with_move_made(mv); @@ -122,6 +130,9 @@ impl<'a> Searcher<'a> { } } + let bestmove = self.data.bestmove.unwrap(); // safe unwrap because if `moves` was empty, we would have returned earlier. So `bestmove` is guaranteed to be *something* + let flag = NodeType::new(self.data.score, original_alpha, beta); + self.save_to_ttable(self.game.key(), bestmove, self.data.score, 0, flag); Ok(self.data) } @@ -153,11 +164,13 @@ impl<'a> Searcher<'a> { }); } - moves.sort_by_cached_key(|mv| score_move(game, mv)); + let tt_bestmove = self.get_tt_bestmove(game.key()); + moves.sort_by_cached_key(|mv| score_move(game, mv, tt_bestmove)); // Start with a default (very bad) result. let mut best = -INF; let mut bestmove = moves[0]; // Safe because we already checked that moves isn't empty + let original_alpha = alpha; for i in 0..moves.len() { let mv = moves[i]; @@ -206,6 +219,8 @@ impl<'a> Searcher<'a> { } } + let flag = NodeType::new(best, original_alpha, beta); + self.save_to_ttable(game.key(), bestmove, best, depth, flag); Ok(best) } @@ -230,11 +245,13 @@ impl<'a> Searcher<'a> { return Ok(stand_pat); } - captures.sort_by_cached_key(|mv| score_move(game, mv)); + let tt_bestmove = self.get_tt_bestmove(game.key()); + captures.sort_by_cached_key(|mv| score_move(game, mv, tt_bestmove)); // let original_alpha = alpha; let mut best = stand_pat; let mut bestmove = captures[0]; // Safe because we already checked that moves isn't empty + let original_alpha = alpha; // Only search captures for i in 0..captures.len() { @@ -278,8 +295,28 @@ impl<'a> Searcher<'a> { } } + let flag = NodeType::new(best, original_alpha, beta); + self.save_to_ttable(game.key(), bestmove, best, 0, flag); Ok(alpha) } + + fn save_to_ttable( + &mut self, + key: ZobristKey, + bestmove: Move, + score: Score, + depth: u32, + flag: NodeType, + ) { + let entry = TTableEntry::new(key, bestmove, score, depth, flag); + let mut tt = self.ttable.write().unwrap(); + tt.store(entry); + } + + fn get_tt_bestmove(&self, key: ZobristKey) -> Option { + let tt = self.ttable.read().unwrap(); + tt.get(&key).map(|entry| entry.bestmove) + } } // TODO: verify the values are good: https://discord.com/channels/719576389245993010/719576389690589244/1268914745298391071 @@ -287,12 +324,10 @@ fn mvv_lva(kind: PieceKind, captured: PieceKind) -> i32 { 10 * value_of(captured) - value_of(kind) } -fn score_move(game: &Game, mv: &Move) -> i32 { - // if let Some(ponder) = self.result.ponder.take() { - // if *mv == ponder { - // return i32::MIN; - // } - // } +fn score_move(game: &Game, mv: &Move, tt_bestmove: Option) -> i32 { + if tt_bestmove.is_some_and(|tt_mv| tt_mv == *mv) { + return Score::MIN; + } let mut score = 0; let kind = game.kind_at(mv.from()).unwrap(); diff --git a/brogle/src/search/transposition_table.rs b/brogle/src/search/transposition_table.rs new file mode 100644 index 0000000..d600256 --- /dev/null +++ b/brogle/src/search/transposition_table.rs @@ -0,0 +1,142 @@ +use brogle_core::{Move, ZobristKey}; + +#[derive(PartialEq, Eq, Clone, Copy, Debug, Hash, Default)] +pub enum NodeType { + /// The score is exact + #[default] + Pv, + /// The score is less than alpha (upper bound) + All, + /// The score is greater than or equal to beta (lower bound) + Cut, +} + +impl NodeType { + pub fn new(score: i32, alpha: i32, beta: i32) -> Self { + if score <= alpha { + Self::All + } else if score >= beta { + Self::Cut + } else { + Self::Pv + } + } +} + +#[derive(PartialEq, Eq, Clone, Debug, Hash, Default)] +pub struct TTableEntry { + pub(crate) key: ZobristKey, + pub(crate) depth: u32, + pub(crate) bestmove: Move, + pub(crate) score: i32, + pub(crate) flag: NodeType, + pub(crate) age: usize, +} + +impl TTableEntry { + pub fn new(key: ZobristKey, bestmove: Move, score: i32, depth: u32, flag: NodeType) -> Self { + Self { + key, + bestmove, + score, + depth, + flag, + age: 0, + } + } +} + +/// Default size of the Transposition Table, in bytes +const DEFAULT_TTABLE_SIZE: usize = 1_048_576; // 1 mb + +#[derive(Debug)] +pub struct TTable(pub Vec>); + +impl TTable { + /// Create a new [`TTable`] that is `size` bytes. + /// + /// Its size will be `size_of::() * capacity` + pub fn new(size: usize) -> Self { + Self::from_capacity(size / size_of::()) + } + + /// Create a new [`TTable`] that can hold `capacity` entries. + pub fn from_capacity(capacity: usize) -> Self { + Self(vec![None; capacity]) + } + + pub fn index(&self, key: &ZobristKey) -> usize { + // TODO: Enforce size as a power of two so you can use & instead of % + key.inner() as usize % self.0.len() + } + + /// Get the entry, without regards for whether it matches the provided key + fn get_entry(&self, key: &ZobristKey) -> Option<&TTableEntry> { + // We can safely index as we've initialized this ttable to be non-empty + self.0[self.index(key)].as_ref() + } + + /// Mutably get the entry, without regards for whether it matches the provided key + fn get_entry_mut(&mut self, key: &ZobristKey) -> Option<&mut TTableEntry> { + let index = self.index(key); + self.0[index].as_mut() + } + + /// Get the entry if and only if it matches the provided key + pub fn get(&self, key: &ZobristKey) -> Option<&TTableEntry> { + if let Some(entry) = self.get_entry(key) { + if &entry.key == key { + return Some(entry); + } + } + None + } + + pub fn get_mut(&mut self, key: &ZobristKey) -> Option<&mut TTableEntry> { + if let Some(entry) = self.get_entry_mut(key) { + if &entry.key == key { + return Some(entry); + } + } + None + } + + pub fn update_flag(&mut self, key: &ZobristKey, flag: NodeType) { + if let Some(entry) = self.get_mut(key) { + entry.flag = flag; + } + } + + pub fn update_score(&mut self, key: &ZobristKey, new_score: i32) { + if let Some(entry) = self.get_mut(key) { + entry.score = new_score; + } + } + + /// Store `entry` in the table at `entry.key`, overriding whatever was there. + pub fn store(&mut self, entry: TTableEntry) { + self.insert(entry); + } + + /// Store `entry` in the table at `entry.key`, if the existing entry at `entry.key` is either `None` or has a lower `depth` than `entry`. + pub fn store_if_greater_depth(&mut self, entry: TTableEntry) { + if self + .get(&entry.key) + .is_some_and(|old_entry| old_entry.depth < entry.depth) + { + self.store(entry); + } + } + + /// Inserts `entry` in the table at `entry.key`, overriding whatever was there and returning a mutable reference to it. + pub fn insert(&mut self, entry: TTableEntry) -> &mut TTableEntry { + let index = self.index(&entry.key); + self.0[index].insert(entry) + } +} + +impl Default for TTable { + fn default() -> Self { + Self::new(DEFAULT_TTABLE_SIZE) + } +} diff --git a/brogle_core/src/position.rs b/brogle_core/src/position.rs index 836f93f..babfe34 100644 --- a/brogle_core/src/position.rs +++ b/brogle_core/src/position.rs @@ -409,8 +409,8 @@ impl Position { } /// Fetch the Zobrist hash key of this position. - pub fn key(&self) -> &ZobristKey { - &self.key + pub fn key(&self) -> ZobristKey { + self.key } /// Returns `true` if the half-move counter is 50 or greater. diff --git a/brogle_core/src/zobrist.rs b/brogle_core/src/zobrist.rs index 9e60444..30f8c9c 100644 --- a/brogle_core/src/zobrist.rs +++ b/brogle_core/src/zobrist.rs @@ -11,7 +11,7 @@ use super::{ const ZOBRIST_TABLE: ZobristHashTable = ZobristHashTable::new(); /// Represents a key generated from a Zobrist Hash -#[derive(Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Clone)] +#[derive(Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Clone, Copy)] pub struct ZobristKey(u64); impl ZobristKey { @@ -51,6 +51,11 @@ impl ZobristKey { key } + /// Return the inner `u64` of this key. + pub fn inner(&self) -> u64 { + self.0 + } + /// Adds/removes `hash_key` to this [`ZobristKey`]. /// /// This is done internally with the XOR operator.