Skip to content

Commit

Permalink
Fixed a race condition by sending fixed size packet
Browse files Browse the repository at this point in the history
  • Loading branch information
pariterre committed Aug 8, 2024
1 parent ab66fed commit 6dcd269
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 34 deletions.
115 changes: 81 additions & 34 deletions bioptim/gui/online_callback_server.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from enum import IntEnum, auto
from enum import IntEnum, StrEnum, auto
import json
import logging
import platform
Expand All @@ -13,6 +13,7 @@

from .online_callback_abstract import OnlineCallbackAbstract
from .plot import PlotOcp, OcpSerializable
from .utils import strstaticproperty, intstaticproperty
from ..optimization.optimization_vector import OptimizationVectorHelper


Expand All @@ -37,6 +38,45 @@ class _ServerMessages(IntEnum):
UNKNOWN = auto()


class _HeaderMessage(StrEnum):
_OK = "OK"
_NOK = "NOK"
_PLOT_READY = "PLOT_READY"
_READY_FOR_NEXT_DATA = "READY_FOR_NEXT_DATA"

@strstaticproperty
def OK() -> str:
return _HeaderMessage._to_str(_HeaderMessage._OK)

@strstaticproperty
def NOK() -> str:
return _HeaderMessage._to_str(_HeaderMessage._NOK)

@strstaticproperty
def PLOT_READY() -> str:
return _HeaderMessage._to_str(_HeaderMessage._PLOT_READY)

@strstaticproperty
def READY_FOR_NEXT_DATA() -> str:
return _HeaderMessage._to_str(_HeaderMessage._READY_FOR_NEXT_DATA)

@intstaticproperty
def longest() -> int:
return max(len(v) for v in _HeaderMessage.__members__.values())

@intstaticproperty
def header_len() -> int:
return _HeaderMessage.longest + 1

@intstaticproperty
def header_generic_len() -> int:
return 1024

@staticmethod
def _to_str(message: str) -> str:
return message.ljust(_HeaderMessage.header_len, "\0")


class PlottingServer:
def __init__(self, host: str = None, port: int = None, log_level: int | None = logging.INFO):
"""
Expand Down Expand Up @@ -141,8 +181,8 @@ def _recv_data(self, client_socket: socket.socket, send_confirmation: bool) -> t
client_socket: socket.socket
The client socket
send_confirmation: bool
If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will
not send anything. This is part of the communication protocol
If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data,
otherwise it will not send anything. This is part of the communication protocol
Returns
-------
Expand All @@ -167,8 +207,8 @@ def _recv_message_type_and_data_len(
client_socket: socket.socket
The client socket
send_confirmation: bool
If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will
not send anything. This is part of the communication protocol
If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data,
otherwise it will not send anything. This is part of the communication protocol
Returns
-------
Expand All @@ -177,22 +217,22 @@ def _recv_message_type_and_data_len(

# Receive the actual data
try:
data = client_socket.recv(1024)
data = client_socket.recv(_HeaderMessage.header_generic_len).decode().strip("\0")
if not data:
return _ServerMessages.EMPTY, None
except:
self._logger.info("Client closed connexion")
client_socket.close()
return _ServerMessages.CLOSE_CONNEXION, None

data_as_list = data.decode().split("\n")
data_as_list = data.split("\n")
try:
message_type = _ServerMessages(int(data_as_list[0]))
except ValueError:
self._logger.error("Unknown message type received")
# Sends failure
if send_confirmation:
client_socket.sendall("NOK".encode())
client_socket.sendall(_HeaderMessage.NOK.encode())
return _ServerMessages.UNKNOWN, None

if message_type == _ServerMessages.CLOSE_CONNEXION:
Expand All @@ -207,13 +247,13 @@ def _recv_message_type_and_data_len(
self._logger.debug(f"Error: {e}")
# Sends failure
if send_confirmation:
client_socket.sendall("NOK".encode())
client_socket.sendall(_HeaderMessage.NOK.encode())
return _ServerMessages.UNKNOWN, None

# If we are here, everything went well, so send confirmation
self._logger.debug(f"Received from client: {message_type} ({len_all_data} bytes)")
if send_confirmation:
client_socket.sendall("OK".encode())
client_socket.sendall(_HeaderMessage.OK.encode())

return message_type, len_all_data

Expand All @@ -226,8 +266,8 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation:
client_socket: socket.socket
The client socket
send_confirmation: bool
If True, the server will send a "OK" confirmation to the client after receiving the data, otherwise it will
not send anything. This is part of the communication protocol
If True, the server will send a _HeaderMessage.OK confirmation to the client after receiving the data,
otherwise it will not send anything. This is part of the communication protocol
len_all_data: list
The length of the data to receive
Expand All @@ -249,12 +289,12 @@ def _recv_serialize_data(self, client_socket: socket.socket, send_confirmation:
self._logger.debug(f"Error: {e}")
# Sends failure
if send_confirmation:
client_socket.sendall("NOK".encode())
client_socket.sendall(_HeaderMessage.NOK.encode())
return None

# If we are here, everything went well, so send confirmation
if send_confirmation:
client_socket.sendall("OK".encode())
client_socket.sendall(_HeaderMessage.OK.encode())

self._logger.debug(f"Received data from client: {[len(d) for d in data_out]} bytes")
return data_out
Expand All @@ -275,14 +315,14 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No
data_json = json.loads(ocp_raw[0])
except Exception as e:
self._logger.error("Error while converting data to json format, closing connexion")
client_socket.sendall("NOK".encode())
client_socket.sendall(_HeaderMessage.NOK.encode())
raise e

try:
self._should_send_ok_to_client_on_new_data = data_json["request_confirmation_on_new_data"]
except Exception as e:
self._logger.error("Did not receive if confirmation should be sent, closing connexion")
client_socket.sendall("NOK".encode())
client_socket.sendall(_HeaderMessage.NOK.encode())
raise e

try:
Expand All @@ -292,32 +332,32 @@ def _initialize_plotter(self, client_socket: socket.socket, ocp_raw: list) -> No
del data_json["dummy_phase_times"]
except Exception as e:
self._logger.error("Error while extracting dummy time vector from OCP data, closing connexion")
client_socket.sendall("NOK".encode())
client_socket.sendall(_HeaderMessage.NOK.encode())
raise e

try:
self.ocp = OcpSerializable.deserialize(data_json)
except Exception as e:
self._logger.error("Error while deserializing OCP data from client, closing connexion")
client_socket.sendall("NOK".encode())
client_socket.sendall(_HeaderMessage.NOK.encode())
raise e

try:
show_options = _deserialize_show_options(ocp_raw[1])
except Exception as e:
self._logger.error("Error while extracting show options, closing connexion")
client_socket.sendall("NOK".encode())
client_socket.sendall(_HeaderMessage.NOK.encode())
raise e

try:
self._plotter = PlotOcp(self.ocp, dummy_phase_times=dummy_time_vector, **show_options)
except Exception as e:
self._logger.error("Error while initializing the plotter, closing connexion")
client_socket.sendall("NOK".encode())
client_socket.sendall(_HeaderMessage.NOK.encode())
raise e

# Send the confirmation to the client
client_socket.sendall("PLOT_READY".encode())
client_socket.sendall(_HeaderMessage.PLOT_READY.encode())

# Start the callbacks
threading.Timer(self._get_data_interval, self._wait_for_new_data_to_plot, (client_socket,)).start()
Expand Down Expand Up @@ -355,9 +395,9 @@ def _redraw(self) -> None:

def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None:
"""
Waits for new data from the client, sends a "READY_FOR_NEXT_DATA" message to the client to signal that the server
is ready to receive new data. If the client sends new data, the server will update the plot, if client disconnects
the connexion will be closed
Waits for new data from the client, sends a _HeaderMessage.READY_FOR_NEXT_DATA message to the client to signal
that the server is ready to receive new data. If the client sends new data, the server will update the plot, if
client disconnects the connexion will be closed
Parameters
----------
Expand All @@ -371,7 +411,7 @@ def _wait_for_new_data_to_plot(self, client_socket: socket.socket) -> None:
time.sleep(self._update_plot_interval)

