Skip to content
This repository has been archived by the owner on Oct 8, 2024. It is now read-only.

Commit

Permalink
feat: implemented TT for move ordering
Browse files Browse the repository at this point in the history
  • Loading branch information
dannyhammer committed Aug 19, 2024
1 parent 05750e0 commit 8dcfbd1
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 24 deletions.
27 changes: 17 additions & 10 deletions brogle/src/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::{
sync::{
atomic::{AtomicBool, Ordering},
mpsc::{self, Sender},
Arc, LazyLock,
Arc, LazyLock, RwLock,
},
time::{Duration, Instant},
};
Expand All @@ -14,11 +14,11 @@ use brogle_core::{print_perft, Bitboard, Color, Game, Move, Position, Tile, FEN_
use log::{error, warn};
use threadpool::ThreadPool;

use crate::protocols::UciInfo;

use super::{
protocols::{UciCommand, UciEngine, UciOption, UciResponse, UciScore, UciSearchOptions},
search::Searcher,
protocols::{
UciCommand, UciEngine, UciInfo, UciOption, UciResponse, UciScore, UciSearchOptions,
},
search::{Searcher, TTable},
Evaluator, MATE, MAX_DEPTH, MAX_MATE,
};

Expand Down Expand Up @@ -73,9 +73,9 @@ pub struct Engine {
/// etc.
game: Game,

// /// Transposition table for game states.
// ttable: TranspositionTable,
//
/// Transposition table for game states.
ttable: Arc<RwLock<TTable>>,

/// Whether to display additional information in `info` commands.
///
/// Defaults to `false`.`
Expand Down Expand Up @@ -613,6 +613,7 @@ impl UciEngine for Engine {
let sender = self.sender.clone().unwrap();

let game = self.game.clone();
let ttable = Arc::clone(&self.ttable);
let max_depth = options.depth.unwrap_or(MAX_DEPTH);

// Initialize bestmove to the first move available, if there are any
Expand All @@ -627,7 +628,13 @@ impl UciEngine for Engine {
}

// Create a search instance with the appropriate thread data
let search = Searcher::new(&game, starttime, timeout, Arc::clone(&is_searching));
let search = Searcher::new(
&game,
starttime,
timeout,
Arc::clone(&ttable),
Arc::clone(&is_searching),
);

// If we received an error, that means the search was stopped externally
match search.start(depth) {
Expand Down Expand Up @@ -733,7 +740,7 @@ impl Default for Engine {
fn default() -> Self {
Self {
game: Game::default(),
// ttable: TranspositionTable::default(),
ttable: Arc::default(),
debug: Arc::default(),
is_searching: Arc::default(),
sender: None,
Expand Down
2 changes: 2 additions & 0 deletions brogle/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ pub use engine::*;

pub mod search {
pub mod searcher;
pub mod transposition_table;
pub use searcher::*;
pub use transposition_table::*;
}

pub mod eval {
Expand Down
57 changes: 46 additions & 11 deletions brogle/src/search/searcher.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use std::time::Instant;

use anyhow::{bail, Result};
use brogle_core::{Game, Move, PieceKind};
use brogle_core::{Game, Move, PieceKind, ZobristKey};

use crate::{value_of, Evaluator, Score, INF, MATE};

use super::{NodeType, TTable, TTableEntry};

pub struct SearchData {
pub nodes_searched: usize,
pub score: Score,
Expand All @@ -28,6 +30,7 @@ impl Default for SearchData {
pub struct Searcher<'a> {
game: &'a Game,
timeout: Duration,
ttable: Arc<RwLock<TTable>>,
stopper: Arc<AtomicBool>,
starttime: Instant,

Expand All @@ -46,13 +49,15 @@ impl<'a> Searcher<'a> {
game: &'a Game,
starttime: Instant,
timeout: Duration,
ttable: Arc<RwLock<TTable>>,
stopper: Arc<AtomicBool>,
) -> Self {
Self {
game,
starttime,
timeout,
stopper,
ttable,
data: SearchData::default(),
}
}
Expand All @@ -71,15 +76,18 @@ impl<'a> Searcher<'a> {

self.data.bestmove = moves.first().cloned();

moves.sort_by_cached_key(|mv| score_move(self.game, mv));
let tt_bestmove = self.get_tt_bestmove(self.game.key());
moves.sort_by_cached_key(|mv| score_move(self.game, mv, tt_bestmove));

// Start with a default (very bad) result.
let mut alpha = -INF;
let original_alpha = alpha;
let beta = INF;
let ply = 0;

for i in 0..moves.len() {
let mv = moves[i];

// Make the score move on the position, getting a new position in return
let new_pos = self.game.clone().with_move_made(mv);

Expand Down Expand Up @@ -122,6 +130,9 @@ impl<'a> Searcher<'a> {
}
}

let bestmove = self.data.bestmove.unwrap(); // safe unwrap because if `moves` was empty, we would have returned earlier. So `bestmove` is guaranteed to be *something*
let flag = NodeType::new(self.data.score, original_alpha, beta);
self.save_to_ttable(self.game.key(), bestmove, self.data.score, 0, flag);
Ok(self.data)
}

Expand Down Expand Up @@ -153,11 +164,13 @@ impl<'a> Searcher<'a> {
});
}

moves.sort_by_cached_key(|mv| score_move(game, mv));
let tt_bestmove = self.get_tt_bestmove(game.key());
moves.sort_by_cached_key(|mv| score_move(game, mv, tt_bestmove));

// Start with a default (very bad) result.
let mut best = -INF;
let mut bestmove = moves[0]; // Safe because we already checked that moves isn't empty
let original_alpha = alpha;

for i in 0..moves.len() {
let mv = moves[i];
Expand Down Expand Up @@ -206,6 +219,8 @@ impl<'a> Searcher<'a> {
}
}

let flag = NodeType::new(best, original_alpha, beta);
self.save_to_ttable(game.key(), bestmove, best, depth, flag);
Ok(best)
}

Expand All @@ -230,11 +245,13 @@ impl<'a> Searcher<'a> {
return Ok(stand_pat);
}

captures.sort_by_cached_key(|mv| score_move(game, mv));
let tt_bestmove = self.get_tt_bestmove(game.key());
captures.sort_by_cached_key(|mv| score_move(game, mv, tt_bestmove));

// let original_alpha = alpha;
let mut best = stand_pat;
let mut bestmove = captures[0]; // Safe because we already checked that moves isn't empty
let original_alpha = alpha;

// Only search captures
for i in 0..captures.len() {
Expand Down Expand Up @@ -278,21 +295,39 @@ impl<'a> Searcher<'a> {
}
}

let flag = NodeType::new(best, original_alpha, beta);
self.save_to_ttable(game.key(), bestmove, best, 0, flag);
Ok(alpha)
}

fn save_to_ttable(
&mut self,
key: ZobristKey,
bestmove: Move,
score: Score,
depth: u32,
flag: NodeType,
) {
let entry = TTableEntry::new(key, bestmove, score, depth, flag);
let mut tt = self.ttable.write().unwrap();
tt.store(entry);
}

fn get_tt_bestmove(&self, key: ZobristKey) -> Option<Move> {
let tt = self.ttable.read().unwrap();
tt.get(&key).map(|entry| entry.bestmove)
}
}

// TODO: verify the values are good: https://discord.com/channels/719576389245993010/719576389690589244/1268914745298391071
fn mvv_lva(kind: PieceKind, captured: PieceKind) -> i32 {
10 * value_of(captured) - value_of(kind)
}

fn score_move(game: &Game, mv: &Move) -> i32 {
// if let Some(ponder) = self.result.ponder.take() {
// if *mv == ponder {
// return i32::MIN;
// }
// }
fn score_move(game: &Game, mv: &Move, tt_bestmove: Option<Move>) -> i32 {
if tt_bestmove.is_some_and(|tt_mv| tt_mv == *mv) {
return Score::MIN;
}

let mut score = 0;
let kind = game.kind_at(mv.from()).unwrap();
Expand Down
142 changes: 142 additions & 0 deletions brogle/src/search/transposition_table.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
use brogle_core::{Move, ZobristKey};

#[derive(PartialEq, Eq, Clone, Copy, Debug, Hash, Default)]
pub enum NodeType {
/// The score is exact
#[default]
Pv,
/// The score is less than alpha (upper bound)
All,
/// The score is greater than or equal to beta (lower bound)
Cut,
}

impl NodeType {
pub fn new(score: i32, alpha: i32, beta: i32) -> Self {
if score <= alpha {
Self::All
} else if score >= beta {
Self::Cut
} else {
Self::Pv
}
}
}

#[derive(PartialEq, Eq, Clone, Debug, Hash, Default)]
pub struct TTableEntry {
pub(crate) key: ZobristKey,
pub(crate) depth: u32,
pub(crate) bestmove: Move,
pub(crate) score: i32,
pub(crate) flag: NodeType,
pub(crate) age: usize,
}

impl TTableEntry {
pub fn new(key: ZobristKey, bestmove: Move, score: i32, depth: u32, flag: NodeType) -> Self {
Self {
key,
bestmove,
score,
depth,
flag,
age: 0,
}
}
}

/// Default size of the Transposition Table, in bytes
const DEFAULT_TTABLE_SIZE: usize = 1_048_576; // 1 mb

#[derive(Debug)]
pub struct TTable(pub Vec<Option<TTableEntry>>);

impl TTable {
/// Create a new [`TTable`] that is `size` bytes.
///
/// Its size will be `size_of::<TTableEntry>() * capacity`
pub fn new(size: usize) -> Self {
Self::from_capacity(size / size_of::<TTableEntry>())
}

/// Create a new [`TTable`] that can hold `capacity` entries.
pub fn from_capacity(capacity: usize) -> Self {
Self(vec![None; capacity])
}

pub fn index(&self, key: &ZobristKey) -> usize {
// TODO: Enforce size as a power of two so you can use & instead of %
key.inner() as usize % self.0.len()
}

/// Get the entry, without regards for whether it matches the provided key
fn get_entry(&self, key: &ZobristKey) -> Option<&TTableEntry> {
// We can safely index as we've initialized this ttable to be non-empty
self.0[self.index(key)].as_ref()
}

/// Mutably get the entry, without regards for whether it matches the provided key
fn get_entry_mut(&mut self, key: &ZobristKey) -> Option<&mut TTableEntry> {
let index = self.index(key);
self.0[index].as_mut()
}

/// Get the entry if and only if it matches the provided key
pub fn get(&self, key: &ZobristKey) -> Option<&TTableEntry> {
if let Some(entry) = self.get_entry(key) {
if &entry.key == key {
return Some(entry);
}
}
None
}

pub fn get_mut(&mut self, key: &ZobristKey) -> Option<&mut TTableEntry> {
if let Some(entry) = self.get_entry_mut(key) {
if &entry.key == key {
return Some(entry);
}
}
None
}

pub fn update_flag(&mut self, key: &ZobristKey, flag: NodeType) {
if let Some(entry) = self.get_mut(key) {
entry.flag = flag;
}
}

pub fn update_score(&mut self, key: &ZobristKey, new_score: i32) {
if let Some(entry) = self.get_mut(key) {
entry.score = new_score;
}
}

/// Store `entry` in the table at `entry.key`, overriding whatever was there.
pub fn store(&mut self, entry: TTableEntry) {
self.insert(entry);
}

/// Store `entry` in the table at `entry.key`, if the existing entry at `entry.key` is either `None` or has a lower `depth` than `entry`.
pub fn store_if_greater_depth(&mut self, entry: TTableEntry) {
if self
.get(&entry.key)
.is_some_and(|old_entry| old_entry.depth < entry.depth)
{
self.store(entry);
}
}

/// Inserts `entry` in the table at `entry.key`, overriding whatever was there and returning a mutable reference to it.
pub fn insert(&mut self, entry: TTableEntry) -> &mut TTableEntry {
let index = self.index(&entry.key);
self.0[index].insert(entry)
}
}

impl Default for TTable {
fn default() -> Self {
Self::new(DEFAULT_TTABLE_SIZE)
}
}
4 changes: 2 additions & 2 deletions brogle_core/src/position.rs
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,8 @@ impl Position {
}

/// Fetch the Zobrist hash key of this position.
pub fn key(&self) -> &ZobristKey {
&self.key
pub fn key(&self) -> ZobristKey {
self.key
}

/// Returns `true` if the half-move counter is 50 or greater.
Expand Down
Loading

0 comments on commit 8dcfbd1

Please sign in to comment.