diff --git a/src/mcts/mod.rs b/src/mcts/mod.rs index de9249d..d0aded7 100644 --- a/src/mcts/mod.rs +++ b/src/mcts/mod.rs @@ -46,37 +46,38 @@ impl Searcher { let start = time::Instant::now(); + let mut playouts = 0; + let mut depth = 0; let mut seldepth = 0; let mut cumulative_depth = 0; loop { let mut new_depth = 0; - let node = self.playout(&mut new_depth); + self.playout(&mut new_depth); + playouts += 1; cumulative_depth += new_depth; if new_depth > seldepth { seldepth = new_depth; } - let avg_depth = cumulative_depth / self.playouts(); + let avg_depth = cumulative_depth / playouts; if avg_depth > depth { depth = avg_depth; - let node = self.node(node); - // Make a new info report. println!( "info depth {} seldepth {} score cp {:.0} nodes {} nps {}", depth, seldepth, - node.q() * 100.0, - self.playouts(), + 0.0, + playouts, self.nodes() * 1000 / start.elapsed().as_millis().max(1) as usize ); } - if self.playouts() & 127 == 0 { + if playouts & 127 == 0 { if start.elapsed().as_millis() >= movetime || depth >= maxdepth || self.nodes() >= maxnodes @@ -96,15 +97,15 @@ impl Searcher { println!( "info depth {} seldepth {} score cp {:.0} nodes {} nps {}", - cumulative_depth / self.playouts(), + cumulative_depth / playouts, seldepth, 100.0, - self.playouts(), + playouts, self.nodes() * 1000 / start.elapsed().as_millis().max(1) as usize ); // Verify the self. - debug_assert!(self.verify().is_ok()); + // debug_assert!(self.verify().is_ok()); self.best_move() } @@ -134,7 +135,7 @@ impl Searcher { break; } - if (node_ptr == 0 && node.playouts == 0) || (node_ptr != 0 && node.playouts == 1) { + if !node.expanded() { // If the selected Node's Edges haven't been expanded, expand. node.expand(position, policy); } @@ -160,39 +161,28 @@ impl Searcher { } // v-----------------------v exploitation - // node-q + policy * cpuct * sqrt(node-visits) / (1 + child-visits = 1) - // child-q + policy * cpuct * sqrt(node-visits) / (1 + child-visits) + // node-q + policy * cpuct * sqrt(node-visits) / (1 + child-visits = 0) // not expanded + // child-q + policy * cpuct * sqrt(node-visits) / (1 + child-visits) // expanded // ^-----^ score / visits fn select_edge(&self, ptr: NodePtr) -> EdgePtr { let node = self.node(ptr); + let parent = self.edge(node.parent_node, node.parent_edge); // Node exploitation factor (cpuct * sqrt(parent-playouts)) - let e = self.params.cpuct() * f64::sqrt(node.playouts.max(1) as f64); + let e = self.params.cpuct() * f64::sqrt(parent.visits.max(1) as f64); let mut best_ptr: EdgePtr = -1; let mut best_uct = 0.0; - // Q value (score / playouts) for parent node. - let node_q = node.q(); - for (ptr, edge) in node.edges.iter().enumerate() { - let ptr = ptr as EdgePtr; - - // Fetch the Q value, Policy value, and Playout count. - let (q, p, c) = if edge.ptr == -1 { - // Edge hasn't been expanded, so no node information available. - // Use the parent (current) node's information instead for uct. - (node_q, edge.policy, 1.0) // No child playouts, so playouts + 1 = 1 - } else { - let child = self.node(edge.ptr); - (child.q(), edge.policy, (child.playouts + 1) as f64) - }; + // If the edge hasn't been expanded yet, use the parent's q value. + let q = if edge.ptr == -1 { parent.q() } else { edge.q() }; - let child_uct = q + p * e / c; + let child_uct = q + edge.policy * e / (edge.visits as f64 + 1.0); // Check if we have a better UCT score for this edge. if child_uct > best_uct { - best_ptr = ptr; + best_ptr = ptr as EdgePtr; best_uct = child_uct; } } @@ -214,12 +204,11 @@ impl Searcher { *position = position.after_move::(edge.mov); // Expand the Edge into a new Node. - let new_node = Node::new(parent); + let new_node = Node::new(parent, edge_ptr); // Add the new Node to the Tree. let new_ptr = self.push_node(new_node); - - let edge = self.node_mut(parent).edge_mut(edge_ptr); + let edge = self.edge_mut(parent, edge_ptr); // Make the Edge point to the new Node. edge.ptr = new_ptr; @@ -247,17 +236,21 @@ impl Searcher { let mut result = result; loop { - let node = self.node_mut(node_ptr); + let node = self.node(node_ptr); + let parent_node = node.parent_node; + let parent_edge = node.parent_edge; + + let edge = self.edge_mut(parent_node, parent_edge); - node.playouts += 1; - node.total_score += result; + edge.visits += 1; + edge.scores += result; // Stop backpropagation if root has been reached. if node_ptr == 0 { break; } - node_ptr = node.parent_node; + node_ptr = parent_node; result = 1.0 - result; } } diff --git a/src/mcts/node.rs b/src/mcts/node.rs index 83b0b66..30c6842 100644 --- a/src/mcts/node.rs +++ b/src/mcts/node.rs @@ -1,27 +1,25 @@ +use ataxx::MoveStore; + use super::policy; use core::slice; pub type NodePtr = isize; -pub type Result = f64; +pub type Score = f64; #[derive(Clone)] pub struct Node { pub edges: Edges, - pub playouts: usize, - pub total_score: Result, pub parent_node: NodePtr, + pub parent_edge: EdgePtr, } impl Node { - pub fn new(parent_node: NodePtr) -> Node { + pub fn new(parent_node: NodePtr, parent_edge: EdgePtr) -> Node { Node { - // position, edges: Edges::new(), - - playouts: 0, - total_score: 0.0, parent_node, + parent_edge, } } @@ -47,14 +45,11 @@ impl Node { pub fn edge_mut(&mut self, ptr: EdgePtr) -> &mut Edge { &mut self.edges.edges[ptr as usize] } -} -impl Node { - pub fn q(&self) -> f64 { - self.total_score / self.playouts.max(1) as f64 + pub fn expanded(&mut self) -> bool { + self.edges.len() > 0 } } - #[derive(Clone)] pub struct Edges { edges: Vec, @@ -93,6 +88,10 @@ pub type EdgePtr = isize; pub struct Edge { pub mov: ataxx::Move, pub ptr: NodePtr, + + pub visits: usize, + pub scores: Score, + pub policy: f64, } @@ -101,7 +100,15 @@ impl Edge { Edge { mov: m, ptr: -1, + + visits: 0, + scores: 0.0, + policy: 0.0, } } + + pub fn q(&self) -> f64 { + self.scores / self.visits.max(1) as f64 + } } diff --git a/src/mcts/tree.rs b/src/mcts/tree.rs index 3905aa5..95de077 100644 --- a/src/mcts/tree.rs +++ b/src/mcts/tree.rs @@ -1,18 +1,19 @@ -use super::{Node, NodePtr}; -use ataxx::MoveStore; +use super::{Edge, EdgePtr, Node, NodePtr}; #[derive(Clone)] pub struct Tree { root_pos: ataxx::Position, + root_edge: Edge, nodes: Vec, } impl Tree { pub fn new(position: ataxx::Position) -> Tree { - let root = Node::new(-1); + let root = Node::new(-1, -1); Tree { root_pos: position, + root_edge: Edge::new(ataxx::Move::NULL), nodes: vec![root], } } @@ -21,10 +22,6 @@ impl Tree { self.nodes.len() } - pub fn playouts(&self) -> usize { - self.node(0).playouts - } - pub fn root_position(&self) -> ataxx::Position { self.root_pos } @@ -37,6 +34,22 @@ impl Tree { &mut self.nodes[ptr as usize] } + pub fn edge(&self, parent: NodePtr, edge_ptr: EdgePtr) -> &Edge { + if parent == -1 { + &self.root_edge + } else { + self.node(parent).edge(edge_ptr) + } + } + + pub fn edge_mut(&mut self, parent: NodePtr, edge_ptr: EdgePtr) -> &mut Edge { + if parent == -1 { + &mut self.root_edge + } else { + self.node_mut(parent).edge_mut(edge_ptr) + } + } + pub fn push_node(&mut self, node: Node) -> NodePtr { self.nodes.push(node); self.nodes.len() as NodePtr - 1 @@ -52,8 +65,7 @@ impl Tree { continue; } - let node = self.node(edge.ptr); - let score = 1.0 - node.q(); + let score = 1.0 - edge.q(); if best_mov == ataxx::Move::NULL || score > best_scr { best_mov = edge.mov; @@ -65,48 +77,48 @@ impl Tree { } } -impl Tree { - pub fn verify(&self) -> Result<(), String> { - self.verify_node(0, self.root_pos) - } - - fn verify_node(&self, ptr: NodePtr, position: ataxx::Position) -> Result<(), String> { - let node = self.node(ptr); - if !(node.total_score >= 0.0 && node.total_score <= node.playouts as f64) { - return Err("node score out of bounds [0, playouts]".to_string()); - } - - let mut child_playouts = 0; - let mut policy_sum = 0.0; - for edge in node.edges.iter() { - policy_sum += edge.policy; - - if edge.ptr == -1 { - continue; - } - - let child_position = position.after_move::(edge.mov); - let child = self.node(edge.ptr); - - self.verify_node(edge.ptr, child_position)?; - - child_playouts += child.playouts; - } - - if node.edges.len() > 0 && (1.0 - policy_sum).abs() > 0.00001 { - return Err(format!("total playout probability {} not 1", policy_sum)); - } - - if (ptr == 0 && node.playouts != child_playouts) - || (ptr != 0 && !position.is_game_over() && node.playouts != child_playouts + 1) - { - println!("{}", position); - Err(format!( - "node playouts {} while child playouts {}", - node.playouts, child_playouts - )) - } else { - Ok(()) - } - } -} +// impl Tree { +// pub fn verify(&self) -> Result<(), String> { +// self.verify_node(0, self.root_pos) +// } + +// fn verify_node(&self, ptr: NodePtr, position: ataxx::Position) -> Result<(), String> { +// let node = self.node(ptr); +// if !(node.total_score >= 0.0 && node.total_score <= node.playouts as f64) { +// return Err("node score out of bounds [0, playouts]".to_string()); +// } + +// let mut child_playouts = 0; +// let mut policy_sum = 0.0; +// for edge in node.edges.iter() { +// policy_sum += edge.policy; + +// if edge.ptr == -1 { +// continue; +// } + +// let child_position = position.after_move::(edge.mov); +// let child = self.node(edge.ptr); + +// self.verify_node(edge.ptr, child_position)?; + +// child_playouts += child.playouts; +// } + +// if node.edges.len() > 0 && (1.0 - policy_sum).abs() > 0.00001 { +// return Err(format!("total playout probability {} not 1", policy_sum)); +// } + +// if (ptr == 0 && node.playouts != child_playouts) +// || (ptr != 0 && !position.is_game_over() && node.playouts != child_playouts + 1) +// { +// println!("{}", position); +// Err(format!( +// "node playouts {} while child playouts {}", +// node.playouts, child_playouts +// )) +// } else { +// Ok(()) +// } +// } +// }