try:
client_socket.sendall("READY_FOR_NEXT_DATA".encode())
client_socket.sendall(_HeaderMessage.READY_FOR_NEXT_DATA.encode())
except Exception as e:
self._logger.error("Error while sending READY_FOR_NEXT_DATA to client, closing connexion")
self._logger.debug(f"Error: {e}")
Expand Down Expand Up @@ -497,19 +537,20 @@ def _initialize_connexion(self, retries: int = 0, **show_options) -> None:

# Sends message type and dimensions
self._socket.sendall(
f"{_ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp), len(serialized_show_options)]}".encode()
f"{_ServerMessages.INITIATE_CONNEXION.value}\n{[len(serialized_ocp), len(serialized_show_options)]}".ljust(
_HeaderMessage.header_generic_len, "\0"
).encode()
)
if self._socket.recv(1024).decode() != "OK":
if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK:
raise RuntimeError("The server did not acknowledge the connexion")

self._socket.sendall(serialized_ocp)
self._socket.sendall(serialized_show_options)
if self._socket.recv(1024).decode() != "OK":
if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK:
raise RuntimeError("The server did not acknowledge the connexion")

# Wait for the server to be ready
data = self._socket.recv(1024).decode().split("\n")
if data[0] != "PLOT_READY":
if self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.PLOT_READY:
raise RuntimeError("The server did not acknowledge the OCP data, this should not happen, please report")

self._plotter = PlotOcp(
Expand Down Expand Up @@ -545,8 +586,8 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]:
self._socket.setblocking(False)

try:
data = self._socket.recv(1024).decode()
if data != "READY_FOR_NEXT_DATA":
data = self._socket.recv(_HeaderMessage.header_len).decode()
if data != _HeaderMessage.READY_FOR_NEXT_DATA:
return [0]
except BlockingIOError:
# This is to prevent the solving to be blocked by the server if it is not ready to update the plots
Expand All @@ -561,12 +602,18 @@ def eval(self, arg: list | tuple, enforce: bool = False) -> list[int]:
header, data_serialized = _serialize_xydata(xdata, ydata)

self._socket.sendall(f"{_ServerMessages.NEW_DATA.value}\n{[len(header), len(data_serialized)]}".encode())
if self._should_wait_ok_to_client_on_new_data and self._socket.recv(1024).decode() != "OK":
if (
self._should_wait_ok_to_client_on_new_data
and self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK
):
raise RuntimeError("The server did not acknowledge the connexion")

self._socket.sendall(header)
self._socket.sendall(data_serialized)
if self._should_wait_ok_to_client_on_new_data and self._socket.recv(1024).decode() != "OK":
if (
self._should_wait_ok_to_client_on_new_data
and self._socket.recv(_HeaderMessage.header_len).decode() != _HeaderMessage.OK
):
raise RuntimeError("The server did not acknowledge the connexion")

return [0]
Expand Down
14 changes: 14 additions & 0 deletions bioptim/gui/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
class strstaticproperty:
def __init__(self, func):
self.func = func

def __get__(self, instance, owner) -> str:
return self.func()


class intstaticproperty:
def __init__(self, func):
self.func = func

def __get__(self, instance, owner) -> int:
return self.func()

0 comments on commit 6dcd269

Please sign in to comment.