diff --git a/Rapfi/search/mcts/node.h b/Rapfi/search/mcts/node.h index b3cac90..d9fc997 100644 --- a/Rapfi/search/mcts/node.h +++ b/Rapfi/search/mcts/node.h @@ -193,14 +193,17 @@ class Node void updateStats(); /// Begin the visit of this node. - void beginVisit() { nVirtual.fetch_add(1, std::memory_order_acq_rel); } + void beginVisit(uint32_t newVisits) + { + nVirtual.fetch_add(newVisits, std::memory_order_acq_rel); + } /// Finish the visit of this node by incrementing the total visits of this node. - void finishVisit(uint32_t delta) + void finishVisit(uint32_t newVisits, uint32_t actualNewVisits) { - if (delta) - n.fetch_add(delta, std::memory_order_acq_rel); - nVirtual.fetch_add(-1, std::memory_order_release); + if (actualNewVisits) + n.fetch_add(actualNewVisits, std::memory_order_acq_rel); + nVirtual.fetch_add(-newVisits, std::memory_order_release); } /// Directly increment the total visits of this node by delta. diff --git a/Rapfi/search/mcts/parameter.h b/Rapfi/search/mcts/parameter.h index 246866f..4c60249 100644 --- a/Rapfi/search/mcts/parameter.h +++ b/Rapfi/search/mcts/parameter.h @@ -18,8 +18,13 @@ #pragma once +#include + namespace Search::MCTS { +constexpr uint32_t MaxNumVisitsPerPlayout = 100; +constexpr float MaxNewVisitsProp = 0.2f; + constexpr float CpuctExploration = 1.0f; constexpr float CpuctExplorationLog = 0.4f; constexpr float CpuctExplorationBase = 500; @@ -34,7 +39,7 @@ constexpr float FpuUtilityBlendPow = 2.0f; constexpr uint32_t MinTranspositionSkipVisits = 10; constexpr bool UseLCBForBestmoveSelection = false; -constexpr float LCBStdevs = 5; -constexpr float MinVisitPropForLCB = 0.2f; +constexpr float LCBStdevs = 5; +constexpr float MinVisitPropForLCB = 0.2f; } // namespace Search::MCTS diff --git a/Rapfi/search/mcts/search.cpp b/Rapfi/search/mcts/search.cpp index a87b6cd..691b931 100644 --- a/Rapfi/search/mcts/search.cpp +++ b/Rapfi/search/mcts/search.cpp @@ -75,15 +75,6 @@ inline float puctSelectionValue(float childUtility, return Q + U; } -inline uint32_t requiredVisitToBalance(float q1, - float q2, - float p1, - float p2, - uint32_t n1, - uint32_t n2, - float cpuct) -{} - /// allocateOrFindNode: allocate a new node if it does not exist in the node table /// @param nodeTable The node table to allocate or find the node /// @param hash The hash key of the node @@ -110,7 +101,7 @@ allocateOrFindNode(NodeTable &nodeTable, HashKey hash, uint32_t globalNodeAge) /// @return A pair of (the non-null best child edge pointer, the child node pointer) /// The child node pointer is nullptr if the edge is unexplored (has zero visit). template -std::pair selectChild(Node &node, NodeTable &nodeTable, const Board &board) +std::pair selectChild(Node &node, const Board &board) { assert(!node.isLeaf()); SearchThread *thisThread = board.thisThread(); @@ -265,9 +256,10 @@ bool expandNode(Node &node, const SearchOptions &options, const Board &board, in /// @param node The node to search, must been already allocated. /// @param board The board state of this node. The board's hash must be equal to the node's. /// @param ply The current search ply. Root node is zero. -/// @return The number of new visits added to this node. +/// @param visits The number of new visits for this playout. +/// @return The number of actual new visits added to this node. template -uint32_t searchNode(Node &node, Board &board, int ply) +uint32_t searchNode(Node &node, Board &board, int ply, uint32_t newVisits) { assert(node.getHash() == board.zobristKey()); @@ -275,22 +267,23 @@ uint32_t searchNode(Node &node, Board &board, int ply) SearchOptions &options = thisThread->options(); MCTSSearcher &searcher = static_cast(*thisThread->threads.searcher()); - // Return immediately if this node is unevaluated - if (node.getVisits() == 0) + // Discard visits in this node if it is unevaluated + uint32_t parentVisits = node.getVisits(); + if (parentVisits == 0) return 0; if (Root) thisThread->selDepth = 0; + // Cap new visits so that we dont do too much at one time + newVisits = std::min(newVisits, uint32_t(parentVisits * MaxNewVisitsProp) + 1); + // Return directly if this node is a terminal node and not at root if (!Root && node.isTerminal()) { - node.incrementVisits(1); - return 1; + node.incrementVisits(newVisits); + return newVisits; } - // Mark that we are now starting to visit this node - node.beginVisit(); - // Make sure the parent node is expanded before we select a child if (node.isLeaf()) { bool noValidMove = expandNode(node, options, board, ply); @@ -298,73 +291,95 @@ uint32_t searchNode(Node &node, Board &board, int ply) // If we found that there is no valid move, we mark this node as terminal // node the finish this visit. if (noValidMove) { - node.finishVisit(1); - return 1; + node.incrementVisits(newVisits); + return newVisits; } } - // Select the best edge to explore - auto [childEdge, childNode] = selectChild(node, *searcher.nodeTable, board); - - // Make the move to reach the child node - Pos move = childEdge->getMove(); - board.move(options.rule, move); - HashKey hash = board.zobristKey(); - - // Reaching a leaf node, expand it - bool allocatedNode = false; - if (!childNode) { - std::tie(childNode, allocatedNode) = - allocateOrFindNode(*searcher.nodeTable, hash, searcher.globalNodeAge); - - // Remember this child node in the edge - childEdge->setChild(childNode); - } - - uint32_t numNewVisits = 0; - // Evaluate the new child node if we are the one who really allocated the node - if (allocatedNode) { - evaluateNode(*childNode, options, board, ply + 1); - numNewVisits += 1; - } - // Continue to visit the selected child node - else { - // When transposition happens, we stop the playout if the child node has been - // visited more times than the parent node. Only continue the playout if the - // child node has been visited less times than the edge visits, or the absolute - // child node visits is less than the given threshold. - uint32_t childEdgeVisits = childEdge->getVisits(); - uint32_t childNodeVisits = childNode->getVisits(); - if (childEdgeVisits >= childNodeVisits || childNodeVisits < MinTranspositionSkipVisits) - numNewVisits += searchNode(*childNode, board, ply + 1); - else - numNewVisits += 1; // Increment edge visits without search the node - } - - if (numNewVisits > 0) { - // Increment child edge visit count - childEdge->addVisits(numNewVisits); + bool stopThisPlayout = false; + uint32_t actualNewVisits = 0; + while (!stopThisPlayout && newVisits > 0) { + // Select the best edge to explore + auto [childEdge, childNode] = selectChild(node, board); + + // Make the move to reach the child node + Pos move = childEdge->getMove(); + board.move(options.rule, move); + + // Reaching a leaf node, expand it + bool allocatedNode = false; + if (!childNode) { + HashKey hash = board.zobristKey(); + std::tie(childNode, allocatedNode) = + allocateOrFindNode(*searcher.nodeTable, hash, searcher.globalNodeAge); + + // Remember this child node in the edge + childEdge->setChild(childNode); + } - // Update the node's stats - node.updateStats(); - } + // Evaluate the new child node if we are the one who really allocated the node + if (allocatedNode) { + // Mark that we are now starting to visit this node + node.beginVisit(1); + evaluateNode(*childNode, options, board, ply + 1); + + // Increment child edge visit count + childEdge->addVisits(1); + node.updateStats(); + node.finishVisit(1, 1); + actualNewVisits++; + newVisits--; + } + else { + // When transposition happens, we stop the playout if the child node has been + // visited more times than the parent node. Only continue the playout if the + // child node has been visited less times than the edge visits, or the absolute + // child node visits is less than the given threshold. + uint32_t childEdgeVisits = childEdge->getVisits(); + uint32_t childNodeVisits = childNode->getVisits(); + if (childEdgeVisits >= childNodeVisits + || childNodeVisits < MinTranspositionSkipVisits) { + node.beginVisit(newVisits); + uint32_t actualChildNewVisits = searchNode(*childNode, board, ply + 1, newVisits); + assert(actualChildNewVisits <= newVisits); + + if (actualChildNewVisits > 0) { + childEdge->addVisits(actualChildNewVisits); + node.updateStats(); + actualNewVisits += actualChildNewVisits; + } + // Discard this playout if we can not make new visits to the best child, + // since some other thread is evaluating it + else + stopThisPlayout = true; - // Increment parent node visit - node.finishVisit(numNewVisits); + node.finishVisit(newVisits, actualChildNewVisits); + newVisits -= actualChildNewVisits; + } + else { + // Increment edge visits without search the node + childEdge->addVisits(1); + node.updateStats(); + node.incrementVisits(1); + actualNewVisits++; + newVisits--; + } + } - // Undo the move - board.undo(options.rule); + // Undo the move + board.undo(options.rule); - // Record root move's seldepth - if constexpr (Root) { - auto rmIt = std::find(thisThread->rootMoves.begin(), thisThread->rootMoves.end(), move); - if (rmIt != thisThread->rootMoves.end()) { - RootMove &rm = *rmIt; - rm.selDepth = std::max(rm.selDepth, thisThread->selDepth); + // Record root move's seldepth + if constexpr (Root) { + auto rmIt = std::find(thisThread->rootMoves.begin(), thisThread->rootMoves.end(), move); + if (rmIt != thisThread->rootMoves.end()) { + RootMove &rm = *rmIt; + rm.selDepth = std::max(rm.selDepth, thisThread->selDepth); + } } } - return numNewVisits; + return actualNewVisits; } /// Select best move to play for the given node. @@ -405,6 +420,7 @@ int selectBestmoveOfChildNode(const Node &node, float childPolicy = childEdge.getP(); float selectionValue = 2.0f * childPolicy; + assert(childNode->getVisits() > 0); uint32_t childVisits = childEdge.getVisits(); // Skip zero visits children if (childVisits > 0) { @@ -618,12 +634,25 @@ void MCTSSearcher::search(SearchThread &th) // Main search loop std::vector selectedPath; while (!th.threads.isTerminating()) { - uint32_t newNumNodes = searchNode(*root, board, 0); + uint32_t newNumPlayouts = MaxNumVisitsPerPlayout; + + // Cap new number of playouts to the maximum num nodes to visit + if (options.maxNodes) { + uint64_t nodesSearched = th.threads.nodesSearched(); + if (nodesSearched >= options.maxNodes) + break; + + uint64_t maxNodesToVisit = options.maxNodes - nodesSearched; + if (maxNodesToVisit < newNumPlayouts) + newNumPlayouts = maxNodesToVisit; + } + + uint32_t newNumNodes = searchNode(*root, board, 0, newNumPlayouts); th.numNodes.fetch_add(newNumNodes, std::memory_order_relaxed); if (th.isMainThread()) { MainSearchThread &mainThread = static_cast(th); - mainThread.checkExit(); + mainThread.checkExit(std::max(newNumNodes, 1u)); bool printRootMoves = false; if (Config::NodesToPrintMCTSRootmoves > 0) { diff --git a/Rapfi/search/searchthread.cpp b/Rapfi/search/searchthread.cpp index 5de7440..3aa3d47 100644 --- a/Rapfi/search/searchthread.cpp +++ b/Rapfi/search/searchthread.cpp @@ -187,12 +187,14 @@ void SearchThread::setBoardAndEvaluator(const Board &board) this->board = std::make_unique(board, this); } -void MainSearchThread::checkExit() +void MainSearchThread::checkExit(uint32_t elapsedCalls) { // We only check exit condition after a number of calls. // This is to avoid expensive calculation in timeup condition checking. - if (callsCnt-- > 0) + if (callsCnt > elapsedCalls) { + callsCnt -= elapsedCalls; return; + } // Resets callsCnt if (searchOptions.maxNodes) diff --git a/Rapfi/search/searchthread.h b/Rapfi/search/searchthread.h index b1f3d0d..48bb931 100644 --- a/Rapfi/search/searchthread.h +++ b/Rapfi/search/searchthread.h @@ -140,7 +140,7 @@ struct MainSearchThread : public SearchThread void search() override; /// Check exit condition (time/nodes) and set ThreadPool's terminate flag. /// @return True if we should stop the search now. - void checkExit(); + void checkExit(uint32_t elapsedCalls = 1); /// Mark pondering available for the last finished searching. void markPonderingAvailable(); /// Start all non-main search threads. This function should be called in