Skip to content

Commit

Permalink
introduce multi visit to reduce NNUE cost
Browse files Browse the repository at this point in the history
Instead of performing a single visit per playout, we conduct multiple visits per playout. This approach reduces the incremental update cost as we traverse down and up through the graph, leading to an increase in the number of visits per second. As a result, the original best-first MCTS search now incorporates some "depth-first" characteristics.

However, this method introduces potential bias into the search. To prevent the addition of too many nodes at once, which could distort the visit distribution among child nodes, we limit the number of new visits added to a node to a specified proportion of the parent node's visits.
  • Loading branch information
dhbloo committed Sep 3, 2024
1 parent 1784068 commit af7013f
Show file tree
Hide file tree
Showing 5 changed files with 127 additions and 88 deletions.
13 changes: 8 additions & 5 deletions Rapfi/search/mcts/node.h
Original file line number Diff line number Diff line change
Expand Up @@ -193,14 +193,17 @@ class Node
void updateStats();

/// Begin the visit of this node.
void beginVisit() { nVirtual.fetch_add(1, std::memory_order_acq_rel); }
void beginVisit(uint32_t newVisits)
{
nVirtual.fetch_add(newVisits, std::memory_order_acq_rel);
}

/// Finish the visit of this node by incrementing the total visits of this node.
void finishVisit(uint32_t delta)
void finishVisit(uint32_t newVisits, uint32_t actualNewVisits)
{
if (delta)
n.fetch_add(delta, std::memory_order_acq_rel);
nVirtual.fetch_add(-1, std::memory_order_release);
if (actualNewVisits)
n.fetch_add(actualNewVisits, std::memory_order_acq_rel);
nVirtual.fetch_add(-newVisits, std::memory_order_release);
}

/// Directly increment the total visits of this node by delta.
Expand Down
9 changes: 7 additions & 2 deletions Rapfi/search/mcts/parameter.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@

#pragma once

#include <cstdint>

namespace Search::MCTS {

constexpr uint32_t MaxNumVisitsPerPlayout = 100;
constexpr float MaxNewVisitsProp = 0.2f;

constexpr float CpuctExploration = 1.0f;
constexpr float CpuctExplorationLog = 0.4f;
constexpr float CpuctExplorationBase = 500;
Expand All @@ -34,7 +39,7 @@ constexpr float FpuUtilityBlendPow = 2.0f;
constexpr uint32_t MinTranspositionSkipVisits = 10;

constexpr bool UseLCBForBestmoveSelection = false;
constexpr float LCBStdevs = 5;
constexpr float MinVisitPropForLCB = 0.2f;
constexpr float LCBStdevs = 5;
constexpr float MinVisitPropForLCB = 0.2f;

} // namespace Search::MCTS
185 changes: 107 additions & 78 deletions Rapfi/search/mcts/search.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,6 @@ inline float puctSelectionValue(float childUtility,
return Q + U;
}

inline uint32_t requiredVisitToBalance(float q1,
float q2,
float p1,
float p2,
uint32_t n1,
uint32_t n2,
float cpuct)
{}

