diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index ee54bb91e..ceffc89af 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -1,4 +1,4 @@ -from enum import IntEnum, auto +from enum import IntEnum, StrEnum, auto import json import logging import platform @@ -13,6 +13,7 @@ from .online_callback_abstract import OnlineCallbackAbstract from .plot import PlotOcp, OcpSerializable +from .utils import strstaticproperty, intstaticproperty from ..optimization.optimization_vector import OptimizationVectorHelper @@ -37,6 +38,45 @@ class _ServerMessages(IntEnum): UNKNOWN = auto() +class _HeaderMessage(StrEnum): + _OK = "OK" + _NOK = "NOK" + _PLOT_READY = "PLOT_READY" + _READY_FOR_NEXT_DATA = "READY_FOR_NEXT_DATA" + + @strstaticproperty + def OK() -> str: + return _HeaderMessage._to_str(_HeaderMessage._OK) + + @strstaticproperty + def NOK() -> str: + return _HeaderMessage._to_str(_HeaderMessage._NOK) + + @strstaticproperty + def PLOT_READY() -> str: + return _HeaderMessage._to_str(_HeaderMessage._PLOT_READY) + + @strstaticproperty + def READY_FOR_NEXT_DATA() -> str: + return _HeaderMessage._to_str(_HeaderMessage._READY_FOR_NEXT_DATA) + + @intstaticproperty + def longest() -> int: + return max(len(v) for v in _HeaderMessage.__members__.values()) + + @intstaticproperty + def header_len() -> int: + return _HeaderMessage.longest + 1 + + @intstaticproperty + def header_generic_len() -> int: + return 1024 + + @staticmethod + def _to_str(message: str) -> str: + return message.ljust(_HeaderMessage.header_len, "\0") + + class PlottingServer: def __init__(self, host: str = None, port: int = None, log_level: int | None = logging.INFO): """ @@ -141,8 +181,8 @@ def _recv_data(self, client_socket: socket.socket, send_confirmation: bool) -> t client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will - not send anything. This is part of the communication protocol + If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data, + otherwise it will not send anything. This is part of the communication protocol Returns ------- @@ -167,8 +207,8 @@ def _recv_message_type_and_data_len( client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will - not send anything. This is part of the communication protocol + If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data, + otherwise it will not send anything. This is part of the communication protocol Returns ------- @@ -177,7 +217,7 @@ def _recv_message_type_and_data_len( # Receive the actual data try: - data = client_socket.recv(1024) + data = client_socket.recv(_HeaderMessage.header_generic_len).decode().strip("\0") if not data: return _ServerMessages.EMPTY, None except: @@ -185,14 +225,14 @@ def _recv_message_type_and_data_len( client_socket.close() return _ServerMessages.CLOSE_CONNEXION, None - data_as_list = data.decode().split("\n") + data_as_list = data.split("\n") try: message_type = _ServerMessages(int(data_as_list[0])) except ValueError: self._logger.error("Unknown message type received") # Sends failure if send_confirmation: - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) return _ServerMessages.UNKNOWN, None if message_type == _ServerMessages.CLOSE_CONNEXION: @@ -207,13 +247,13 @@ def _recv_message_type_and_data_len( self._logger.debug(f"Error: {e}") # Sends failure if send_confirmation: - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) return _ServerMessages.UNKNOWN, None # If we are here, everything went well, so send confirmation self._logger.debug(f"Received from client: {message_type} ({len_all_data} bytes)") if send_confirmation: - client_socket.sendall("OK".encode()) + client_socket.sendall(_HeaderMessage.OK.encode()) return message_type, len_all_data @@ -226,8 +266,8 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation: client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will - not send anything. This is part of the communication protocol + If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data, + otherwise it will not send anything. This is part of the communication protocol len_all_data: list The length of the data to receive @@ -249,12 +289,12 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation: self._logger.debug(f"Error: {e}") # Sends failure if send_confirmation: - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) return None # If we are here, everything went well, so send confirmation if send_confirmation: - client_socket.sendall("OK".encode()) + client_socket.sendall(_HeaderMessage.OK.encode()) self._logger.debug(f"Received data from client: {[len(d) for d in data_out]} bytes") return data_out @@ -275,14 +315,14 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No data_json = json.loads(ocp_raw[0]) except Exception as e: self._logger.error("Error while converting data to json format, closing connexion") - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) raise e try: self._should_send_ok_to_client_on_new_data = data_json["request_confirmation_on_new_data"] except Exception as e: self._logger.error("Did not receive if confirmation should be sent, closing connexion") - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) raise e try: @@ -292,32 +332,32 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No del data_json["dummy_phase_times"] except Exception as e: self._logger.error("Error while extracting dummy time vector from OCP data, closing connexion") - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) raise e try: self.ocp = OcpSerializable.deserialize(data_json) except Exception as e: self._logger.error("Error while deserializing OCP data from client, closing connexion") - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) raise e try: show_options = _deserialize_show_options(ocp_raw[1]) except Exception as e: self._logger.error("Error while extracting show options, closing connexion") - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) raise e try: self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_time_vector, **show_options) except Exception as e: self._logger.error("Error while initializing the plotter, closing connexion") - client_socket.sendall("NOK".encode()) + client_socket.sendall(_HeaderMessage.NOK.encode()) raise e # Send the confirmation to the client - client_socket.sendall("PLOT_READY".encode()) + client_socket.sendall(_HeaderMessage.PLOT_READY.encode()) # Start the callbacks threading.Timer(self._get_data_interval, self._wait_for_new_data_to_plot, (client_socket,)).start() @@ -355,9 +395,9 @@ def _redraw(self) -> None: def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: """ - Waits for new data from the client, sends a "READY_FOR_NEXT_DATA" message to the client to signal that the server - is ready to receive new data. If the client sends new data, the server will update the plot, if client disconnects - the connexion will be closed + Waits for new data from the client, sends a _HeaderMessage.READY_FOR_NEXT_DATA message to the client to signal + that the server is ready to receive new data. If the client sends new data, the server will update the plot, if + client disconnects the connexion will be closed Parameters ---------- @@ -371,7 +411,7 @@ def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: time.sleep(self._update_plot_interval) try: - client_socket.sendall("READY_FOR_NEXT_DATA".encode()) + client_socket.sendall(_HeaderMessage.READY_FOR_NEXT_DATA.encode()) except Exception as e: self._logger.error("Error while sending READY_FOR_NEXT_DATA to client, closing connexion") self._logger.debug(f"Error: {e}") @@ -497,19 +537,20 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None: # Sends message type and dimensions self._socket.sendall( - f"{_ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp), len(serialized_show_options)]}".encode() + f"{_ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp), len(serialized_show_options)]}".ljust( + _HeaderMessage.header_generic_len, "\0" + ).encode() ) - if self._socket.recv(1024).decode() != "OK": + if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK: raise RuntimeError("The server did not acknowledge the connexion") self._socket.sendall(serialized_ocp) self._socket.sendall(serialized_show_options) - if self._socket.recv(1024).decode() != "OK": + if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK: raise RuntimeError("The server did not acknowledge the connexion") # Wait for the server to be ready - data = self._socket.recv(1024).decode().split("\n") - if data[0] != "PLOT_READY": + if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.PLOT_READY: raise RuntimeError("The server did not acknowledge the OCP data, this should not happen, please report") self._plotter = PlotOcp( @@ -545,8 +586,8 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: self._socket.setblocking(False) try: - data = self._socket.recv(1024).decode() - if data != "READY_FOR_NEXT_DATA": + data = self._socket.recv(_HeaderMessage.header_len).decode() + if data != _HeaderMessage.READY_FOR_NEXT_DATA: return [0] except BlockingIOError: # This is to prevent the solving to be blocked by the server if it is not ready to update the plots @@ -561,12 +602,18 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: header, data_serialized = _serialize_xydata(xdata, ydata) self._socket.sendall(f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".encode()) - if self._should_wait_ok_to_client_on_new_data and self._socket.recv(1024).decode() != "OK": + if ( + self._should_wait_ok_to_client_on_new_data + and self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK + ): raise RuntimeError("The server did not acknowledge the connexion") self._socket.sendall(header) self._socket.sendall(data_serialized) - if self._should_wait_ok_to_client_on_new_data and self._socket.recv(1024).decode() != "OK": + if ( + self._should_wait_ok_to_client_on_new_data + and self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK + ): raise RuntimeError("The server did not acknowledge the connexion") return [0] diff --git a/bioptim/gui/utils.py b/bioptim/gui/utils.py new file mode 100644 index 000000000..bcb9460a9 --- /dev/null +++ b/bioptim/gui/utils.py @@ -0,0 +1,14 @@ +class strstaticproperty: + def __init__(self, func): + self.func = func + + def __get__(self, instance, owner) -> str: + return self.func() + + +class intstaticproperty: + def __init__(self, func): + self.func = func + + def __get__(self, instance, owner) -> int: + return self.func()