Skip to content

Commit

Permalink
Simplified ResponseHeader
Browse files Browse the repository at this point in the history
  • Loading branch information
pariterre committed Aug 8, 2024
1 parent a688b91 commit ead0f76
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 83 deletions.
129 changes: 61 additions & 68 deletions bioptim/gui/online_callback_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@

from .online_callback_abstract import OnlineCallbackAbstract
from .plot import PlotOcp, OcpSerializable
from .utils import strstaticproperty, intstaticproperty
from ..optimization.optimization_vector import OptimizationVectorHelper


_DEFAULT_HOST = "localhost"
_DEFAULT_PORT = 3050
_HEADER_GENERIC_LEN = 1024


def _serialize_show_options(show_options: dict) -> bytes:
Expand All @@ -38,43 +38,31 @@ class _ServerMessages(IntEnum):
UNKNOWN = auto()


class _HeaderMessage(StrEnum):
_OK = "OK"
_NOK = "NOK"
_PLOT_READY = "PLOT_READY"
_READY_FOR_NEXT_DATA = "READY_FOR_NEXT_DATA"
class _ResponseHeader(StrEnum):
OK = "OK"
NOK = "NOK"
PLOT_READY = "PLOT_READY"
READY_FOR_NEXT_DATA = "READY_FOR_NEXT_DATA"

@strstaticproperty
def OK() -> str:
return _HeaderMessage._to_str(_HeaderMessage._OK)

@strstaticproperty
def NOK() -> str:
return _HeaderMessage._to_str(_HeaderMessage._NOK)

@strstaticproperty
def PLOT_READY() -> str:
return _HeaderMessage._to_str(_HeaderMessage._PLOT_READY)
@staticmethod
def longest() -> int:
return max(len(v) for v in _ResponseHeader.__members__)

@strstaticproperty
def READY_FOR_NEXT_DATA() -> str:
return _HeaderMessage._to_str(_HeaderMessage._READY_FOR_NEXT_DATA)
def encode(self) -> str:
return self.ljust(len(self), "\0").encode()

@intstaticproperty
def longest() -> int:
return max(len(v) for v in _HeaderMessage.__members__.values())
@staticmethod
def response_len() -> int:
return _ResponseHeader.longest() + 1

@intstaticproperty
def header_len() -> int:
return _HeaderMessage.longest + 1
def __len__(self) -> int:
return _ResponseHeader.response_len()

@intstaticproperty
def header_generic_len() -> int:
return 1024
def __eq__(self, value: object) -> bool:
return self.split("\0")[0] == value.split("\0")[0] or super().__eq__(value)

@staticmethod
def _to_str(message: str) -> str:
return message.ljust(_HeaderMessage.header_len, "\0")
def __ne__(self, value: object) -> bool:
return not self.__eq__(value)