/// allocateOrFindNode: allocate a new node if it does not exist in the node table
/// @param nodeTable The node table to allocate or find the node
/// @param hash The hash key of the node
Expand All @@ -110,7 +101,7 @@ allocateOrFindNode(NodeTable &nodeTable, HashKey hash, uint32_t globalNodeAge)
/// @return A pair of (the non-null best child edge pointer, the child node pointer)
/// The child node pointer is nullptr if the edge is unexplored (has zero visit).
template <bool Root>
std::pair<Edge *, Node *> selectChild(Node &node, NodeTable &nodeTable, const Board &board)
std::pair<Edge *, Node *> selectChild(Node &node, const Board &board)
{
assert(!node.isLeaf());
SearchThread *thisThread = board.thisThread();
Expand Down Expand Up @@ -265,106 +256,130 @@ bool expandNode(Node &node, const SearchOptions &options, const Board &board, in
/// @param node The node to search, must been already allocated.
/// @param board The board state of this node. The board's hash must be equal to the node's.
/// @param ply The current search ply. Root node is zero.
/// @return The number of new visits added to this node.
/// @param visits The number of new visits for this playout.
/// @return The number of actual new visits added to this node.
template <bool Root = false>
uint32_t searchNode(Node &node, Board &board, int ply)
uint32_t searchNode(Node &node, Board &board, int ply, uint32_t newVisits)
{
assert(node.getHash() == board.zobristKey());

SearchThread *thisThread = board.thisThread();
SearchOptions &options = thisThread->options();
MCTSSearcher &searcher = static_cast<MCTSSearcher &>(*thisThread->threads.searcher());

// Return immediately if this node is unevaluated
if (node.getVisits() == 0)
// Discard visits in this node if it is unevaluated
uint32_t parentVisits = node.getVisits();
if (parentVisits == 0)
return 0;

if (Root)
thisThread->selDepth = 0;

// Cap new visits so that we dont do too much at one time
newVisits = std::min(newVisits, uint32_t(parentVisits * MaxNewVisitsProp) + 1);

// Return directly if this node is a terminal node and not at root
if (!Root && node.isTerminal()) {
node.incrementVisits(1);
return 1;
node.incrementVisits(newVisits);
return newVisits;
}

// Mark that we are now starting to visit this node
node.beginVisit();

// Make sure the parent node is expanded before we select a child
if (node.isLeaf()) {
bool noValidMove = expandNode<Root>(node, options, board, ply);

// If we found that there is no valid move, we mark this node as terminal
// node the finish this visit.
if (noValidMove) {
node.finishVisit(1);
return 1;
node.incrementVisits(newVisits);
return newVisits;
}
}

// Select the best edge to explore
auto [childEdge, childNode] = selectChild<Root>(node, *searcher.nodeTable, board);

// Make the move to reach the child node
Pos move = childEdge->getMove();
board.move(options.rule, move);
HashKey hash = board.zobristKey();

// Reaching a leaf node, expand it
bool allocatedNode = false;
if (!childNode) {
std::tie(childNode, allocatedNode) =
allocateOrFindNode(*searcher.nodeTable, hash, searcher.globalNodeAge);

// Remember this child node in the edge
childEdge->setChild(childNode);
}

uint32_t numNewVisits = 0;
// Evaluate the new child node if we are the one who really allocated the node
if (allocatedNode) {
evaluateNode(*childNode, options, board, ply + 1);
numNewVisits += 1;
}
// Continue to visit the selected child node
else {
// When transposition happens, we stop the playout if the child node has been
// visited more times than the parent node. Only continue the playout if the
// child node has been visited less times than the edge visits, or the absolute
// child node visits is less than the given threshold.
uint32_t childEdgeVisits = childEdge->getVisits();
uint32_t childNodeVisits = childNode->getVisits();
if (childEdgeVisits >= childNodeVisits || childNodeVisits < MinTranspositionSkipVisits)
numNewVisits += searchNode(*childNode, board, ply + 1);
else
numNewVisits += 1; // Increment edge visits without search the node
}

if (numNewVisits > 0) {
// Increment child edge visit count
childEdge->addVisits(numNewVisits);
bool stopThisPlayout = false;
uint32_t actualNewVisits = 0;
while (!stopThisPlayout && newVisits > 0) {
// Select the best edge to explore
auto [childEdge, childNode] = selectChild<Root>(node, board);

// Make the move to reach the child node
Pos move = childEdge->getMove();
board.move(options.rule, move);

// Reaching a leaf node, expand it
bool allocatedNode = false;
if (!childNode) {
HashKey hash = board.zobristKey();
std::tie(childNode, allocatedNode) =
allocateOrFindNode(*searcher.nodeTable, hash, searcher.globalNodeAge);

// Remember this child node in the edge
childEdge->setChild(childNode);
}

// Update the node's stats
node.updateStats();
}
// Evaluate the new child node if we are the one who really allocated the node
if (allocatedNode) {
// Mark that we are now starting to visit this node
node.beginVisit(1);
evaluateNode(*childNode, options, board, ply + 1);

// Increment child edge visit count
childEdge->addVisits(1);
node.updateStats();
node.finishVisit(1, 1);
actualNewVisits++;
newVisits--;
}
else {
// When transposition happens, we stop the playout if the child node has been
// visited more times than the parent node. Only continue the playout if the
// child node has been visited less times than the edge visits, or the absolute
// child node visits is less than the given threshold.
uint32_t childEdgeVisits = childEdge->getVisits();
uint32_t childNodeVisits = childNode->getVisits();
if (childEdgeVisits >= childNodeVisits
|| childNodeVisits < MinTranspositionSkipVisits) {
node.beginVisit(newVisits);
uint32_t actualChildNewVisits = searchNode(*childNode, board, ply + 1, newVisits);
assert(actualChildNewVisits <= newVisits);

if (actualChildNewVisits > 0) {
childEdge->addVisits(actualChildNewVisits);
node.updateStats();
actualNewVisits += actualChildNewVisits;
}
// Discard this playout if we can not make new visits to the best child,
// since some other thread is evaluating it
else
stopThisPlayout = true;

// Increment parent node visit
node.finishVisit(numNewVisits);
node.finishVisit(newVisits, actualChildNewVisits);
newVisits -= actualChildNewVisits;
}
else {
// Increment edge visits without search the node
childEdge->addVisits(1);
node.updateStats();
node.incrementVisits(1);
actualNewVisits++;
newVisits--;
}
}

// Undo the move
board.undo(options.rule);
// Undo the move
board.undo(options.rule);

// Record root move's seldepth
if constexpr (Root) {
auto rmIt = std::find(thisThread->rootMoves.begin(), thisThread->rootMoves.end(), move);
if (rmIt != thisThread->rootMoves.end()) {
RootMove &rm = *rmIt;
rm.selDepth = std::max(rm.selDepth, thisThread->selDepth);
// Record root move's seldepth
if constexpr (Root) {
auto rmIt = std::find(thisThread->rootMoves.begin(), thisThread->rootMoves.end(), move);
if (rmIt != thisThread->rootMoves.end()) {
RootMove &rm = *rmIt;
rm.selDepth = std::max(rm.selDepth, thisThread->selDepth);
}
}
}

return numNewVisits;
return actualNewVisits;
}

/// Select best move to play for the given node.
Expand Down Expand Up @@ -405,6 +420,7 @@ int selectBestmoveOfChildNode(const Node &node,
float childPolicy = childEdge.getP();
float selectionValue = 2.0f * childPolicy;

assert(childNode->getVisits() > 0);
uint32_t childVisits = childEdge.getVisits();
// Skip zero visits children
if (childVisits > 0) {
Expand Down Expand Up @@ -618,12 +634,25 @@ void MCTSSearcher::search(SearchThread &th)
// Main search loop
std::vector<Node *> selectedPath;
while (!th.threads.isTerminating()) {
uint32_t newNumNodes = searchNode<true>(*root, board, 0);
uint32_t newNumPlayouts = MaxNumVisitsPerPlayout;

// Cap new number of playouts to the maximum num nodes to visit
if (options.maxNodes) {
uint64_t nodesSearched = th.threads.nodesSearched();
if (nodesSearched >= options.maxNodes)
break;

uint64_t maxNodesToVisit = options.maxNodes - nodesSearched;
if (maxNodesToVisit < newNumPlayouts)
newNumPlayouts = maxNodesToVisit;
}

uint32_t newNumNodes = searchNode<true>(*root, board, 0, newNumPlayouts);
th.numNodes.fetch_add(newNumNodes, std::memory_order_relaxed);

if (th.isMainThread()) {
MainSearchThread &mainThread = static_cast<MainSearchThread &>(th);
mainThread.checkExit();
mainThread.checkExit(std::max(newNumNodes, 1u));

bool printRootMoves = false;
if (Config::NodesToPrintMCTSRootmoves > 0) {
Expand Down
6 changes: 4 additions & 2 deletions Rapfi/search/searchthread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,14 @@ void SearchThread::setBoardAndEvaluator(const Board &board)
this->board = std::make_unique<Board>(board, this);
}

void MainSearchThread::checkExit()
void MainSearchThread::checkExit(uint32_t elapsedCalls)
{
// We only check exit condition after a number of calls.
// This is to avoid expensive calculation in timeup condition checking.
if (callsCnt-- > 0)
if (callsCnt > elapsedCalls) {
callsCnt -= elapsedCalls;
return;
}

// Resets callsCnt
if (searchOptions.maxNodes)
Expand Down
2 changes: 1 addition & 1 deletion Rapfi/search/searchthread.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ struct MainSearchThread : public SearchThread
void search() override;
/// Check exit condition (time/nodes) and set ThreadPool's terminate flag.
/// @return True if we should stop the search now.
void checkExit();
void checkExit(uint32_t elapsedCalls = 1);
/// Mark pondering available for the last finished searching.
void markPonderingAvailable();
/// Start all non-main search threads. This function should be called in
Expand Down

0 comments on commit af7013f

Please sign in to comment.