Skip to content

Commit

Permalink
Interface change from communicator -> coordinator
Browse files Browse the repository at this point in the history
  • Loading branch information
unkcpz committed Dec 17, 2024
1 parent c0a8bbd commit 41e7833
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 23 deletions.
13 changes: 13 additions & 0 deletions src/plumpy/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,16 @@ def remove_rpc_subscriber(self, identifier): ...
def remove_broadcast_subscriber(self, identifier): ...

def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ...

class Coordinator(Protocol):
def add_rpc_subscriber(self, subscriber: RpcSubscriber, identifier=None) -> Any: ...

def add_broadcast_subscriber(
self, subscriber: BroadcastSubscriber, subject_filter: str | Pattern[str] | None = None, identifier=None
) -> Any: ...

def remove_rpc_subscriber(self, identifier): ...

def remove_broadcast_subscriber(self, identifier): ...

def broadcast_send(self, body, sender=None, subject=None, correlation_id=None) -> bool: ...
14 changes: 7 additions & 7 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
cast,
)

from plumpy.coordinator import Communicator
from plumpy.coordinator import Coordinator

try:
from aiocontextvars import ContextVar
Expand Down Expand Up @@ -266,7 +266,7 @@ def __init__(
pid: Optional[PID_TYPE] = None,
logger: Optional[logging.Logger] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
communicator: Optional[Communicator] = None,
coordinator: Optional[Coordinator] = None,
) -> None:
"""
The signature of the constructor should not be changed by subclassing processes.
Expand Down Expand Up @@ -305,7 +305,7 @@ def __init__(
self._future = persistence.SavableFuture(loop=self._loop)
self._event_helper = EventHelper(ProcessListener)
self._logger = logger
self._communicator = communicator
self._communicator = coordinator

@super_check
def init(self) -> None:
Expand Down Expand Up @@ -449,7 +449,7 @@ def launch(
pid=pid,
logger=logger,
loop=self.loop,
communicator=self._communicator,
coordinator=self._communicator,
)
self.loop.create_task(process.step_until_terminated())
return process
Expand Down Expand Up @@ -654,7 +654,7 @@ def load_instance_state(self, saved_state: SAVED_STATE_TYPE, load_context: persi

self._state: process_states.State = self.recreate_state(saved_state['_state'])

if 'communicator' in load_context:
if 'coordinator' in load_context:
self._communicator = load_context.communicator

if 'logger' in load_context:
Expand Down Expand Up @@ -938,7 +938,7 @@ def _fire_event(self, evt: Callable[..., Any], *args: Any, **kwargs: Any) -> Non

# region Communication

def message_receive(self, _comm: Communicator, msg: Dict[str, Any]) -> Any:
def message_receive(self, _comm: Coordinator, msg: Dict[str, Any]) -> Any:
"""
Coroutine called when the process receives a message from the communicator
Expand Down Expand Up @@ -970,7 +970,7 @@ def message_receive(self, _comm: Communicator, msg: Dict[str, Any]) -> Any:
raise RuntimeError('Unknown intent')

