diff --git a/bioptim/gui/online_callback_server.py b/bioptim/gui/online_callback_server.py index 08f422b18..43fb96542 100644 --- a/bioptim/gui/online_callback_server.py +++ b/bioptim/gui/online_callback_server.py @@ -13,12 +13,12 @@ from .online_callback_abstract import OnlineCallbackAbstract from .plot import PlotOcp, OcpSerializable -from .utils import strstaticproperty, intstaticproperty from ..optimization.optimization_vector import OptimizationVectorHelper _DEFAULT_HOST = "localhost" _DEFAULT_PORT = 3050 +_HEADER_GENERIC_LEN = 1024 def _serialize_show_options(show_options: dict) -> bytes: @@ -38,43 +38,31 @@ class _ServerMessages(IntEnum): UNKNOWN = auto() -class _HeaderMessage(StrEnum): - _OK = "OK" - _NOK = "NOK" - _PLOT_READY = "PLOT_READY" - _READY_FOR_NEXT_DATA = "READY_FOR_NEXT_DATA" +class _ResponseHeader(StrEnum): + OK = "OK" + NOK = "NOK" + PLOT_READY = "PLOT_READY" + READY_FOR_NEXT_DATA = "READY_FOR_NEXT_DATA" - @strstaticproperty - def OK() -> str: - return _HeaderMessage._to_str(_HeaderMessage._OK) - - @strstaticproperty - def NOK() -> str: - return _HeaderMessage._to_str(_HeaderMessage._NOK) - - @strstaticproperty - def PLOT_READY() -> str: - return _HeaderMessage._to_str(_HeaderMessage._PLOT_READY) + @staticmethod + def longest() -> int: + return max(len(v) for v in _ResponseHeader.__members__) - @strstaticproperty - def READY_FOR_NEXT_DATA() -> str: - return _HeaderMessage._to_str(_HeaderMessage._READY_FOR_NEXT_DATA) + def encode(self) -> str: + return self.ljust(len(self), "\0").encode() - @intstaticproperty - def longest() -> int: - return max(len(v) for v in _HeaderMessage.__members__.values()) + @staticmethod + def response_len() -> int: + return _ResponseHeader.longest() + 1 - @intstaticproperty - def header_len() -> int: - return _HeaderMessage.longest + 1 + def __len__(self) -> int: + return _ResponseHeader.response_len() - @intstaticproperty - def header_generic_len() -> int: - return 1024 + def __eq__(self, value: object) -> bool: + return self.split("\0")[0] == value.split("\0")[0] or super().__eq__(value) - @staticmethod - def _to_str(message: str) -> str: - return message.ljust(_HeaderMessage.header_len, "\0") + def __ne__(self, value: object) -> bool: + return not self.__eq__(value) class PlottingServer: @@ -181,7 +169,7 @@ def _recv_data(self, client_socket: socket.socket, send_confirmation: bool) -> t client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data, + If True, the server will send a _ResponseHeader.OK confirmation to the client after receiving the data, otherwise it will not send anything. This is part of the communication protocol Returns @@ -207,7 +195,7 @@ def _recv_message_type_and_data_len( client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data, + If True, the server will send a _ResponseHeader.OK confirmation to the client after receiving the data, otherwise it will not send anything. This is part of the communication protocol Returns @@ -217,7 +205,7 @@ def _recv_message_type_and_data_len( # Receive the actual data try: - data = client_socket.recv(_HeaderMessage.header_generic_len).decode().strip("\0") + data = client_socket.recv(_HEADER_GENERIC_LEN).decode().strip("\0") if not data: return _ServerMessages.EMPTY, None except: @@ -232,7 +220,7 @@ def _recv_message_type_and_data_len( self._logger.error("Unknown message type received") # Sends failure if send_confirmation: - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) return _ServerMessages.UNKNOWN, None if message_type == _ServerMessages.CLOSE_CONNEXION: @@ -247,13 +235,13 @@ def _recv_message_type_and_data_len( self._logger.debug(f"Error: {e}") # Sends failure if send_confirmation: - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) return _ServerMessages.UNKNOWN, None # If we are here, everything went well, so send confirmation self._logger.debug(f"Received from client: {message_type} ({len_all_data} bytes)") if send_confirmation: - client_socket.sendall(_HeaderMessage.OK.encode()) + client_socket.sendall(_ResponseHeader.OK.encode()) return message_type, len_all_data @@ -266,7 +254,7 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation: client_socket: socket.socket The client socket send_confirmation: bool - If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data, + If True, the server will send a _ResponseHeader.OK confirmation to the client after receiving the data, otherwise it will not send anything. This is part of the communication protocol len_all_data: list The length of the data to receive @@ -289,12 +277,12 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation: self._logger.debug(f"Error: {e}") # Sends failure if send_confirmation: - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) return None # If we are here, everything went well, so send confirmation if send_confirmation: - client_socket.sendall(_HeaderMessage.OK.encode()) + client_socket.sendall(_ResponseHeader.OK.encode()) self._logger.debug(f"Received data from client: {[len(d) for d in data_out]} bytes") return data_out @@ -315,14 +303,14 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No data_json = json.loads(ocp_raw[0]) except Exception as e: self._logger.error("Error while converting data to json format, closing connexion") - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) raise e try: self._should_send_ok_to_client_on_new_data = data_json["request_confirmation_on_new_data"] except Exception as e: self._logger.error("Did not receive if confirmation should be sent, closing connexion") - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) raise e try: @@ -332,32 +320,32 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No del data_json["dummy_phase_times"] except Exception as e: self._logger.error("Error while extracting dummy time vector from OCP data, closing connexion") - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) raise e try: self.ocp = OcpSerializable.deserialize(data_json) except Exception as e: self._logger.error("Error while deserializing OCP data from client, closing connexion") - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) raise e try: show_options = _deserialize_show_options(ocp_raw[1]) except Exception as e: self._logger.error("Error while extracting show options, closing connexion") - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) raise e try: self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_time_vector, **show_options) except Exception as e: self._logger.error("Error while initializing the plotter, closing connexion") - client_socket.sendall(_HeaderMessage.NOK.encode()) + client_socket.sendall(_ResponseHeader.NOK.encode()) raise e # Send the confirmation to the client - client_socket.sendall(_HeaderMessage.PLOT_READY.encode()) + client_socket.sendall(_ResponseHeader.PLOT_READY.encode()) # Start the callbacks threading.Timer(self._get_data_interval, self._wait_for_new_data_to_plot, (client_socket,)).start() @@ -395,7 +383,7 @@ def _redraw(self) -> None: def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: """ - Waits for new data from the client, sends a _HeaderMessage.READY_FOR_NEXT_DATA message to the client to signal + Waits for new data from the client, sends a _ResponseHeader.READY_FOR_NEXT_DATA message to the client to signal that the server is ready to receive new data. If the client sends new data, the server will update the plot, if client disconnects the connexion will be closed @@ -411,7 +399,7 @@ def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None: time.sleep(self._update_plot_interval) try: - client_socket.sendall(_HeaderMessage.READY_FOR_NEXT_DATA.encode()) + client_socket.sendall(_ResponseHeader.READY_FOR_NEXT_DATA.encode()) except Exception as e: self._logger.error("Error while sending READY_FOR_NEXT_DATA to client, closing connexion") self._logger.debug(f"Error: {e}") @@ -476,10 +464,10 @@ def __init__(self, ocp, opts: dict = None, host: str = None, port: int = None, * The port to connect to, by default 3050 """ + super().__init__(ocp, opts, **show_options) + self._host = host if host else _DEFAULT_HOST self._port = port if port else _DEFAULT_PORT - - super().__init__(ocp, opts, **show_options) self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._should_wait_ok_to_client_on_new_data = platform.system() == "Darwin" @@ -518,7 +506,7 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None: raise RuntimeError( "Could not connect to the plotter server, make sure it is running by calling 'PlottingServer()' on " "another python instance or allowing for automatic start (Linux or Windows) of the server setting " - "the online_option to 'OnlineOptim.MULTIPROCESS_SERVER' when instantiating your solver" + "the online_option to 'OnlineOptim.MULTIPROCESS_SERVER' when instantiating your solver." ) else: time.sleep(1) @@ -538,25 +526,36 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None: # Sends message type and dimensions self._socket.sendall( f"{_ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp), len(serialized_show_options)]}".ljust( - _HeaderMessage.header_generic_len, "\0" + _HEADER_GENERIC_LEN, "\0" ).encode() ) - if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK: + if not self._has_received_ok(): raise RuntimeError("The server did not acknowledge the connexion") self._socket.sendall(serialized_ocp) self._socket.sendall(serialized_show_options) - if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK: + if not self._has_received_ok(): raise RuntimeError("The server did not acknowledge the connexion") # Wait for the server to be ready - if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.PLOT_READY: + if self._socket.recv(_ResponseHeader.response_len()).decode() != _ResponseHeader.PLOT_READY: raise RuntimeError("The server did not acknowledge the OCP data, this should not happen, please report") self._plotter = PlotOcp( self.ocp, only_initialize_variables=True, dummy_phase_times=dummy_phase_times, **show_options ) + def _has_received_ok(self) -> bool: + """ + Checks if the server has sent an OK message + + Returns + ------- + If the server has sent an OK message + """ + + return self._socket.recv(_ResponseHeader.response_len()).decode() == _ResponseHeader.OK + def close(self) -> None: """ Closes the connexion @@ -586,8 +585,8 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: self._socket.setblocking(False) try: - data = self._socket.recv(_HeaderMessage.header_len).decode() - if data != _HeaderMessage.READY_FOR_NEXT_DATA: + data = self._socket.recv(_ResponseHeader.response_len()).decode() + if data != _ResponseHeader.READY_FOR_NEXT_DATA: return [0] except BlockingIOError: # This is to prevent the solving to be blocked by the server if it is not ready to update the plots @@ -603,21 +602,15 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]: self._socket.sendall( f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".ljust( - _HeaderMessage.header_generic_len, "\0" + _HEADER_GENERIC_LEN, "\0" ).encode() ) - if ( - self._should_wait_ok_to_client_on_new_data - and self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK - ): + if self._should_wait_ok_to_client_on_new_data and not self._has_received_ok(): raise RuntimeError("The server did not acknowledge the connexion") self._socket.sendall(header) self._socket.sendall(data_serialized) - if ( - self._should_wait_ok_to_client_on_new_data - and self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK - ): + if self._should_wait_ok_to_client_on_new_data and not self._has_received_ok(): raise RuntimeError("The server did not acknowledge the connexion") return [0] diff --git a/bioptim/gui/utils.py b/bioptim/gui/utils.py deleted file mode 100644 index bcb9460a9..000000000 --- a/bioptim/gui/utils.py +++ /dev/null @@ -1,14 +0,0 @@ -class strstaticproperty: - def __init__(self, func): - self.func = func - - def __get__(self, instance, owner) -> str: - return self.func() - - -class intstaticproperty: - def __init__(self, func): - self.func = func - - def __get__(self, instance, owner) -> int: - return self.func() diff --git a/tests/shard5/test_plot_server.py b/tests/shard5/test_plot_server.py index af774b260..442726ad5 100644 --- a/tests/shard5/test_plot_server.py +++ b/tests/shard5/test_plot_server.py @@ -2,8 +2,9 @@ from bioptim.gui.online_callback_server import _serialize_xydata, _deserialize_xydata from bioptim.gui.plot import PlotOcp +from bioptim.gui.online_callback_server import _ResponseHeader from bioptim.optimization.optimization_vector import OptimizationVectorHelper -from casadi import DM, Function +from casadi import DM import numpy as np @@ -40,3 +41,24 @@ def test_serialize_deserialize(): else: for y_phase, deserialized_y_phase in zip(y_variable, deserialized_y_variable): assert np.allclose(y_phase, deserialized_y_phase) + + +def test_response_header(): + # Make sure all the response have the same length + response_len = _ResponseHeader.response_len() + for response in _ResponseHeader: + assert len(response) == response_len + # Make sure encoding provides a constant length + assert len(response.encode()) == response_len + + # Make sure equality works + assert _ResponseHeader.OK == _ResponseHeader.OK + assert _ResponseHeader.OK.value == _ResponseHeader.OK + assert _ResponseHeader.OK.encode().decode() == _ResponseHeader.OK + assert _ResponseHeader.OK == _ResponseHeader.OK.encode().decode() + assert not (_ResponseHeader.OK != _ResponseHeader.OK) + assert not (_ResponseHeader.OK.encode().decode() != _ResponseHeader.OK) + assert not (_ResponseHeader.OK != _ResponseHeader.OK.encode().decode()) + assert not (_ResponseHeader.OK.value == _ResponseHeader.OK.encode().decode()) + assert _ResponseHeader.OK != _ResponseHeader.NOK + assert _ResponseHeader.NOK == _ResponseHeader.NOK