Skip to content

Commit

Permalink
Transform search output to engine callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
Disservin committed Apr 3, 2024
1 parent 63886c6 commit cf62ca0
Show file tree
Hide file tree
Showing 12 changed files with 365 additions and 105 deletions.
4 changes: 2 additions & 2 deletions src/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,15 @@ PGOBENCH = $(WINE_PATH) ./$(EXE) bench
SRCS = benchmark.cpp bitboard.cpp evaluate.cpp main.cpp \
misc.cpp movegen.cpp movepick.cpp position.cpp \
search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp syzygy/tbprobe.cpp \
nnue/nnue_misc.cpp nnue/features/half_ka_v2_hm.cpp nnue/network.cpp engine.cpp
nnue/nnue_misc.cpp nnue/features/half_ka_v2_hm.cpp nnue/network.cpp engine.cpp score.cpp

HEADERS = benchmark.h bitboard.h evaluate.h misc.h movegen.h movepick.h \
nnue/nnue_misc.h nnue/features/half_ka_v2_hm.h nnue/layers/affine_transform.h \
nnue/layers/affine_transform_sparse_input.h nnue/layers/clipped_relu.h nnue/layers/simd.h \
nnue/layers/sqr_clipped_relu.h nnue/nnue_accumulator.h nnue/nnue_architecture.h \
nnue/nnue_common.h nnue/nnue_feature_transformer.h position.h \
search.h syzygy/tbprobe.h thread.h thread_win32_osx.h timeman.h \
tt.h tune.h types.h uci.h ucioption.h perft.h nnue/network.h engine.h
tt.h tune.h types.h uci.h ucioption.h perft.h nnue/network.h engine.h score.h

OBJS = $(notdir $(SRCS:.cpp=.o))

Expand Down
41 changes: 24 additions & 17 deletions src/engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,22 @@

#include "engine.h"

#include <algorithm>
#include <cassert>
#include <cctype>
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <deque>
#include <memory>
#include <optional>
#include <sstream>
#include <ostream>
#include <utility>
#include <vector>

#include "benchmark.h"
#include "evaluate.h"
#include "movegen.h"
#include "misc.h"
#include "nnue/network.h"
#include "nnue/nnue_common.h"
#include "perft.h"
#include "position.h"
#include "search.h"
#include "syzygy/tbprobe.h"
#include "types.h"
#include "uci.h"
#include "ucioption.h"

