Skip to content

Commit

Permalink
Use central finite difference for gradient estimation (#58)
Browse files Browse the repository at this point in the history
Bench: 11035603
  • Loading branch information
Dannyj1 authored Nov 28, 2023
1 parent 8559233 commit 6053725
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 17 deletions.
2 changes: 1 addition & 1 deletion src/features.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include <iostream>

namespace Zagreus {
int evalValues[70] = { 95, 105, 358, 349, 377, 356, 535, 529, 1011, 1004, 7, 3, 5, 4, 0, 7, 0, 18, 18, 1, -7, -1, 0, 7, 4, -2, -9, -3, -9, -1, -11, -22, -8, -14, 3, 20, 1, -3, -7, -8, -15, -11, -5, -4, 7, 0, -5, -6, -3, -6, -35, -32, 22, 12, 7, 6, 25, 21, 14, 8, 1, 13, 3, -2, 1, 18, 5, -5, 8, 6, };
int evalValues[70] = { 105, 103, 371, 354, 398, 370, 543, 548, 1032, 1021, 9, 5, 6, 4, 0, 7, 0, 22, 21, -1, -6, -1, -12, 9, 6, -5, -12, 0, -20, 1, -13, -30, -13, -6, 0, 23, -5, -10, -10, -9, -21, -9, -5, -9, 7, -7, -5, -9, -5, -17, -31, -17, 22, 5, 8, 7, 36, 22, 19, 0, -5, 0, 11, 4, 6, 16, 0, -6, 11, 0, };

int baseEvalValues[70] = {
100, // MIDGAME_PAWN_MATERIAL
Expand Down
25 changes: 12 additions & 13 deletions src/pst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,23 @@
#include <vector>

namespace Zagreus {
int midgamePawnTable[64] = { 6, 6, 6, 6, 6, 6, 6, 6, 103, 135, 66, 101, 76, 129, 42, -2, 2, 6, 36, 36, 64, 61, 26, -13, -8, 8, 5, 24, 25, 18, 24, -11, -41, -9, -4, 17, 20, 10, 5, -25, -25, -3, -7, -2, 0, 2, 10, -16, -23, 8, -6, -16, -14, 5, 27, -13, 6, 6, 6, 6, 6, 6, 6, 6 };
int endgamePawnTable[64] = { 6, 6, 6, 6, 6, 6, 6, 6, 184, 177, 162, 141, 153, 135, 172, 194, 99, 104, 93, 75, 57, 56, 86, 89, 37, 27, 14, -5, -2, 13, 18, 25, 7, 8, -1, -6, -4, -7, 1, 1, 2, 11, 1, 10, 5, 0, -1, -11, 17, 2, 13, 17, 21, 0, 0, 0, 6, 6, 6, 6, 6, 6, 6, 6 };
int midgamePawnTable[64] = { 0, 0, 0, 0, 0, 0, 0, 0, 89, 115, 57, 97, 72, 113, 37, -2, -6, 7, 39, 28, 42, 57, 20, -13, -12, 5, 0, 25, 27, 15, 14, -10, -41, -14, -3, 21, 23, 9, 0, -27, -19, -6, -4, 2, 0, 0, 6, -17, -22, 13, -7, -11, -12, 4, 25, -13, 0, 0, 0, 0, 0, 0, 0, 0 };
int endgamePawnTable[64] = { 0, 0, 0, 0, 0, 0, 0, 0, 174, 159, 150, 131, 147, 121, 169, 187, 94, 97, 86, 67, 42, 59, 80, 77, 33, 26, 22, 1, -3, 5, 7, 26, 14, 10, -7, -10, -12, -17, -3, -1, 4, 10, 0, 12, 13, -2, -10, 0, 11, 5, 17, 23, 18, 4, -8, 0, 0, 0, 0, 0, 0, 0, 0, 0 };

int midgameKnightTable[64] = { -162, -82, -26, -42, 66, -89, -8, -102, -66, -32, 78, 43, 30, 73, 13, -13, -37, 61, 43, 68, 87, 126, 76, 49, -3, 16, 22, 47, 38, 68, 17, 27, -9, 11, 16, 21, 30, 26, 27, -5, -20, -3, 14, 15, 22, 12, 19, -12, -20, -42, -4, 6, 10, 23, -6, -10, -99, -21, -49, -27, -10, -22, -14, -17 };
int endgameKnightTable[64] = { -51, -31, -6, -21, -25, -20, -56, -92, -20, 0, -17, 3, -2, -15, -17, -45, -15, -15, 18, 14, 5, -5, -12, -33, -10, 10, 28, 27, 30, 18, 12, -12, -13, 0, 17, 33, 15, 21, 11, -13, -16, 1, 0, 18, 18, -8, -15, -16, -35, -12, -2, 0, 1, -14, -18, -38, -22, -40, -13, -9, -15, -11, -46, -57 };
int midgameKnightTable[64] = { -167, -88, -41, -45, 58, -90, -14, -110, -65, -33, 73, 43, 32, 73, 0, -21, -38, 53, 35, 64, 78, 110, 61, 39, -2, 15, 16, 34, 38, 38, 15, 23, -6, 7, 18, 19, 23, 18, 18, -6, -27, 2, 11, 12, 17, 10, 9, -18, -15, -41, -10, 16, 16, 8, -5, -8, -100, -15, -46, -28, -19, -21, -13, -26 };
int endgameKnightTable[64] = { -55, -41, -17, -26, -33, -23, -63, -98, -26, -5, -22, -6, -5, -12, -19, -50, -19, -23, 13, 9, -4, -16, -21, -41, -12, 13, 25, 21, 20, 5, 11, -14, -9, -5, 15, 26, 18, 16, 0, -12, -21, 0, -17, 16, 10, -14, -30, -16, -39, -15, -6, -4, -9, -25, -23, -39, -27, -35, -13, -11, -12, -12, -42, -65 };

int midgameBishopTable[64] = { -22, 10, -73, -29, -19, -36, 12, 0, -17, 11, -14, -5, 35, 61, 21, -39, 0, 40, 45, 44, 37, 57, 41, 6, 0, 0, 17, 45, 38, 34, 8, 0, 2, 17, 6, 24, 35, 9, 14, 5, 10, 22, 20, 16, 12, 32, 18, 23, 9, 29, 20, 5, 11, 24, 45, 4, -26, 2, -6, -14, -4, 0, -30, -11 };
int endgameBishopTable[64] = { -7, -14, -3, -1, -1, -3, -11, -16, -3, -6, 11, -4, 2, -8, 0, -6, 12, -3, 5, 5, 3, 10, 3, 9, -1, 10, 14, 12, 19, 10, 7, 8, 0, 7, 14, 19, 12, 13, 1, -4, -4, 1, 14, 15, 11, 9, 0, -7, -7, -8, 0, 6, 9, -4, -6, -22, -16, -4, -20, 2, 0, -6, 2, -8 };
int midgameBishopTable[64] = { -26, 1, -72, -35, -29, -43, 0, -5, -18, 5, -22, -7, 22, 46, 10, -24, -5, 33, 39, 27, 27, 48, 26, 6, -6, 0, 9, 26, 25, 14, 3, -12, -4, 3, 0, 24, 28, 3, 7, -3, 9, 20, 18, 8, 17, 30, 12, 20, 11, 37, 17, 13, 16, 23, 47, -1, -20, -4, 0, -16, -3, 0, -23, -12 };
int endgameBishopTable[64] = { -10, -21, -4, -3, -10, -9, -22, -21, -3, -8, 7, -8, -9, -27, -9, -11, 8, -12, -1, -9, 0, 5, -8, 3, -5, 4, 3, 3, 12, 0, 0, -1, -10, 1, 8, 16, 3, 0, -9, -11, -5, -2, 15, 22, 7, 10, -4, -13, -12, -5, 0, 8, 8, -11, -7, -25, -18, -10, -7, 2, -8, -9, -3, -12 };

int midgameRookTable[64] = { 38, 48, 37, 58, 70, 16, 38, 49, 28, 32, 59, 64, 83, 70, 29, 46, 3, 24, 32, 42, 24, 51, 66, 20, -18, -7, 12, 32, 26, 38, 0, -12, -29, -21, -11, 0, 11, -1, 12, -20, -36, -20, -11, -11, 5, 5, 3, -26, -41, -8, -11, -2, 6, 7, -2, -64, -6, -13, 0, 9, 10, 2, -18, -13 };
int endgameRookTable[64] = { 19, 16, 23, 21, 18, 18, 16, 11, 14, 19, 15, 16, 3, 7, 14, 8, 13, 13, 13, 12, 10, 3, 1, 3, 9, 8, 18, 9, 5, 5, 4, 9, 10, 10, 10, 5, 0, 0, 0, -5, 2, 6, 0, 5, -2, -7, -3, -8, 0, 0, 4, 7, -2, -8, -5, 2, -3, 1, 3, -4, -6, -9, 14, -11 };
int midgameRookTable[64] = { 31, 41, 27, 53, 62, 11, 30, 40, 16, 12, 57, 49, 68, 57, 21, 38, 0, 15, 22, 37, 31, 44, 52, 13, -15, -8, 6, 26, 15, 30, -2, -14, -33, -24, -22, -10, 5, -12, 4, -28, -42, -23, -14, -12, 1, -1, -8, -27, -46, -13, -22, -10, -1, -5, -2, -52, -4, -11, -1, 0, 9, 4, -12, -9 };
int endgameRookTable[64] = { 13, 11, 13, 17, 9, 16, 8, 2, 4, 10, 7, 6, -10, 2, 15, 0, 10, 3, 8, 5, 12, -2, -4, -1, 2, 6, 12, 1, -6, 0, -1, 7, 0, 8, 5, 2, -7, -10, -10, -13, -6, 0, -7, 5, -8, -13, -11, -11, -7, -5, 0, 0, -8, -10, -4, 7, 0, 2, 0, -6, -9, -1, 6, -9 };

int midgameQueenTable[64] = { -21, 5, 36, 20, 65, 50, 48, 48, -21, -46, -1, 2, -10, 62, 22, 55, -5, -10, 14, 11, 39, 56, 52, 55, -15, -20, -12, -11, -2, 22, 0, 2, -10, -18, -9, -10, 0, 4, 3, 1, -11, 0, -5, -4, -5, 5, 17, 4, -29, 0, 8, 11, 14, 16, 7, 6, 4, -16, -3, 0, 7, -18, -19, -35 };
int endgameQueenTable[64] = { -2, 27, 28, 35, 33, 25, 16, 24, -10, 26, 38, 45, 66, 30, 33, 5, -14, 12, 15, 56, 56, 39, 25, 12, 11, 30, 30, 50, 60, 45, 62, 38, -11, 36, 23, 49, 35, 43, 41, 30, -10, -26, 18, 9, 14, 24, 16, 7, -16, -16, -24, -11, -5, -22, -27, -26, -27, -21, -15, -38, 0, -25, -12, -32 };

int midgameKingTable[64] = { -58, 29, 22, -8, -49, -27, 8, 19, 35, 5, -13, 0, -1, 2, -31, -22, -2, 30, 8, -9, -13, 12, 28, -15, -10, -13, -6, -20, -23, -18, -7, -29, -42, 4, -21, -32, -39, -36, -26, -44, -7, -7, -14, -41, -41, -23, -11, -19, 9, 15, -4, -60, -45, -16, 12, 13, 0, 23, 0, -16, 0, -31, 14, 10 };
int endgameKingTable[64] = { -67, -28, -11, -11, -4, 21, 10, -10, -5, 23, 20, 23, 23, 44, 29, 17, 16, 23, 29, 21, 26, 51, 50, 19, -1, 28, 30, 33, 32, 39, 32, 9, -11, 2, 26, 30, 33, 30, 15, -4, -13, 1, 19, 27, 29, 23, 11, 0, -22, -5, 9, 25, 21, 10, 1, -11, -43, -40, -21, 5, -32, -10, -25, -44 };
int midgameQueenTable[64] = { -26, -1, 31, 14, 55, 43, 43, 36, -15, -48, -6, -9, -24, 42, 1, 36, -5, -5, 5, 4, 28, 41, 41, 36, -13, -19, -18, -18, -11, 6, -11, 0, -16, -21, -17, -18, -17, -6, 2, -3, -17, 4, -3, -10, -4, -7, 6, 2, -24, 0, 8, 10, 12, 12, 13, 4, -4, -18, -12, 7, 7, -12, -21, -25 };
int endgameQueenTable[64] = { -6, 21, 22, 26, 23, 18, 7, 15, -9, 20, 33, 35, 49, 18, 19, -4, -17, 10, 14, 52, 49, 25, 17, 0, 6, 29, 23, 52, 51, 36, 56, 32, -4, 32, 18, 44, 29, 34, 38, 18, -15, -27, 18, 4, 3, 17, 5, 0, -19, -20, -29, -2, 0, -25, -27, -32, -34, -28, -25, -30, 0, -27, -16, -26 };

int midgameKingTable[64] = { -65, 22, 16, -15, -56, -34, 2, 13, 28, -1, -20, -6, -7, -3, -37, -29, -8, 24, 1, -15, -19, 5, 21, -22, -17, -21, -12, -27, -30, -25, -12, -36, -48, 0, -28, -37, -47, -46, -31, -52, -15, -15, -20, -53, -45, -24, -9, -26, 15, 16, -21, -70, -59, -24, 10, 9, 4, 20, -24, -9, -9, -37, 10, 0 };
int endgameKingTable[64] = { -74, -35, -18, -18, -11, 15, 4, -17, -12, 16, 13, 17, 17, 38, 23, 10, 9, 17, 23, 15, 20, 44, 43, 13, -8, 20, 24, 27, 25, 31, 28, 3, -16, -3, 19, 25, 23, 25, 12, -13, -16, 0, 17, 20, 24, 25, 15, -8, -25, -15, 8, 28, 24, 13, -5, -28, -53, -39, -25, -11, -32, -5, -28, -56 };

// Base tables from https://www.chessprogramming.org/PeSTO%27s_Evaluation_Function
int baseMidgamePawnTable[64] = {
Expand Down
5 changes: 2 additions & 3 deletions src/tuner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,13 +348,12 @@ namespace Zagreus {
updateEvalValues(bestParameters);
double lossPlus = evaluationLoss(position, 1, maxEndTime, engine);

/*bestParameters[paramIndex] = oldParam - delta;
bestParameters[paramIndex] = oldParam - delta;
updateEvalValues(bestParameters);
double lossMinus = evaluationLoss(position, 1, maxEndTime, engine);

gradients[paramIndex] += (lossPlus - lossMinus) / (2 * delta);*/
gradients[paramIndex] += (lossPlus - lossMinus) / (2 * delta);

gradients[paramIndex] += (lossPlus - loss) / delta;
// reset
bestParameters[paramIndex] = oldParam;
}
Expand Down

0 comments on commit 6053725

Please sign in to comment.