Skip to content

Commit

Permalink
fix_issue_7_blocking_on_linux_refactor: temp-changes
Browse files Browse the repository at this point in the history
  • Loading branch information
hassan-digicatapult committed Sep 13, 2024
1 parent 041bf74 commit 3a5a478
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 52 deletions.
15 changes: 9 additions & 6 deletions src/dc_federated/backend/dcf_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,6 +726,7 @@ def __init__(

self.context = zmq.Context()
self.socket = self.context.socket(zmq.ROUTER)
print(f"binding to " + f"tcp://*:{self.socket_port}")
self.socket.bind(f"tcp://*:{self.socket_port}")

def start_server(self):
Expand All @@ -740,19 +741,24 @@ def start_server(self):
command = " ".join(["python", subprocess_file, str(self.socket_port)])
sp.Popen(command, stdout=sys.stdout, stderr=sys.stderr, shell=True)

self.wait_for_messages()
self.run()

def run(self):
"""
Runs the main loop of the DCFServerHandler.
"""
print(self.socket.send())
if True:
return

poller = zmq.Poller()
poller.register(self.socket, zmq.POLLIN)

print("starting listening")
while True:
socks = dict(poller.poll())

print("something!")
if self.socket in socks:
print("handling")
self.handle_received_message(self.socket.recv_multipart())
else:
raise LookupError(f"Unknown socket detected by zmq poller {self.socket}.")
Expand Down Expand Up @@ -793,6 +799,3 @@ def receive(self, message):
logger.error(
f'ZQM messaging interface received unrecognised message type: "{message[0]}"'
)

def __del__(self):
self.context.term()
10 changes: 7 additions & 3 deletions src/dc_federated/backend/zmq_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(self, port) -> None:
self.port = port
self.context = zmq.Context()
self.socket = self.context.socket(zmq.DEALER)
print(f"connecting to " + f"tcp://localhost:{self.port}")
self.socket.connect(f"tcp://localhost:{self.port}")

def server_args_request_send(self):
Expand All @@ -105,9 +106,12 @@ def server_args_request_send(self):
ssl_keyfile, ssl_certfile, model_check_interval, debug
"""
# socket = self._send([b"server_args_request"])
self.socket.send_multipart([b"server_args_request"])
output = self.socket.recv_pyobj()
return output
# self.socket.send_multipart([b"server_args_request"])
self.socket.send_pyobj("hello")

print("Sent")
# output = self.socket.recv_pyobj()
# return output

def register_worker_send(self, worker_id):
"""
Expand Down
103 changes: 60 additions & 43 deletions tests/test_zmq_interface.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from dc_federated.backend.zmq_interface import ZMQInterfaceModel, ZMQInterfaceServer
from cProfile import run
from dc_federated.backend.zmq_interface import ZMQInterfaceServer
from dc_federated.backend.dcf_server import DCFServerHandler
import zmq
from unittest.mock import Mock, patch
import threading
Expand Down Expand Up @@ -53,7 +55,7 @@ def mock_new_socket():
converts the relevant "recv" functions from blocking to non-blocking.
"""
context = zmq.Context()
socket = context.socket(zmq.REQ)
socket = context.socket(zmq.DEALER)
socket.connect(f"tcp://localhost:{port}")

mock_recv(context, socket, "recv_pyobj")
Expand All @@ -62,74 +64,89 @@ def mock_new_socket():
return socket


@pytest.fixture(autouse=True)
# @pytest.fixture(autouse=True)
def run_model_interface():
"""
A function for initialising the REP socket that is initialised by the
model-process. This is done in its own thread with a non-blocking
`recv_multipart` so that the server-process can send/receive messages too.
"""
context = zmq.Context()
socket = context.socket(zmq.REP)
socket.bind(f"tcp://*:{port}")

mock_recv(context, socket, "recv_multipart", close=False)

zmqM = ZMQInterfaceModel(
socket=socket,
# context = zmq.Context()
# socket = context.socket(zmq.REP)
# socket.bind(f"tcp://*:{port}")

# mock_recv(context, socket, "recv_multipart", close=False)

# zmqM = DCFServerHandler(
# socket=socket,
# register_worker_callback=mock_register_worker_callback,
# unregister_worker_callback=mock_unregister_worker_callback,
# return_global_model_callback=mock_return_global_model_callback,
# is_global_model_most_recent=mock_is_global_model_most_recent,
# receive_worker_update_callback=mock_receive_worker_update_callback,
# server_subprocess_args=server_subprocess_args,
# )

dcf_server_handler = DCFServerHandler(
register_worker_callback=mock_register_worker_callback,
unregister_worker_callback=mock_unregister_worker_callback,
return_global_model_callback=mock_return_global_model_callback,
is_global_model_most_recent=mock_is_global_model_most_recent,
receive_worker_update_callback=mock_receive_worker_update_callback,
server_subprocess_args=server_subprocess_args,
key_list_file="",
server_mode_safe=False,
socket_port=port
)
thread = threading.Thread(target=zmqM.receive, daemon=True)
thread.start()
yield thread
socket.close()
context.term()
# dcf_server_handler.run()
# thread = threading.Thread(target=dcf_server_handler.run, daemon=True)
threading.Thread(target=lambda: dcf_server_handler.run()).start()
# thread.start()
# yield thread


zmqS = ZMQInterfaceServer(port=port)

# Test each of the interface's functions
@patch.object(zmqS, "_new_socket", mock_new_socket)
# @patch.object(zmqS, "_new_socket", mock_new_socket)
def test_server_args_request():
run_model_interface()
print("here")
result = zmqS.server_args_request_send()
assert result == server_subprocess_args


@patch.object(zmqS, "_new_socket", mock_new_socket)
def test_register_worker():
result = zmqS.register_worker_send("test123")
assert result == b"1"
mock_register_worker_callback.assert_called_once_with("test123")
# @patch.object(zmqS, "_new_socket", mock_new_socket)
# def test_register_worker():
# result = zmqS.register_worker_send("test123")
# assert result == b"1"
# mock_register_worker_callback.assert_called_once_with("test123")


# @patch.object(zmqS, "_new_socket", mock_new_socket)
# def test_unregister_worker():
# result = zmqS.unregister_worker_send("test123")
# assert result == b"1"
# mock_unregister_worker_callback.assert_called_once_with("test123")

@patch.object(zmqS, "_new_socket", mock_new_socket)
def test_unregister_worker():
result = zmqS.unregister_worker_send("test123")
assert result == b"1"
mock_unregister_worker_callback.assert_called_once_with("test123")

# @patch.object(zmqS, "_new_socket", mock_new_socket)
# def test_return_global_model():
# result = zmqS.return_global_model_send()
# assert result == GLOBAL_MODEL
# mock_return_global_model_callback.assert_called_once_with()

@patch.object(zmqS, "_new_socket", mock_new_socket)
def test_return_global_model():
result = zmqS.return_global_model_send()
assert result == GLOBAL_MODEL
mock_return_global_model_callback.assert_called_once_with()

# @patch.object(zmqS, "_new_socket", mock_new_socket)
# def test_is_global_model_most_recent():
# result = zmqS.is_global_model_most_recent_send(123)
# assert result == IS_MOST_RECENT
# mock_is_global_model_most_recent.assert_called_once_with(123)

@patch.object(zmqS, "_new_socket", mock_new_socket)
def test_is_global_model_most_recent():
result = zmqS.is_global_model_most_recent_send(123)
assert result == IS_MOST_RECENT
mock_is_global_model_most_recent.assert_called_once_with(123)

# @patch.object(zmqS, "_new_socket", mock_new_socket)
# def test_receive_worker_update():
# result = zmqS.receive_worker_update_send("test123", b"model_update")
# assert result == WORKER_UPDATE
# mock_receive_worker_update_callback.assert_called_once_with("test123", b"model_update")

@patch.object(zmqS, "_new_socket", mock_new_socket)
def test_receive_worker_update():
result = zmqS.receive_worker_update_send("test123", b"model_update")
assert result == WORKER_UPDATE
mock_receive_worker_update_callback.assert_called_once_with("test123", b"model_update")
test_server_args_request()

0 comments on commit 3a5a478

Please sign in to comment.