Skip to content

Commit

Permalink
chore: move visit and score info to the edge
Browse files Browse the repository at this point in the history
  • Loading branch information
raklaptudirm committed Jun 6, 2024
1 parent 5c80c80 commit 913f392
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 104 deletions.
67 changes: 30 additions & 37 deletions src/mcts/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,37 +46,38 @@ impl Searcher {

let start = time::Instant::now();

let mut playouts = 0;

let mut depth = 0;
let mut seldepth = 0;
let mut cumulative_depth = 0;

loop {
let mut new_depth = 0;
let node = self.playout(&mut new_depth);
self.playout(&mut new_depth);
playouts += 1;

cumulative_depth += new_depth;
if new_depth > seldepth {
seldepth = new_depth;
}

let avg_depth = cumulative_depth / self.playouts();
let avg_depth = cumulative_depth / playouts;
if avg_depth > depth {
depth = avg_depth;

let node = self.node(node);

// Make a new info report.
println!(
"info depth {} seldepth {} score cp {:.0} nodes {} nps {}",
depth,
seldepth,
node.q() * 100.0,
self.playouts(),
0.0,
playouts,
self.nodes() * 1000 / start.elapsed().as_millis().max(1) as usize
);
}

if self.playouts() & 127 == 0 {
if playouts & 127 == 0 {
if start.elapsed().as_millis() >= movetime
|| depth >= maxdepth
|| self.nodes() >= maxnodes
Expand All @@ -96,15 +97,15 @@ impl Searcher {

println!(
"info depth {} seldepth {} score cp {:.0} nodes {} nps {}",
cumulative_depth / self.playouts(),
cumulative_depth / playouts,
seldepth,
100.0,
self.playouts(),
playouts,
self.nodes() * 1000 / start.elapsed().as_millis().max(1) as usize
);

// Verify the self.
debug_assert!(self.verify().is_ok());
// debug_assert!(self.verify().is_ok());

self.best_move()
}
Expand Down Expand Up @@ -134,7 +135,7 @@ impl Searcher {
break;
}

if (node_ptr == 0 && node.playouts == 0) || (node_ptr != 0 && node.playouts == 1) {
if !node.expanded() {
// If the selected Node's Edges haven't been expanded, expand.
node.expand(position, policy);
}
Expand All @@ -160,39 +161,28 @@ impl Searcher {
}

// v-----------------------v exploitation
// node-q + policy * cpuct * sqrt(node-visits) / (1 + child-visits = 1)
// child-q + policy * cpuct * sqrt(node-visits) / (1 + child-visits)
// node-q + policy * cpuct * sqrt(node-visits) / (1 + child-visits = 0) // not expanded
// child-q + policy * cpuct * sqrt(node-visits) / (1 + child-visits) // expanded
// ^-----^ score / visits
fn select_edge(&self, ptr: NodePtr) -> EdgePtr {
let node = self.node(ptr);
let parent = self.edge(node.parent_node, node.parent_edge);

// Node exploitation factor (cpuct * sqrt(parent-playouts))
let e = self.params.cpuct() * f64::sqrt(node.playouts.max(1) as f64);
let e = self.params.cpuct() * f64::sqrt(parent.visits.max(1) as f64);

let mut best_ptr: EdgePtr = -1;
let mut best_uct = 0.0;

// Q value (score / playouts) for parent node.
let node_q = node.q();

for (ptr, edge) in node.edges.iter().enumerate() {
let ptr = ptr as EdgePtr;

// Fetch the Q value, Policy value, and Playout count.
let (q, p, c) = if edge.ptr == -1 {
// Edge hasn't been expanded, so no node information available.
// Use the parent (current) node's information instead for uct.
(node_q, edge.policy, 1.0) // No child playouts, so playouts + 1 = 1
} else {
let child = self.node(edge.ptr);
(child.q(), edge.policy, (child.playouts + 1) as f64)
};
// If the edge hasn't been expanded yet, use the parent's q value.
let q = if edge.ptr == -1 { parent.q() } else { edge.q() };

let child_uct = q + p * e / c;
let child_uct = q + edge.policy * e / (edge.visits as f64 + 1.0);

// Check if we have a better UCT score for this edge.
if child_uct > best_uct {
best_ptr = ptr;
best_ptr = ptr as EdgePtr;
best_uct = child_uct;
}
}
Expand All @@ -214,12 +204,11 @@ impl Searcher {
*position = position.after_move::<true>(edge.mov);

// Expand the Edge into a new Node.
let new_node = Node::new(parent);
let new_node = Node::new(parent, edge_ptr);

// Add the new Node to the Tree.
let new_ptr = self.push_node(new_node);

let edge = self.node_mut(parent).edge_mut(edge_ptr);
let edge = self.edge_mut(parent, edge_ptr);

// Make the Edge point to the new Node.
edge.ptr = new_ptr;
Expand Down Expand Up @@ -247,17 +236,21 @@ impl Searcher {
let mut result = result;

loop {
let node = self.node_mut(node_ptr);
let node = self.node(node_ptr);
let parent_node = node.parent_node;
let parent_edge = node.parent_edge;

let edge = self.edge_mut(parent_node, parent_edge);

node.playouts += 1;
node.total_score += result;
edge.visits += 1;
edge.scores += result;

// Stop backpropagation if root has been reached.
if node_ptr == 0 {
break;
}

node_ptr = node.parent_node;
node_ptr = parent_node;
result = 1.0 - result;
}
}
Expand Down
33 changes: 20 additions & 13 deletions src/mcts/node.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,25 @@
use ataxx::MoveStore;

use super::policy;
use core::slice;

pub type NodePtr = isize;
pub type Result = f64;
pub type Score = f64;

#[derive(Clone)]
pub struct Node {
pub edges: Edges,

pub playouts: usize,
pub total_score: Result,
pub parent_node: NodePtr,
pub parent_edge: EdgePtr,
}

impl Node {
pub fn new(parent_node: NodePtr) -> Node {
pub fn new(parent_node: NodePtr, parent_edge: EdgePtr) -> Node {
Node {
// position,
edges: Edges::new(),

playouts: 0,
total_score: 0.0,
parent_node,
parent_edge,
}
}

Expand All @@ -47,14 +45,11 @@ impl Node {
pub fn edge_mut(&mut self, ptr: EdgePtr) -> &mut Edge {
&mut self.edges.edges[ptr as usize]
}
}

impl Node {
pub fn q(&self) -> f64 {
self.total_score / self.playouts.max(1) as f64
pub fn expanded(&mut self) -> bool {
self.edges.len() > 0
}
}

#[derive(Clone)]
pub struct Edges {
edges: Vec<Edge>,
Expand Down Expand Up @@ -93,6 +88,10 @@ pub type EdgePtr = isize;
pub struct Edge {
pub mov: ataxx::Move,
pub ptr: NodePtr,

pub visits: usize,
pub scores: Score,

pub policy: f64,
}

Expand All @@ -101,7 +100,15 @@ impl Edge {
Edge {
mov: m,
ptr: -1,

visits: 0,
scores: 0.0,

policy: 0.0,
}
}

pub fn q(&self) -> f64 {
self.scores / self.visits.max(1) as f64
}
}
Loading

0 comments on commit 913f392

Please sign in to comment.