From 3a5a478d7c6529f4d4abc8aa317ffee8d3882a39 Mon Sep 17 00:00:00 2001 From: hassandigicatapult Date: Fri, 13 Sep 2024 14:20:36 +0100 Subject: [PATCH] fix_issue_7_blocking_on_linux_refactor: temp-changes --- src/dc_federated/backend/dcf_server.py | 15 ++-- src/dc_federated/backend/zmq_interface.py | 10 ++- tests/test_zmq_interface.py | 103 +++++++++++++--------- 3 files changed, 76 insertions(+), 52 deletions(-) diff --git a/src/dc_federated/backend/dcf_server.py b/src/dc_federated/backend/dcf_server.py index ab84d3f..5cbf41f 100644 --- a/src/dc_federated/backend/dcf_server.py +++ b/src/dc_federated/backend/dcf_server.py @@ -726,6 +726,7 @@ def __init__( self.context = zmq.Context() self.socket = self.context.socket(zmq.ROUTER) + print(f"binding to " + f"tcp://*:{self.socket_port}") self.socket.bind(f"tcp://*:{self.socket_port}") def start_server(self): @@ -740,19 +741,24 @@ def start_server(self): command = " ".join(["python", subprocess_file, str(self.socket_port)]) sp.Popen(command, stdout=sys.stdout, stderr=sys.stderr, shell=True) - self.wait_for_messages() + self.run() def run(self): """ Runs the main loop of the DCFServerHandler. """ + print(self.socket.send()) + if True: + return + poller = zmq.Poller() poller.register(self.socket, zmq.POLLIN) - + print("starting listening") while True: socks = dict(poller.poll()) - + print("something!") if self.socket in socks: + print("handling") self.handle_received_message(self.socket.recv_multipart()) else: raise LookupError(f"Unknown socket detected by zmq poller {self.socket}.") @@ -793,6 +799,3 @@ def receive(self, message): logger.error( f'ZQM messaging interface received unrecognised message type: "{message[0]}"' ) - - def __del__(self): - self.context.term() diff --git a/src/dc_federated/backend/zmq_interface.py b/src/dc_federated/backend/zmq_interface.py index 2f5af75..15b6280 100644 --- a/src/dc_federated/backend/zmq_interface.py +++ b/src/dc_federated/backend/zmq_interface.py @@ -88,6 +88,7 @@ def __init__(self, port) -> None: self.port = port self.context = zmq.Context() self.socket = self.context.socket(zmq.DEALER) + print(f"connecting to " + f"tcp://localhost:{self.port}") self.socket.connect(f"tcp://localhost:{self.port}") def server_args_request_send(self): @@ -105,9 +106,12 @@ def server_args_request_send(self): ssl_keyfile, ssl_certfile, model_check_interval, debug """ # socket = self._send([b"server_args_request"]) - self.socket.send_multipart([b"server_args_request"]) - output = self.socket.recv_pyobj() - return output + # self.socket.send_multipart([b"server_args_request"]) + self.socket.send_pyobj("hello") + + print("Sent") + # output = self.socket.recv_pyobj() + # return output def register_worker_send(self, worker_id): """ diff --git a/tests/test_zmq_interface.py b/tests/test_zmq_interface.py index 6454b1a..e615a3a 100644 --- a/tests/test_zmq_interface.py +++ b/tests/test_zmq_interface.py @@ -1,4 +1,6 @@ -from dc_federated.backend.zmq_interface import ZMQInterfaceModel, ZMQInterfaceServer +from cProfile import run +from dc_federated.backend.zmq_interface import ZMQInterfaceServer +from dc_federated.backend.dcf_server import DCFServerHandler import zmq from unittest.mock import Mock, patch import threading @@ -53,7 +55,7 @@ def mock_new_socket(): converts the relevant "recv" functions from blocking to non-blocking. """ context = zmq.Context() - socket = context.socket(zmq.REQ) + socket = context.socket(zmq.DEALER) socket.connect(f"tcp://localhost:{port}") mock_recv(context, socket, "recv_pyobj") @@ -62,74 +64,89 @@ def mock_new_socket(): return socket -@pytest.fixture(autouse=True) +# @pytest.fixture(autouse=True) def run_model_interface(): """ A function for initialising the REP socket that is initialised by the model-process. This is done in its own thread with a non-blocking `recv_multipart` so that the server-process can send/receive messages too. """ - context = zmq.Context() - socket = context.socket(zmq.REP) - socket.bind(f"tcp://*:{port}") - - mock_recv(context, socket, "recv_multipart", close=False) - - zmqM = ZMQInterfaceModel( - socket=socket, + # context = zmq.Context() + # socket = context.socket(zmq.REP) + # socket.bind(f"tcp://*:{port}") + + # mock_recv(context, socket, "recv_multipart", close=False) + + # zmqM = DCFServerHandler( + # socket=socket, + # register_worker_callback=mock_register_worker_callback, + # unregister_worker_callback=mock_unregister_worker_callback, + # return_global_model_callback=mock_return_global_model_callback, + # is_global_model_most_recent=mock_is_global_model_most_recent, + # receive_worker_update_callback=mock_receive_worker_update_callback, + # server_subprocess_args=server_subprocess_args, + # ) + + dcf_server_handler = DCFServerHandler( register_worker_callback=mock_register_worker_callback, unregister_worker_callback=mock_unregister_worker_callback, return_global_model_callback=mock_return_global_model_callback, is_global_model_most_recent=mock_is_global_model_most_recent, receive_worker_update_callback=mock_receive_worker_update_callback, - server_subprocess_args=server_subprocess_args, + key_list_file="", + server_mode_safe=False, + socket_port=port ) - thread = threading.Thread(target=zmqM.receive, daemon=True) - thread.start() - yield thread - socket.close() - context.term() + # dcf_server_handler.run() + # thread = threading.Thread(target=dcf_server_handler.run, daemon=True) + threading.Thread(target=lambda: dcf_server_handler.run()).start() + # thread.start() + # yield thread zmqS = ZMQInterfaceServer(port=port) # Test each of the interface's functions -@patch.object(zmqS, "_new_socket", mock_new_socket) +# @patch.object(zmqS, "_new_socket", mock_new_socket) def test_server_args_request(): + run_model_interface() + print("here") result = zmqS.server_args_request_send() assert result == server_subprocess_args -@patch.object(zmqS, "_new_socket", mock_new_socket) -def test_register_worker(): - result = zmqS.register_worker_send("test123") - assert result == b"1" - mock_register_worker_callback.assert_called_once_with("test123") +# @patch.object(zmqS, "_new_socket", mock_new_socket) +# def test_register_worker(): +# result = zmqS.register_worker_send("test123") +# assert result == b"1" +# mock_register_worker_callback.assert_called_once_with("test123") + +# @patch.object(zmqS, "_new_socket", mock_new_socket) +# def test_unregister_worker(): +# result = zmqS.unregister_worker_send("test123") +# assert result == b"1" +# mock_unregister_worker_callback.assert_called_once_with("test123") -@patch.object(zmqS, "_new_socket", mock_new_socket) -def test_unregister_worker(): - result = zmqS.unregister_worker_send("test123") - assert result == b"1" - mock_unregister_worker_callback.assert_called_once_with("test123") +# @patch.object(zmqS, "_new_socket", mock_new_socket) +# def test_return_global_model(): +# result = zmqS.return_global_model_send() +# assert result == GLOBAL_MODEL +# mock_return_global_model_callback.assert_called_once_with() -@patch.object(zmqS, "_new_socket", mock_new_socket) -def test_return_global_model(): - result = zmqS.return_global_model_send() - assert result == GLOBAL_MODEL - mock_return_global_model_callback.assert_called_once_with() +# @patch.object(zmqS, "_new_socket", mock_new_socket) +# def test_is_global_model_most_recent(): +# result = zmqS.is_global_model_most_recent_send(123) +# assert result == IS_MOST_RECENT +# mock_is_global_model_most_recent.assert_called_once_with(123) -@patch.object(zmqS, "_new_socket", mock_new_socket) -def test_is_global_model_most_recent(): - result = zmqS.is_global_model_most_recent_send(123) - assert result == IS_MOST_RECENT - mock_is_global_model_most_recent.assert_called_once_with(123) +# @patch.object(zmqS, "_new_socket", mock_new_socket) +# def test_receive_worker_update(): +# result = zmqS.receive_worker_update_send("test123", b"model_update") +# assert result == WORKER_UPDATE +# mock_receive_worker_update_callback.assert_called_once_with("test123", b"model_update") -@patch.object(zmqS, "_new_socket", mock_new_socket) -def test_receive_worker_update(): - result = zmqS.receive_worker_update_send("test123", b"model_update") - assert result == WORKER_UPDATE - mock_receive_worker_update_callback.assert_called_once_with("test123", b"model_update") +test_server_args_request() \ No newline at end of file