class PlottingServer:
Expand Down Expand Up @@ -181,7 +169,7 @@ def _recv_data(self, client_socket: socket.socket, send_confirmation: bool) -> t
client_socket: socket.socket
The client socket
send_confirmation: bool
If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data,
If True, the server will send a _ResponseHeader.OK confirmation to the client after receiving the data,
otherwise it will not send anything. This is part of the communication protocol
Returns
Expand All @@ -207,7 +195,7 @@ def _recv_message_type_and_data_len(
client_socket: socket.socket
The client socket
send_confirmation: bool
If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data,
If True, the server will send a _ResponseHeader.OK confirmation to the client after receiving the data,
otherwise it will not send anything. This is part of the communication protocol
Returns
Expand All @@ -217,7 +205,7 @@ def _recv_message_type_and_data_len(

# Receive the actual data
try:
data = client_socket.recv(_HeaderMessage.header_generic_len).decode().strip("\0")
data = client_socket.recv(_HEADER_GENERIC_LEN).decode().strip("\0")
if not data:
return _ServerMessages.EMPTY, None
except:
Expand All @@ -232,7 +220,7 @@ def _recv_message_type_and_data_len(
self._logger.error("Unknown message type received")
# Sends failure
if send_confirmation:
client_socket.sendall(_HeaderMessage.NOK.encode())
client_socket.sendall(_ResponseHeader.NOK.encode())
return _ServerMessages.UNKNOWN, None

if message_type == _ServerMessages.CLOSE_CONNEXION:
Expand All @@ -247,13 +235,13 @@ def _recv_message_type_and_data_len(
self._logger.debug(f"Error: {e}")
# Sends failure
if send_confirmation:
client_socket.sendall(_HeaderMessage.NOK.encode())
client_socket.sendall(_ResponseHeader.NOK.encode())
return _ServerMessages.UNKNOWN, None

# If we are here, everything went well, so send confirmation
self._logger.debug(f"Received from client: {message_type} ({len_all_data} bytes)")
if send_confirmation:
client_socket.sendall(_HeaderMessage.OK.encode())
client_socket.sendall(_ResponseHeader.OK.encode())

return message_type, len_all_data

Expand All @@ -266,7 +254,7 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation:
client_socket: socket.socket
The client socket
send_confirmation: bool
If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data,
If True, the server will send a _ResponseHeader.OK confirmation to the client after receiving the data,
otherwise it will not send anything. This is part of the communication protocol
len_all_data: list
The length of the data to receive
Expand All @@ -289,12 +277,12 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation:
self._logger.debug(f"Error: {e}")
# Sends failure
if send_confirmation:
client_socket.sendall(_HeaderMessage.NOK.encode())
client_socket.sendall(_ResponseHeader.NOK.encode())
return None

# If we are here, everything went well, so send confirmation
if send_confirmation:
client_socket.sendall(_HeaderMessage.OK.encode())
client_socket.sendall(_ResponseHeader.OK.encode())

self._logger.debug(f"Received data from client: {[len(d) for d in data_out]} bytes")
return data_out
Expand All @@ -315,14 +303,14 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No
data_json = json.loads(ocp_raw[0])
except Exception as e:
self._logger.error("Error while converting data to json format, closing connexion")
client_socket.sendall(_HeaderMessage.NOK.encode())
client_socket.sendall(_ResponseHeader.NOK.encode())
raise e

try:
self._should_send_ok_to_client_on_new_data = data_json["request_confirmation_on_new_data"]
except Exception as e:
self._logger.error("Did not receive if confirmation should be sent, closing connexion")
client_socket.sendall(_HeaderMessage.NOK.encode())
client_socket.sendall(_ResponseHeader.NOK.encode())
raise e

try:
Expand All @@ -332,32 +320,32 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No
del data_json["dummy_phase_times"]
except Exception as e:
self._logger.error("Error while extracting dummy time vector from OCP data, closing connexion")
client_socket.sendall(_HeaderMessage.NOK.encode())
client_socket.sendall(_ResponseHeader.NOK.encode())
raise e

try:
self.ocp = OcpSerializable.deserialize(data_json)
except Exception as e:
self._logger.error("Error while deserializing OCP data from client, closing connexion")
client_socket.sendall(_HeaderMessage.NOK.encode())
client_socket.sendall(_ResponseHeader.NOK.encode())
raise e

try:
show_options = _deserialize_show_options(ocp_raw[1])
except Exception as e:
self._logger.error("Error while extracting show options, closing connexion")
client_socket.sendall(_HeaderMessage.NOK.encode())
client_socket.sendall(_ResponseHeader.NOK.encode())
raise e

try:
self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_time_vector, **show_options)
except Exception as e:
self._logger.error("Error while initializing the plotter, closing connexion")
client_socket.sendall(_HeaderMessage.NOK.encode())
client_socket.sendall(_ResponseHeader.NOK.encode())
raise e

# Send the confirmation to the client
client_socket.sendall(_HeaderMessage.PLOT_READY.encode())
client_socket.sendall(_ResponseHeader.PLOT_READY.encode())

# Start the callbacks
threading.Timer(self._get_data_interval, self._wait_for_new_data_to_plot, (client_socket,)).start()
Expand Down Expand Up @@ -395,7 +383,7 @@ def _redraw(self) -> None:

def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None:
"""
Waits for new data from the client, sends a _HeaderMessage.READY_FOR_NEXT_DATA message to the client to signal
Waits for new data from the client, sends a _ResponseHeader.READY_FOR_NEXT_DATA message to the client to signal
that the server is ready to receive new data. If the client sends new data, the server will update the plot, if
client disconnects the connexion will be closed
Expand All @@ -411,7 +399,7 @@ def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None:
time.sleep(self._update_plot_interval)

try:
client_socket.sendall(_HeaderMessage.READY_FOR_NEXT_DATA.encode())
client_socket.sendall(_ResponseHeader.READY_FOR_NEXT_DATA.encode())
except Exception as e:
self._logger.error("Error while sending READY_FOR_NEXT_DATA to client, closing connexion")
self._logger.debug(f"Error: {e}")
Expand Down Expand Up @@ -476,10 +464,10 @@ def __init__(self, ocp, opts: dict = None, host: str = None, port: int = None, *
The port to connect to, by default 3050
"""

super().__init__(ocp, opts, **show_options)

self._host = host if host else _DEFAULT_HOST
self._port = port if port else _DEFAULT_PORT

super().__init__(ocp, opts, **show_options)
self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

self._should_wait_ok_to_client_on_new_data = platform.system() == "Darwin"
Expand Down Expand Up @@ -518,7 +506,7 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None:
raise RuntimeError(
"Could not connect to the plotter server, make sure it is running by calling 'PlottingServer()' on "
"another python instance or allowing for automatic start (Linux or Windows) of the server setting "
"the online_option to 'OnlineOptim.MULTIPROCESS_SERVER' when instantiating your solver"
"the online_option to 'OnlineOptim.MULTIPROCESS_SERVER' when instantiating your solver."
)
else:
time.sleep(1)
Expand All @@ -538,25 +526,36 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None:
# Sends message type and dimensions
self._socket.sendall(
f"{_ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp), len(serialized_show_options)]}".ljust(
_HeaderMessage.header_generic_len, "\0"
_HEADER_GENERIC_LEN, "\0"
).encode()
)
if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK:
if not self._has_received_ok():
raise RuntimeError("The server did not acknowledge the connexion")

self._socket.sendall(serialized_ocp)
self._socket.sendall(serialized_show_options)
if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK:
if not self._has_received_ok():
raise RuntimeError("The server did not acknowledge the connexion")

# Wait for the server to be ready
if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.PLOT_READY:
if self._socket.recv(_ResponseHeader.response_len()).decode() != _ResponseHeader.PLOT_READY:
raise RuntimeError("The server did not acknowledge the OCP data, this should not happen, please report")

self._plotter = PlotOcp(
self.ocp, only_initialize_variables=True, dummy_phase_times=dummy_phase_times, **show_options
)

def _has_received_ok(self) -> bool:
"""
Checks if the server has sent an OK message
Returns
-------
If the server has sent an OK message
"""

return self._socket.recv(_ResponseHeader.response_len()).decode() == _ResponseHeader.OK

def close(self) -> None:
"""
Closes the connexion
Expand Down Expand Up @@ -586,8 +585,8 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]:
self._socket.setblocking(False)

try:
data = self._socket.recv(_HeaderMessage.header_len).decode()
if data != _HeaderMessage.READY_FOR_NEXT_DATA:
data = self._socket.recv(_ResponseHeader.response_len()).decode()
if data != _ResponseHeader.READY_FOR_NEXT_DATA:
return [0]
except BlockingIOError:
# This is to prevent the solving to be blocked by the server if it is not ready to update the plots
Expand All @@ -603,21 +602,15 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]:

self._socket.sendall(
f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".ljust(
_HeaderMessage.header_generic_len, "\0"
_HEADER_GENERIC_LEN, "\0"
).encode()
)
if (
self._should_wait_ok_to_client_on_new_data
and self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK
):
if self._should_wait_ok_to_client_on_new_data and not self._has_received_ok():
raise RuntimeError("The server did not acknowledge the connexion")

self._socket.sendall(header)
self._socket.sendall(data_serialized)
if (
self._should_wait_ok_to_client_on_new_data
and self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK
):
if self._should_wait_ok_to_client_on_new_data and not self._has_received_ok():
raise RuntimeError("The server did not acknowledge the connexion")

return [0]
Expand Down
14 changes: 0 additions & 14 deletions bioptim/gui/utils.py

This file was deleted.

24 changes: 23 additions & 1 deletion tests/shard5/test_plot_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from bioptim.gui.online_callback_server import _serialize_xydata, _deserialize_xydata
from bioptim.gui.plot import PlotOcp
from bioptim.gui.online_callback_server import _ResponseHeader
from bioptim.optimization.optimization_vector import OptimizationVectorHelper
from casadi import DM, Function
from casadi import DM
import numpy as np


Expand Down Expand Up @@ -40,3 +41,24 @@ def test_serialize_deserialize():
else:
for y_phase, deserialized_y_phase in zip(y_variable, deserialized_y_variable):
assert np.allclose(y_phase, deserialized_y_phase)


def test_response_header():
# Make sure all the response have the same length
response_len = _ResponseHeader.response_len()
for response in _ResponseHeader:
assert len(response) == response_len
# Make sure encoding provides a constant length
assert len(response.encode()) == response_len

# Make sure equality works
assert _ResponseHeader.OK == _ResponseHeader.OK
assert _ResponseHeader.OK.value == _ResponseHeader.OK
assert _ResponseHeader.OK.encode().decode() == _ResponseHeader.OK
assert _ResponseHeader.OK == _ResponseHeader.OK.encode().decode()
assert not (_ResponseHeader.OK != _ResponseHeader.OK)
assert not (_ResponseHeader.OK.encode().decode() != _ResponseHeader.OK)
assert not (_ResponseHeader.OK != _ResponseHeader.OK.encode().decode())
assert not (_ResponseHeader.OK.value == _ResponseHeader.OK.encode().decode())
assert _ResponseHeader.OK != _ResponseHeader.NOK
assert _ResponseHeader.NOK == _ResponseHeader.NOK

0 comments on commit ead0f76

Please sign in to comment.