diff --git a/src/Makefile b/src/Makefile index 672171bcd59..6315bda82df 100644 --- a/src/Makefile +++ b/src/Makefile @@ -55,7 +55,7 @@ PGOBENCH = $(WINE_PATH) ./$(EXE) bench SRCS = benchmark.cpp bitboard.cpp evaluate.cpp main.cpp \ misc.cpp movegen.cpp movepick.cpp position.cpp \ search.cpp thread.cpp timeman.cpp tt.cpp uci.cpp ucioption.cpp tune.cpp syzygy/tbprobe.cpp \ - nnue/nnue_misc.cpp nnue/features/half_ka_v2_hm.cpp nnue/network.cpp + nnue/nnue_misc.cpp nnue/features/half_ka_v2_hm.cpp nnue/network.cpp engine.cpp HEADERS = benchmark.h bitboard.h evaluate.h misc.h movegen.h movepick.h \ nnue/nnue_misc.h nnue/features/half_ka_v2_hm.h nnue/layers/affine_transform.h \ @@ -63,7 +63,7 @@ HEADERS = benchmark.h bitboard.h evaluate.h misc.h movegen.h movepick.h \ nnue/layers/sqr_clipped_relu.h nnue/nnue_accumulator.h nnue/nnue_architecture.h \ nnue/nnue_common.h nnue/nnue_feature_transformer.h position.h \ search.h syzygy/tbprobe.h thread.h thread_win32_osx.h timeman.h \ - tt.h tune.h types.h uci.h ucioption.h perft.h nnue/network.h + tt.h tune.h types.h uci.h ucioption.h perft.h nnue/network.h engine.h OBJS = $(notdir $(SRCS:.cpp=.o)) diff --git a/src/engine.cpp b/src/engine.cpp new file mode 100644 index 00000000000..79a2c604742 --- /dev/null +++ b/src/engine.cpp @@ -0,0 +1,153 @@ +/* + Stockfish, a UCI chess playing engine derived from Glaurung 2.1 + Copyright (C) 2004-2024 The Stockfish developers (see AUTHORS file) + + Stockfish is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Stockfish is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +*/ + +#include "engine.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "benchmark.h" +#include "evaluate.h" +#include "movegen.h" +#include "nnue/network.h" +#include "nnue/nnue_common.h" +#include "perft.h" +#include "position.h" +#include "search.h" +#include "syzygy/tbprobe.h" +#include "types.h" +#include "ucioption.h" + +namespace Stockfish { + +namespace NN = Eval::NNUE; + +constexpr auto StartFEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"; + +Engine::Engine(std::string path) : + binaryDirectory(CommandLine::get_binary_directory(path)), + states(new std::deque(1)), + networks(NN::Networks( + NN::NetworkBig({EvalFileDefaultNameBig, "None", ""}, NN::EmbeddedNNUEType::BIG), + NN::NetworkSmall({EvalFileDefaultNameSmall, "None", ""}, NN::EmbeddedNNUEType::SMALL))) { + Tune::init(options); + pos.set(StartFEN, false, &states->back()); +} + +void Engine::go(const Search::LimitsType& limits) { + verify_networks(); + + if (limits.perft) + { + perft(pos.fen(), limits.perft, options["UCI_Chess960"]); + return; + } + + threads.start_thinking(options, pos, states, limits); +} +void Engine::stop() { threads.stop = true; } + +void Engine::search_clear() { + wait_for_search_finished(); + + tt.clear(options["Threads"]); + threads.clear(); + + // @TODO wont work multiple instances + Tablebases::init(options["SyzygyPath"]); // Free mapped files +} + +void Engine::wait_for_search_finished() { threads.main_thread()->wait_for_search_finished(); } + +void Engine::set_position(const std::string& fen, const std::vector& moves) { + // Drop the old state and create a new one + states = StateListPtr(new std::deque(1)); + pos.set(fen, options["UCI_Chess960"], &states->back()); + + for (const auto& move : moves) + { + auto m = UCIEngine::to_move(pos, move); + + if (m == Move::none()) + break; + + states->emplace_back(); + pos.do_move(m, states->back()); + } +} + +// modifiers + +void Engine::resize_threads() { threads.set({options, threads, tt, networks}); } + +void Engine::set_tt_size(size_t mb) { + wait_for_search_finished(); + tt.resize(mb, options["Threads"]); +} + +void Engine::set_ponderhit(bool b) { threads.main_manager()->ponder = b; } + +// network related + +void Engine::verify_networks() { + networks.big.verify(options["EvalFile"]); + networks.small.verify(options["EvalFileSmall"]); +} + +void Engine::load_networks() { + networks.big.load(binaryDirectory, options["EvalFile"]); + networks.small.load(binaryDirectory, options["EvalFileSmall"]); +} + +void Engine::load_big_network(const std::string& file) { networks.big.load(binaryDirectory, file); } + +void Engine::load_small_network(const std::string& file) { + networks.small.load(binaryDirectory, file); +} + +void Engine::save_network(const std::pair, std::string> files[2]) { + networks.big.save(files[0].first); + networks.small.save(files[1].first); +} + +// utility functions + +OptionsMap& Engine::get_options() { return options; } + +uint64_t Engine::nodes_searched() const { return threads.nodes_searched(); } + +void Engine::trace_eval() { + StateListPtr trace_states(new std::deque(1)); + Position p; + p.set(pos.fen(), options["UCI_Chess960"], &trace_states->back()); + + verify_networks(); + + sync_cout << "\n" << Eval::trace(p, networks) << sync_endl; +} + +} \ No newline at end of file diff --git a/src/engine.h b/src/engine.h new file mode 100644 index 00000000000..166c6cf1607 --- /dev/null +++ b/src/engine.h @@ -0,0 +1,85 @@ +/* + Stockfish, a UCI chess playing engine derived from Glaurung 2.1 + Copyright (C) 2004-2024 The Stockfish developers (see AUTHORS file) + + Stockfish is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + Stockfish is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . +*/ + +#ifndef ENGINE_H_INCLUDED +#define ENGINE_H_INCLUDED + +#include "misc.h" +#include "nnue/network.h" +#include "position.h" +#include "search.h" +#include "thread.h" +#include "tt.h" +#include "ucioption.h" + +namespace Stockfish { + +class Engine { + public: + Engine(std::string path = ""); + ~Engine() { wait_for_search_finished(); } + + // non blocking call to start searching + void go(const Search::LimitsType&); + // non blocking call to stop searching + void stop(); + + // blocking call to wait for search to finish + void wait_for_search_finished(); + // set a new position + void set_position(const std::string& fen, const std::vector& moves); + + // modifiers + + void resize_threads(); + void set_tt_size(size_t mb); + void set_ponderhit(bool); + // clears the search + void search_clear(); + + // network related + + void verify_networks(); + void load_networks(); + void load_big_network(const std::string& file); + void load_small_network(const std::string& file); + void save_network(const std::pair, std::string> files[2]); + + // utility functions + + void trace_eval(); + // nodes since last search clear + uint64_t nodes_searched() const; + OptionsMap& get_options(); + + private: + const std::string binaryDirectory; + + Position pos; + StateListPtr states; + + OptionsMap options; + ThreadPool threads; + TranspositionTable tt; + Eval::NNUE::Networks networks; +}; + +} // namespace Stockfish + + +#endif // #ifndef ENGINE_H_INCLUDED \ No newline at end of file diff --git a/src/evaluate.cpp b/src/evaluate.cpp index bc705b857df..dcbfedb499a 100644 --- a/src/evaluate.cpp +++ b/src/evaluate.cpp @@ -105,11 +105,11 @@ std::string Eval::trace(Position& pos, const Eval::NNUE::Networks& networks) { Value v = networks.big.evaluate(pos, false); v = pos.side_to_move() == WHITE ? v : -v; - ss << "NNUE evaluation " << 0.01 * UCI::to_cp(v, pos) << " (white side)\n"; + ss << "NNUE evaluation " << 0.01 * UCIEngine::to_cp(v, pos) << " (white side)\n"; v = evaluate(networks, pos, VALUE_ZERO); v = pos.side_to_move() == WHITE ? v : -v; - ss << "Final evaluation " << 0.01 * UCI::to_cp(v, pos) << " (white side)"; + ss << "Final evaluation " << 0.01 * UCIEngine::to_cp(v, pos) << " (white side)"; ss << " [with scaled NNUE, ...]"; ss << "\n"; diff --git a/src/main.cpp b/src/main.cpp index 33d5d375fca..4e72c00398a 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -34,10 +34,7 @@ int main(int argc, char* argv[]) { Bitboards::init(); Position::init(); - UCI uci(argc, argv); - - Tune::init(uci.options); - + UCIEngine uci(argc, argv); uci.loop(); return 0; diff --git a/src/misc.cpp b/src/misc.cpp index 270d25ad4bc..1abb81b14c2 100644 --- a/src/misc.cpp +++ b/src/misc.cpp @@ -723,13 +723,9 @@ void bind_this_thread(size_t idx) { #define GETCWD getcwd #endif -CommandLine::CommandLine(int _argc, char** _argv) : - argc(_argc), - argv(_argv) { - std::string pathSeparator; - // Extract the path+name of the executable binary - std::string argv0 = argv[0]; +std::string CommandLine::get_binary_directory(std::string argv0) { + std::string pathSeparator; #ifdef _WIN32 pathSeparator = "\\"; @@ -745,15 +741,11 @@ CommandLine::CommandLine(int _argc, char** _argv) : #endif // Extract the working directory - workingDirectory = ""; - char buff[40000]; - char* cwd = GETCWD(buff, 40000); - if (cwd) - workingDirectory = cwd; + auto workingDirectory = CommandLine::get_working_directory(); // Extract the binary directory path from argv0 - binaryDirectory = argv0; - size_t pos = binaryDirectory.find_last_of("\\/"); + auto binaryDirectory = argv0; + size_t pos = binaryDirectory.find_last_of("\\/"); if (pos == std::string::npos) binaryDirectory = "." + pathSeparator; else @@ -762,6 +754,19 @@ CommandLine::CommandLine(int _argc, char** _argv) : // Pattern replacement: "./" at the start of path is replaced by the working directory if (binaryDirectory.find("." + pathSeparator) == 0) binaryDirectory.replace(0, 1, workingDirectory); + + return binaryDirectory; } +std::string CommandLine::get_working_directory() { + std::string workingDirectory = ""; + char buff[40000]; + char* cwd = GETCWD(buff, 40000); + if (cwd) + workingDirectory = cwd; + + return workingDirectory; +} + + } // namespace Stockfish diff --git a/src/misc.h b/src/misc.h index de34ee111f7..d75b236ff71 100644 --- a/src/misc.h +++ b/src/misc.h @@ -206,13 +206,15 @@ void bind_this_thread(size_t idx); struct CommandLine { public: - CommandLine(int, char**); + CommandLine(int _argc, char** _argv) : + argc(_argc), + argv(_argv) {} + + static std::string get_binary_directory(std::string argv0); + static std::string get_working_directory(); int argc; char** argv; - - std::string binaryDirectory; // path of the executable directory - std::string workingDirectory; // path of the working directory }; namespace Utility { diff --git a/src/nnue/nnue_misc.cpp b/src/nnue/nnue_misc.cpp index 725d90d27d6..3fa6e1b6180 100644 --- a/src/nnue/nnue_misc.cpp +++ b/src/nnue/nnue_misc.cpp @@ -58,7 +58,7 @@ void format_cp_compact(Value v, char* buffer, const Position& pos) { buffer[0] = (v < 0 ? '-' : v > 0 ? '+' : ' '); - int cp = std::abs(UCI::to_cp(v, pos)); + int cp = std::abs(UCIEngine::to_cp(v, pos)); if (cp >= 10000) { buffer[1] = '0' + cp / 10000; @@ -92,7 +92,7 @@ void format_cp_compact(Value v, char* buffer, const Position& pos) { // Converts a Value into pawns, always keeping two decimals void format_cp_aligned_dot(Value v, std::stringstream& stream, const Position& pos) { - const double pawns = std::abs(0.01 * UCI::to_cp(v, pos)); + const double pawns = std::abs(0.01 * UCIEngine::to_cp(v, pos)); stream << (v < 0 ? '-' : v > 0 ? '+' diff --git a/src/perft.h b/src/perft.h index 2edc3ad0a6e..2dbab828a18 100644 --- a/src/perft.h +++ b/src/perft.h @@ -51,7 +51,7 @@ uint64_t perft(Position& pos, Depth depth) { pos.undo_move(m); } if (Root) - sync_cout << UCI::move(m, pos.is_chess960()) << ": " << cnt << sync_endl; + sync_cout << UCIEngine::move(m, pos.is_chess960()) << ": " << cnt << sync_endl; } return nodes; } diff --git a/src/position.cpp b/src/position.cpp index 2263afe7669..fd1678959d5 100644 --- a/src/position.cpp +++ b/src/position.cpp @@ -78,7 +78,7 @@ std::ostream& operator<<(std::ostream& os, const Position& pos) { << std::setw(16) << pos.key() << std::setfill(' ') << std::dec << "\nCheckers: "; for (Bitboard b = pos.checkers(); b;) - os << UCI::square(pop_lsb(b)) << " "; + os << UCIEngine::square(pop_lsb(b)) << " "; if (int(Tablebases::MaxCardinality) >= popcount(pos.pieces()) && !pos.can_castle(ANY_CASTLING)) { @@ -431,8 +431,8 @@ string Position::fen() const { if (!can_castle(ANY_CASTLING)) ss << '-'; - ss << (ep_square() == SQ_NONE ? " - " : " " + UCI::square(ep_square()) + " ") << st->rule50 - << " " << 1 + (gamePly - (sideToMove == BLACK)) / 2; + ss << (ep_square() == SQ_NONE ? " - " : " " + UCIEngine::square(ep_square()) + " ") + << st->rule50 << " " << 1 + (gamePly - (sideToMove == BLACK)) / 2; return ss.str(); } diff --git a/src/search.cpp b/src/search.cpp index 3f882aabdf5..efc00750613 100644 --- a/src/search.cpp +++ b/src/search.cpp @@ -158,7 +158,7 @@ void Search::Worker::start_searching() { { rootMoves.emplace_back(Move::none()); sync_cout << "info depth 0 score " - << UCI::to_score(rootPos.checkers() ? -VALUE_MATE : VALUE_DRAW, rootPos) + << UCIEngine::to_score(rootPos.checkers() ? -VALUE_MATE : VALUE_DRAW, rootPos) << sync_endl; } else @@ -204,11 +204,13 @@ void Search::Worker::start_searching() { sync_cout << main_manager()->pv(*bestThread, threads, tt, bestThread->completedDepth) << sync_endl; - sync_cout << "bestmove " << UCI::move(bestThread->rootMoves[0].pv[0], rootPos.is_chess960()); + sync_cout << "bestmove " + << UCIEngine::move(bestThread->rootMoves[0].pv[0], rootPos.is_chess960()); if (bestThread->rootMoves[0].pv.size() > 1 || bestThread->rootMoves[0].extract_ponder_from_tt(tt, rootPos)) - std::cout << " ponder " << UCI::move(bestThread->rootMoves[0].pv[1], rootPos.is_chess960()); + std::cout << " ponder " + << UCIEngine::move(bestThread->rootMoves[0].pv[1], rootPos.is_chess960()); std::cout << sync_endl; } @@ -933,7 +935,7 @@ Value Search::Worker::search( if (rootNode && is_mainthread() && main_manager()->tm.elapsed(threads.nodes_searched()) > 3000) sync_cout << "info depth " << depth << " currmove " - << UCI::move(move, pos.is_chess960()) << " currmovenumber " + << UCIEngine::move(move, pos.is_chess960()) << " currmovenumber " << moveCount + thisThread->pvIdx << sync_endl; if (PvNode) (ss + 1)->pv = nullptr; @@ -1904,10 +1906,10 @@ std::string SearchManager::pv(const Search::Worker& worker, ss << "info" << " depth " << d << " seldepth " << rootMoves[i].selDepth << " multipv " << i + 1 - << " score " << UCI::to_score(v, pos); + << " score " << UCIEngine::to_score(v, pos); if (worker.options["UCI_ShowWDL"]) - ss << UCI::wdl(v, pos); + ss << UCIEngine::wdl(v, pos); if (i == pvIdx && !tb && updated) // tablebase- and previous-scores are exact ss << (rootMoves[i].scoreLowerbound @@ -1918,7 +1920,7 @@ std::string SearchManager::pv(const Search::Worker& worker, << " tbhits " << tbHits << " time " << time << " pv"; for (Move m : rootMoves[i].pv) - ss << " " << UCI::move(m, pos.is_chess960()); + ss << " " << UCIEngine::move(m, pos.is_chess960()); } return ss.str(); diff --git a/src/tune.h b/src/tune.h index b88c085fd4b..079614db28a 100644 --- a/src/tune.h +++ b/src/tune.h @@ -158,7 +158,7 @@ class Tune { for (auto& e : instance().list) e->init_option(); read_options(); - } // Deferred, due to UCI::Options access + } // Deferred, due to UCIEngine::Options access static void read_options() { for (auto& e : instance().list) e->read_option(); diff --git a/src/uci.cpp b/src/uci.cpp index ee95d5be5e6..ed23c00a49d 100644 --- a/src/uci.cpp +++ b/src/uci.cpp @@ -32,6 +32,7 @@ #include #include "benchmark.h" +#include "engine.h" #include "evaluate.h" #include "movegen.h" #include "nnue/network.h" @@ -49,27 +50,19 @@ constexpr auto StartFEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - constexpr int MaxHashMB = Is64Bit ? 33554432 : 2048; -namespace NN = Eval::NNUE; - - -UCI::UCI(int argc, char** argv) : - networks(NN::Networks( - NN::NetworkBig({EvalFileDefaultNameBig, "None", ""}, NN::EmbeddedNNUEType::BIG), - NN::NetworkSmall({EvalFileDefaultNameSmall, "None", ""}, NN::EmbeddedNNUEType::SMALL))), +UCIEngine::UCIEngine(int argc, char** argv) : + engine(argv[0]), cli(argc, argv) { + auto& options = engine.get_options(); + options["Debug Log File"] << Option("", [](const Option& o) { start_logger(o); }); - options["Threads"] << Option(1, 1, 1024, [this](const Option&) { - threads.set({options, threads, tt, networks}); - }); + options["Threads"] << Option(1, 1, 1024, [this](const Option&) { engine.resize_threads(); }); - options["Hash"] << Option(16, 1, MaxHashMB, [this](const Option& o) { - threads.main_thread()->wait_for_search_finished(); - tt.resize(o, options["Threads"]); - }); + options["Hash"] << Option(16, 1, MaxHashMB, [this](const Option& o) { engine.set_tt_size(o); }); - options["Clear Hash"] << Option([this](const Option&) { search_clear(); }); + options["Clear Hash"] << Option([this](const Option&) { engine.search_clear(); }); options["Ponder"] << Option(false); options["MultiPV"] << Option(1, 1, MAX_MOVES); options["Skill Level"] << Option(20, 0, 20); @@ -83,22 +76,17 @@ UCI::UCI(int argc, char** argv) : options["SyzygyProbeDepth"] << Option(1, 1, 100); options["Syzygy50MoveRule"] << Option(true); options["SyzygyProbeLimit"] << Option(7, 0, 7); - options["EvalFile"] << Option(EvalFileDefaultNameBig, [this](const Option& o) { - networks.big.load(cli.binaryDirectory, o); - }); - options["EvalFileSmall"] << Option(EvalFileDefaultNameSmall, [this](const Option& o) { - networks.small.load(cli.binaryDirectory, o); - }); - - networks.big.load(cli.binaryDirectory, options["EvalFile"]); - networks.small.load(cli.binaryDirectory, options["EvalFileSmall"]); - - threads.set({options, threads, tt, networks}); - - search_clear(); // After threads are up + options["EvalFile"] << Option(EvalFileDefaultNameBig, + [this](const Option& o) { engine.load_big_network(o); }); + options["EvalFileSmall"] << Option(EvalFileDefaultNameSmall, + [this](const Option& o) { engine.load_small_network(o); }); + + engine.load_networks(); + engine.resize_threads(); + engine.search_clear(); // After threads are up } -void UCI::loop() { +void UCIEngine::loop() { Position pos; std::string token, cmd; @@ -121,27 +109,27 @@ void UCI::loop() { is >> std::skipws >> token; if (token == "quit" || token == "stop") - threads.stop = true; + engine.stop(); // The GUI sends 'ponderhit' to tell that the user has played the expected move. // So, 'ponderhit' is sent if pondering was done on the same move that the user // has played. The search should continue, but should also switch from pondering // to the normal search. else if (token == "ponderhit") - threads.main_manager()->ponder = false; // Switch to the normal search + engine.set_ponderhit(false); else if (token == "uci") sync_cout << "id name " << engine_info(true) << "\n" - << options << "\nuciok" << sync_endl; + << engine.get_options() << "\nuciok" << sync_endl; else if (token == "setoption") setoption(is); else if (token == "go") - go(pos, is, states); + go(pos, is); else if (token == "position") - position(pos, is, states); + position(is); else if (token == "ucinewgame") - search_clear(); + engine.search_clear(); else if (token == "isready") sync_cout << "readyok" << sync_endl; @@ -150,11 +138,11 @@ void UCI::loop() { else if (token == "flip") pos.flip(); else if (token == "bench") - bench(pos, is, states); + bench(pos, is); else if (token == "d") sync_cout << pos << sync_endl; else if (token == "eval") - trace_eval(pos); + engine.trace_eval(); else if (token == "compiler") sync_cout << compiler_info() << sync_endl; else if (token == "export_net") @@ -167,8 +155,7 @@ void UCI::loop() { if (is >> std::skipws >> files[1].second) files[1].first = files[1].second; - networks.big.save(files[0].first); - networks.small.save(files[1].first); + engine.save_network(files); } else if (token == "--help" || token == "help" || token == "--license" || token == "license") sync_cout @@ -186,7 +173,7 @@ void UCI::loop() { } while (token != "quit" && cli.argc == 1); // The command-line arguments are one-shot } -Search::LimitsType UCI::parse_limits(const Position& pos, std::istream& is) { +Search::LimitsType UCIEngine::parse_limits(const Position& pos, std::istream& is) { Search::LimitsType limits; std::string token; @@ -225,23 +212,13 @@ Search::LimitsType UCI::parse_limits(const Position& pos, std::istream& is) { return limits; } -void UCI::go(Position& pos, std::istringstream& is, StateListPtr& states) { +void UCIEngine::go(Position& pos, std::istringstream& is) { Search::LimitsType limits = parse_limits(pos, is); - - networks.big.verify(options["EvalFile"]); - networks.small.verify(options["EvalFileSmall"]); - - if (limits.perft) - { - perft(pos.fen(), limits.perft, options["UCI_Chess960"]); - return; - } - - threads.start_thinking(options, pos, states, limits); + engine.go(limits); } -void UCI::bench(Position& pos, std::istream& args, StateListPtr& states) { +void UCIEngine::bench(Position& pos, std::istream& args) { std::string token; uint64_t num, nodes = 0, cnt = 1; @@ -263,20 +240,20 @@ void UCI::bench(Position& pos, std::istream& args, StateListPtr& states) { << std::endl; if (token == "go") { - go(pos, is, states); - threads.main_thread()->wait_for_search_finished(); - nodes += threads.nodes_searched(); + go(pos, is); + engine.wait_for_search_finished(); + nodes += engine.nodes_searched(); } else - trace_eval(pos); + engine.trace_eval(); } else if (token == "setoption") setoption(is); else if (token == "position") - position(pos, is, states); + position(is); else if (token == "ucinewgame") { - search_clear(); // Search::clear() may take a while + engine.search_clear(); // search_clear may take a while elapsed = now(); } } @@ -290,33 +267,13 @@ void UCI::bench(Position& pos, std::istream& args, StateListPtr& states) { << "\nNodes/second : " << 1000 * nodes / elapsed << std::endl; } -void UCI::trace_eval(Position& pos) { - StateListPtr states(new std::deque(1)); - Position p; - p.set(pos.fen(), options["UCI_Chess960"], &states->back()); - - networks.big.verify(options["EvalFile"]); - networks.small.verify(options["EvalFileSmall"]); - - sync_cout << "\n" << Eval::trace(p, networks) << sync_endl; +void UCIEngine::setoption(std::istringstream& is) { + engine.wait_for_search_finished(); + engine.get_options().setoption(is); } -void UCI::search_clear() { - threads.main_thread()->wait_for_search_finished(); - - tt.clear(options["Threads"]); - threads.clear(); - Tablebases::init(options["SyzygyPath"]); // Free mapped files -} - -void UCI::setoption(std::istringstream& is) { - threads.main_thread()->wait_for_search_finished(); - options.setoption(is); -} - -void UCI::position(Position& pos, std::istringstream& is, StateListPtr& states) { - Move m; +void UCIEngine::position(std::istringstream& is) { std::string token, fen; is >> token; @@ -332,15 +289,14 @@ void UCI::position(Position& pos, std::istringstream& is, StateListPtr& states) else return; - states = StateListPtr(new std::deque(1)); // Drop the old state and create a new one - pos.set(fen, options["UCI_Chess960"], &states->back()); + std::vector moves; - // Parse the move list, if any - while (is >> token && (m = to_move(pos, token)) != Move::none()) + while (is >> token) { - states->emplace_back(); - pos.do_move(m, states->back()); + moves.push_back(token); } + + engine.set_position(fen, moves); } namespace { @@ -379,7 +335,7 @@ int win_rate_model(Value v, const Position& pos) { } } -std::string UCI::to_score(Value v, const Position& pos) { +std::string UCIEngine::to_score(Value v, const Position& pos) { assert(-VALUE_INFINITE < v && v < VALUE_INFINITE); std::stringstream ss; @@ -399,7 +355,7 @@ std::string UCI::to_score(Value v, const Position& pos) { // Turns a Value to an integer centipawn number, // without treatment of mate and similar special scores. -int UCI::to_cp(Value v, const Position& pos) { +int UCIEngine::to_cp(Value v, const Position& pos) { // In general, the score can be defined via the the WDL as // (log(1/L - 1) - log(1/W - 1)) / ((log(1/L - 1) + log(1/W - 1)) @@ -410,7 +366,7 @@ int UCI::to_cp(Value v, const Position& pos) { return std::round(100 * int(v) / a); } -std::string UCI::wdl(Value v, const Position& pos) { +std::string UCIEngine::wdl(Value v, const Position& pos) { std::stringstream ss; int wdl_w = win_rate_model(v, pos); @@ -421,11 +377,11 @@ std::string UCI::wdl(Value v, const Position& pos) { return ss.str(); } -std::string UCI::square(Square s) { +std::string UCIEngine::square(Square s) { return std::string{char('a' + file_of(s)), char('1' + rank_of(s))}; } -std::string UCI::move(Move m, bool chess960) { +std::string UCIEngine::move(Move m, bool chess960) { if (m == Move::none()) return "(none)"; @@ -447,7 +403,7 @@ std::string UCI::move(Move m, bool chess960) { } -Move UCI::to_move(const Position& pos, std::string& str) { +Move UCIEngine::to_move(const Position& pos, std::string str) { if (str.length() == 5) str[4] = char(tolower(str[4])); // The promotion piece character must be lowercased diff --git a/src/uci.h b/src/uci.h index 237928d9abc..c4e90b48d35 100644 --- a/src/uci.h +++ b/src/uci.h @@ -22,6 +22,7 @@ #include #include +#include "engine.h" #include "misc.h" #include "nnue/network.h" #include "position.h" @@ -36,9 +37,9 @@ class Move; enum Square : int; using Value = int; -class UCI { +class UCIEngine { public: - UCI(int argc, char** argv); + UCIEngine(int argc, char** argv); void loop(); @@ -47,25 +48,17 @@ class UCI { static std::string square(Square s); static std::string move(Move m, bool chess960); static std::string wdl(Value v, const Position& pos); - static Move to_move(const Position& pos, std::string& str); + static Move to_move(const Position& pos, std::string str); static Search::LimitsType parse_limits(const Position& pos, std::istream& is); - const std::string& working_directory() const { return cli.workingDirectory; } - - OptionsMap options; - Eval::NNUE::Networks networks; - private: - TranspositionTable tt; - ThreadPool threads; - CommandLine cli; - - void go(Position& pos, std::istringstream& is, StateListPtr& states); - void bench(Position& pos, std::istream& args, StateListPtr& states); - void position(Position& pos, std::istringstream& is, StateListPtr& states); - void trace_eval(Position& pos); - void search_clear(); + Engine engine; + CommandLine cli; + + void go(Position& pos, std::istringstream& is); + void bench(Position& pos, std::istream& args); + void position(std::istringstream& is); void setoption(std::istringstream& is); };