diff --git a/src/Features/Tas/TasProtocol.cpp b/src/Features/Tas/TasProtocol.cpp index 3b534332d..81a60896c 100644 --- a/src/Features/Tas/TasProtocol.cpp +++ b/src/Features/Tas/TasProtocol.cpp @@ -40,8 +40,6 @@ using namespace TasProtocol; -Variable sar_tas_protocol("sar_tas_protocol", "0", 0, "Enable the remote TAS controller connection protocol. Value higher than 1 replaces port to listen for connections on (6555 by default).\n"); - namespace TasProtocol { // Has to be defined here because something something MVSC suck my dick struct ConnectionData { @@ -54,11 +52,13 @@ namespace TasProtocol { static SOCKET g_listen_sock = INVALID_SOCKET; static std::vector g_connections; static std::atomic g_should_stop; +static bool should_run = false; +static std::atomic g_is_server; static std::string g_client_ip; static int g_client_port; -static bool g_attempt_client_connection; -static std::mutex g_client_data_mutex; +static int g_server_port; +static std::mutex g_conn_data_mutex; static Status g_last_status; static Status g_current_status; @@ -393,23 +393,59 @@ static bool processCommands(ConnectionData &cl) { } } -static void attemptConnectionToServer() { - g_client_data_mutex.lock(); +static bool attemptToInitializeServer() { + if (!g_is_server.load()) return false; - bool shouldConnect = g_attempt_client_connection; - g_attempt_client_connection = false; - std::string ip = g_client_ip; - int port = g_client_port; + g_conn_data_mutex.lock(); + auto server_port = g_server_port; + g_conn_data_mutex.unlock(); - g_client_data_mutex.unlock(); + g_listen_sock = socket(AF_INET6, SOCK_STREAM, 0); + if (g_listen_sock == INVALID_SOCKET) { + THREAD_PRINT("Could not initialize TAS server: socket creation failed\n"); + return false; + } - if (!shouldConnect) return; + // why tf is this enabled by default on Windows + int v6only = 0; + setsockopt(g_listen_sock, IPPROTO_IPV6, IPV6_V6ONLY, (const char *)&v6only, sizeof v6only); + + struct sockaddr_in6 saddr { + AF_INET6, + htons(server_port), + 0, + in6addr_any, + 0, + }; + + if (bind(g_listen_sock, (struct sockaddr *)&saddr, sizeof saddr) == SOCKET_ERROR) { + THREAD_PRINT("Could not initialize TAS server: socket bind failed\n"); + closesocket(g_listen_sock); + return false; + } + + if (listen(g_listen_sock, 4) == SOCKET_ERROR) { + THREAD_PRINT("Could not initialize TAS server: socket listen failed\n"); + closesocket(g_listen_sock); + return false; + } + + return true; +} + +static bool attemptConnectionToServer() { + if (g_is_server.load()) return false; + + g_conn_data_mutex.lock(); + std::string ip = g_client_ip; + int port = g_client_port; + g_conn_data_mutex.unlock(); auto clientSocket = socket(AF_INET, SOCK_STREAM, 0); - if (clientSocket == -1) { + if (clientSocket == SOCKET_ERROR) { THREAD_PRINT("Could not connect to TAS protocol server: socket creation failed\n"); closesocket(clientSocket); - return; + return false; } sockaddr_in serverAddr; @@ -419,18 +455,20 @@ static void attemptConnectionToServer() { if (inet_pton(AF_INET, g_client_ip.c_str(), &serverAddr.sin_addr) <= 0) { THREAD_PRINT("Could not connect to TAS protocol server: invalid address\n"); closesocket(clientSocket); - return; + return false; } if (connect(clientSocket, reinterpret_cast(&serverAddr), sizeof(serverAddr)) == -1) { THREAD_PRINT("Could not connect to TAS protocol server: connection failed.\n"); closesocket(clientSocket); - return; + return false; } g_connections.push_back({clientSocket, {}}); fullUpdate(g_connections[g_connections.size() - 1], true); THREAD_PRINT("Successfully connected to TAS server %s:%d.\n", ip.c_str(), port); + + return true; } static bool receiveFromConnection(TasProtocol::ConnectionData &cl) { @@ -452,13 +490,15 @@ static bool receiveFromConnection(TasProtocol::ConnectionData &cl) { return true; } -static void processConnections() { +static void processConnections(bool is_server) { fd_set set; FD_ZERO(&set); SOCKET max = g_listen_sock; - FD_SET(g_listen_sock, &set); + if (is_server) { + FD_SET(g_listen_sock, &set); + } for (auto client : g_connections) { FD_SET(client.sock, &set); if (max < client.sock) max = client.sock; @@ -474,7 +514,7 @@ static void processConnections() { return; } - if (FD_ISSET(g_listen_sock, &set)) { + if (is_server && FD_ISSET(g_listen_sock, &set)) { SOCKET cl = accept(g_listen_sock, nullptr, nullptr); if (cl != INVALID_SOCKET) { g_connections.push_back({ cl, {} }); @@ -494,58 +534,43 @@ static void processConnections() { } } -static void mainThread(int tas_server_port) { - THREAD_PRINT("Starting TAS server\n"); +static void mainThread() { + THREAD_PRINT("Starting TAS protocol\n"); #ifdef _WIN32 WSADATA wsa_data; int err = WSAStartup(MAKEWORD(2,2), &wsa_data); if (err){ - THREAD_PRINT("Could not initialize TAS server: WSAStartup failed (%d)\n", err); + THREAD_PRINT("Could not initialize TAS protocol: WSAStartup failed (%d)\n", err); return; } #endif - g_listen_sock = socket(AF_INET6, SOCK_STREAM, 0); - if (g_listen_sock == INVALID_SOCKET) { - THREAD_PRINT("Could not initialize TAS server: socket creation failed\n"); - WSACleanup(); - return; - } - - // why tf is this enabled by default on Windows - int v6only = 0; - setsockopt(g_listen_sock, IPPROTO_IPV6, IPV6_V6ONLY, (const char *)&v6only, sizeof v6only); + bool is_server = g_is_server.load(); - struct sockaddr_in6 saddr{ - AF_INET6, - htons(tas_server_port), - 0, - in6addr_any, - 0, - }; - - if (bind(g_listen_sock, (struct sockaddr *)&saddr, sizeof saddr) == SOCKET_ERROR) { - THREAD_PRINT("Could not initialize TAS server: socket bind failed\n"); - closesocket(g_listen_sock); + if (is_server && !attemptToInitializeServer()) { WSACleanup(); return; } - - if (listen(g_listen_sock, 4) == SOCKET_ERROR) { - THREAD_PRINT("Could not initialize TAS server: socket listen failed\n"); - closesocket(g_listen_sock); + if (!is_server && !attemptConnectionToServer()) { WSACleanup(); return; } while (!g_should_stop.load()) { - attemptConnectionToServer(); - processConnections(); + processConnections(is_server); update(); + + if (g_connections.size() == 0 && !is_server) { + break; + } } - THREAD_PRINT("Stopping TAS server\n"); + if (is_server) { + THREAD_PRINT("Stopping TAS server\n"); + } else { + THREAD_PRINT("Stopping TAS client\n"); + } for (auto &cl : g_connections) { closesocket(cl.sock); @@ -558,23 +583,24 @@ static void mainThread(int tas_server_port) { static std::thread g_net_thread; static bool g_running; -ON_EVENT(FRAME) { - int tas_server_port = sar_tas_protocol.GetInt() == 1 ? DEFAULT_TAS_CLIENT_SOCKET : sar_tas_protocol.GetInt(); - bool should_run = sar_tas_protocol.GetBool(); +static void restart() { + g_should_stop.store(true); + should_run = true; +} - if (g_running && !should_run) { - g_should_stop.store(true); +ON_EVENT(FRAME) { + if (g_running && g_should_stop.load()) { if (g_net_thread.joinable()) g_net_thread.join(); g_running = false; } else if (!g_running && should_run) { g_should_stop.store(false); - g_net_thread = std::thread(mainThread, tas_server_port); + g_net_thread = std::thread(mainThread); g_running = true; + should_run = false; } } ON_EVENT_P(SAR_UNLOAD, -100) { - sar_tas_protocol.SetValue(false); g_should_stop.store(true); if (g_net_thread.joinable()) g_net_thread.join(); } @@ -619,13 +645,33 @@ CON_COMMAND(sar_tas_protocol_connect, return console->Print(sar_tas_protocol_connect.ThisPtr()->m_pszHelpString); } - sar_tas_protocol.SetValue(1); - - g_client_data_mutex.lock(); + g_conn_data_mutex.lock(); g_client_ip = args[1]; g_client_port = args.ArgC() >= 3 ? std::atoi(args[2]) : DEFAULT_TAS_SERVER_SOCKET; - g_attempt_client_connection = true; - g_client_data_mutex.unlock(); + g_conn_data_mutex.unlock(); + + g_is_server.store(false); + + restart(); +} + +CON_COMMAND(sar_tas_protocol_server, + "sar_tas_protocol_server [port] - starts a TAS protocol server. Port is 6555 by default.\n") { + if (args.ArgC() < 1 || args.ArgC() > 2) { + return console->Print(sar_tas_protocol_server.ThisPtr()->m_pszHelpString); + } + g_conn_data_mutex.lock(); + g_server_port = args.ArgC() >= 1 ? std::atoi(args[2]) : DEFAULT_TAS_SERVER_SOCKET; + g_conn_data_mutex.unlock(); + + g_is_server.store(true); + + restart(); +} + +CON_COMMAND(sar_tas_protocol_stop, + "sar_tas_protocol_stop - stops every TAS protocol related connection.\n") { + g_should_stop.store(true); }