namespace Stockfish {
Expand All @@ -54,7 +48,6 @@ Engine::Engine(std::string path) :
networks(NN::Networks(
NN::NetworkBig({EvalFileDefaultNameBig, "None", ""}, NN::EmbeddedNNUEType::BIG),
NN::NetworkSmall({EvalFileDefaultNameSmall, "None", ""}, NN::EmbeddedNNUEType::SMALL))) {
Tune::init(options);
pos.set(StartFEN, false, &states->back());
}

Expand All @@ -77,10 +70,26 @@ void Engine::search_clear() {
tt.clear(options["Threads"]);
threads.clear();

// @TODO wont work multiple instances
// @TODO wont work with multiple instances
Tablebases::init(options["SyzygyPath"]); // Free mapped files
}

void Engine::set_on_update_no_moves(std::function<void(const Engine::InfoShort&)>&& f) {
updateContext.onUpdateNoMoves = std::move(f);
}

void Engine::set_on_update_full(std::function<void(const Engine::InfoFull&)>&& f) {
updateContext.onUpdateFull = std::move(f);
}

void Engine::set_on_iter(std::function<void(const Engine::InfoIter&)>&& f) {
updateContext.onIter = std::move(f);
}

void Engine::set_on_bestmove(std::function<void(const std::string&, const std::string&)>&& f) {
updateContext.onBestmove = std::move(f);
}

void Engine::wait_for_search_finished() { threads.main_thread()->wait_for_search_finished(); }

void Engine::set_position(const std::string& fen, const std::vector<std::string>& moves) {
Expand All @@ -102,7 +111,7 @@ void Engine::set_position(const std::string& fen, const std::vector<std::string>

// modifiers

void Engine::resize_threads() { threads.set({options, threads, tt, networks}); }
void Engine::resize_threads() { threads.set({options, threads, tt, networks}, updateContext); }

void Engine::set_tt_size(size_t mb) {
wait_for_search_finished();
Expand All @@ -113,7 +122,7 @@ void Engine::set_ponderhit(bool b) { threads.main_manager()->ponder = b; }

// network related

void Engine::verify_networks() {
void Engine::verify_networks() const {
networks.big.verify(options["EvalFile"]);
networks.small.verify(options["EvalFileSmall"]);
}
Expand All @@ -138,9 +147,7 @@ void Engine::save_network(const std::pair<std::optional<std::string>, std::strin

OptionsMap& Engine::get_options() { return options; }

uint64_t Engine::nodes_searched() const { return threads.nodes_searched(); }

void Engine::trace_eval() {
void Engine::trace_eval() const {
StateListPtr trace_states(new std::deque<StateInfo>(1));
Position p;
p.set(pos.fen(), options["UCI_Chess960"], &trace_states->back());
Expand Down
25 changes: 20 additions & 5 deletions src/engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
#ifndef ENGINE_H_INCLUDED
#define ENGINE_H_INCLUDED

#include "misc.h"
#include <cstddef>
#include <functional>
#include <string>
#include <vector>
#include <optional>
#include <utility>

#include "nnue/network.h"
#include "position.h"
#include "search.h"
Expand All @@ -31,6 +37,10 @@ namespace Stockfish {

class Engine {
public:
using InfoShort = Search::InfoShort;
using InfoFull = Search::InfoFull;
using InfoIter = Search::InfoIteration;

Engine(std::string path = "");
~Engine() { wait_for_search_finished(); }

Expand All @@ -52,19 +62,22 @@ class Engine {
// clears the search
void search_clear();

void set_on_update_no_moves(std::function<void(const InfoShort&)>&&);
void set_on_update_full(std::function<void(const InfoFull&)>&&);
void set_on_iter(std::function<void(const InfoIter&)>&&);
void set_on_bestmove(std::function<void(const std::string&, const std::string&)>&&);

// network related

void verify_networks();
void verify_networks() const;
void load_networks();
void load_big_network(const std::string& file);
void load_small_network(const std::string& file);
void save_network(const std::pair<std::optional<std::string>, std::string> files[2]);

// utility functions

void trace_eval();
// nodes since last search clear
uint64_t nodes_searched() const;
void trace_eval() const;
OptionsMap& get_options();

private:
Expand All @@ -77,6 +90,8 @@ class Engine {
ThreadPool threads;
TranspositionTable tt;
Eval::NNUE::Networks networks;

Search::SearchManager::UpdateContext updateContext;
};

} // namespace Stockfish
Expand Down
5 changes: 4 additions & 1 deletion src/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
#include "bitboard.h"
#include "misc.h"
#include "position.h"
#include "tune.h"
#include "types.h"
#include "uci.h"
#include "tune.h"

using namespace Stockfish;

Expand All @@ -35,6 +35,9 @@ int main(int argc, char* argv[]) {
Position::init();

UCIEngine uci(argc, argv);

Tune::init(uci.engine_options());

uci.loop();

return 0;
Expand Down
48 changes: 48 additions & 0 deletions src/score.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
Stockfish, a UCI chess playing engine derived from Glaurung 2.1
Copyright (C) 2004-2024 The Stockfish developers (see AUTHORS file)
Stockfish is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Stockfish is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

#include "score.h"

#include <cassert>
#include <cmath>
#include <cstdlib>

#include "uci.h"

namespace Stockfish {

Score::Score(Value v, const Position& pos) {
assert(-VALUE_INFINITE < v && v < VALUE_INFINITE);

if (std::abs(v) < VALUE_TB_WIN_IN_MAX_PLY)
{
score = InternalUnits{UCIEngine::to_cp(v, pos)};
}
else if (std::abs(v) <= VALUE_TB)
{
auto distance = VALUE_TB - std::abs(v);
score = (v > 0) ? TBWin{distance} : TBWin{-distance};
}
else
{
auto distance = VALUE_MATE - std::abs(v);
score = (v > 0) ? Mate{distance} : Mate{-distance};
}
}

}
69 changes: 69 additions & 0 deletions src/score.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
Stockfish, a UCI chess playing engine derived from Glaurung 2.1
Copyright (C) 2004-2024 The Stockfish developers (see AUTHORS file)
Stockfish is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
Stockfish is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef SCORE_H_INCLUDED
#define SCORE_H_INCLUDED

#include <variant>
#include <utility>

#include "types.h"

namespace Stockfish {

class Position;

class Score {
public:
struct Mate {
int plies;
};

struct TBWin {
int plies;
};

struct InternalUnits {
int value;
};

Score() = default;
Score(Value v, const Position& pos);

template<typename T>
bool is() const {
return std::holds_alternative<T>(score);
}

template<typename T>
T get() const {
return std::get<T>(score);
}

template<typename F>
decltype(auto) visit(F&& f) const {
return std::visit(std::forward<F>(f), score);
}

private:
std::variant<Mate, TBWin, InternalUnits> score;
};

}

#endif // #ifndef SCORE_H_INCLUDED
Loading

0 comments on commit cf62ca0

Please sign in to comment.