Skip to content

Commit

Permalink
backprop with softmax weights (invT=0)
Browse files Browse the repository at this point in the history
test f15
  • Loading branch information
dhbloo committed Oct 15, 2024
1 parent 58ed3f3 commit 49d00f6
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions Rapfi/search/mcts/node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ void Node::updateStats()
if (!edgeArray)
return;

uint32_t nSum = 1;
constexpr float invTemp = 0.0f;

float wSum = 1 * std::exp((utility - 1.0f) * invTemp);
float qSum = utility;
float qSqrSum = utility * utility;
float dSum = drawRate;
Expand All @@ -168,14 +170,19 @@ void Node::updateStats()
float childQ = childNode->q.load(std::memory_order_relaxed);
float childQSqr = childNode->qSqr.load(std::memory_order_relaxed);
float childD = childNode->d.load(std::memory_order_relaxed);
nSum += childN;
qSum += childN * (-childQ); // Flip side for child's utility
qSqrSum += childN * childQSqr;
dSum += childN * childD;

// Compute the weight of this child node using softmax
// We minus childQ by the maximum Q value to avoid overflow.
float childW = childN * std::exp((childQ - 1.0f) * invTemp);

wSum += childW;
qSum += childW * (-childQ); // Flip side for child's utility
qSqrSum += childW * childQSqr;
dSum += childW * childD;
maxBound |= childNode->bound.load(std::memory_order_relaxed);
}

float norm = 1.0f / nSum;
float norm = 1.0f / wSum;
q.store(qSum * norm, std::memory_order_relaxed);
qSqr.store(qSqrSum * norm, std::memory_order_relaxed);
d.store(dSum * norm, std::memory_order_relaxed);
Expand Down

0 comments on commit 49d00f6

Please sign in to comment.