def broadcast_receive(
self, _comm: Communicator, body: Any, sender: Any, subject: Any, correlation_id: Any
self, _comm: Coordinator, body: Any, sender: Any, subject: Any, correlation_id: Any
) -> Optional[concurrent.futures.Future]:
"""
Coroutine called when the process receives a message from the communicator
Expand Down
6 changes: 3 additions & 3 deletions src/plumpy/workchains.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
cast,
)

from plumpy.coordinator import Communicator
from plumpy.coordinator import Coordinator

from . import lang, mixins, persistence, process_states, processes
from .utils import PID_TYPE, SAVED_STATE_TYPE
Expand Down Expand Up @@ -128,9 +128,9 @@ def __init__(
pid: Optional[PID_TYPE] = None,
logger: Optional[logging.Logger] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
communicator: Optional[Communicator] = None,
coordinator: Optional[Coordinator] = None,
) -> None:
super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, communicator=communicator)
super().__init__(inputs=inputs, pid=pid, logger=logger, loop=loop, coordinator=coordinator)
self._stepper: Optional[Stepper] = None
self._awaitables: Dict[Union[asyncio.Future, processes.Process], str] = {}

Expand Down
24 changes: 12 additions & 12 deletions tests/rmq/test_process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def sync_controller(thread_communicator: rmq.RmqThreadCommunicator):
class TestRemoteProcessController:
@pytest.mark.asyncio
async def test_pause(self, thread_communicator, async_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
# Run the process in the background
asyncio.ensure_future(proc.step_until_terminated())
# Send a pause message
Expand All @@ -57,7 +57,7 @@ async def test_pause(self, thread_communicator, async_controller):

@pytest.mark.asyncio
async def test_play(self, thread_communicator, async_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
# Run the process in the background
asyncio.ensure_future(proc.step_until_terminated())
assert proc.pause()
Expand All @@ -75,7 +75,7 @@ async def test_play(self, thread_communicator, async_controller):

@pytest.mark.asyncio
async def test_kill(self, thread_communicator, async_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
# Run the process in the event loop
asyncio.ensure_future(proc.step_until_terminated())

Expand All @@ -88,7 +88,7 @@ async def test_kill(self, thread_communicator, async_controller):

@pytest.mark.asyncio
async def test_status(self, thread_communicator, async_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
# Run the process in the background
asyncio.ensure_future(proc.step_until_terminated())

Expand All @@ -108,7 +108,7 @@ def on_broadcast_receive(**msg):

thread_communicator.add_broadcast_subscriber(on_broadcast_receive)

proc = utils.DummyProcess(communicator=thread_communicator)
proc = utils.DummyProcess(coordinator=thread_communicator)
proc.execute()

expected_subjects = []
Expand All @@ -123,7 +123,7 @@ def on_broadcast_receive(**msg):
class TestRemoteProcessThreadController:
@pytest.mark.asyncio
async def test_pause(self, thread_communicator, sync_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)

# Send a pause message
pause_future = sync_controller.pause_process(proc.pid)
Expand All @@ -140,7 +140,7 @@ async def test_pause_all(self, thread_communicator, sync_controller):
"""Test pausing all processes on a communicator"""
procs = []
for _ in range(10):
procs.append(utils.WaitForSignalProcess(communicator=thread_communicator))
procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator))

sync_controller.pause_all("Slow yo' roll")
# Wait until they are all paused
Expand All @@ -151,7 +151,7 @@ async def test_play_all(self, thread_communicator, sync_controller):
"""Test pausing all processes on a communicator"""
procs = []
for _ in range(10):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
procs.append(proc)
proc.pause('hold tight')

Expand All @@ -162,7 +162,7 @@ async def test_play_all(self, thread_communicator, sync_controller):

@pytest.mark.asyncio
async def test_play(self, thread_communicator, sync_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
assert proc.pause()

# Send a play message
Expand All @@ -176,7 +176,7 @@ async def test_play(self, thread_communicator, sync_controller):

@pytest.mark.asyncio
async def test_kill(self, thread_communicator, sync_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)

# Send a kill message
kill_future = sync_controller.kill_process(proc.pid)
Expand All @@ -193,7 +193,7 @@ async def test_kill_all(self, thread_communicator, sync_controller):
"""Test pausing all processes on a communicator"""
procs = []
for _ in range(10):
procs.append(utils.WaitForSignalProcess(communicator=thread_communicator))
procs.append(utils.WaitForSignalProcess(coordinator=thread_communicator))

msg = process_comms.MessageBuilder.kill(text='bang bang, I shot you down')

Expand All @@ -203,7 +203,7 @@ async def test_kill_all(self, thread_communicator, sync_controller):

@pytest.mark.asyncio
async def test_status(self, thread_communicator, sync_controller):
proc = utils.WaitForSignalProcess(communicator=thread_communicator)
proc = utils.WaitForSignalProcess(coordinator=thread_communicator)
# Run the process in the background
asyncio.ensure_future(proc.step_until_terminated())

Expand Down
2 changes: 1 addition & 1 deletion tests/test_processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def on_broadcast_receive(_comm, body, sender, subject, correlation_id):
messages.append({'body': body, 'subject': subject, 'sender': sender, 'correlation_id': correlation_id})

communicator.add_broadcast_subscriber(on_broadcast_receive)
proc = utils.DummyProcess(communicator=communicator)
proc = utils.DummyProcess(coordinator=communicator)
proc.execute()

expected_subjects = []
Expand Down

0 comments on commit 41e7833

Please sign in to comment.