Skip to content

Commit

Permalink
Refactor PubSub, Server & Publishers
Browse files Browse the repository at this point in the history
  • Loading branch information
roekatz committed Sep 4, 2024
1 parent 065c539 commit fd22880
Show file tree
Hide file tree
Showing 17 changed files with 370 additions and 477 deletions.
12 changes: 11 additions & 1 deletion packages/opal-common/opal_common/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,23 @@ def __init__(self):
self._tasks: List[asyncio.Task] = []

def _cleanup_task(self, done_task):
self._tasks.remove(done_task)
try:
self._tasks.remove(done_task)
except KeyError:
...

def add_task(self, f):
t = asyncio.create_task(f)
self._tasks.append(t)
t.add_done_callback(self._cleanup_task)

async def join(self, cancel=False):
if cancel:
for t in self._tasks:
t.cancel()
await asyncio.gather(*self._tasks, return_exceptions=True)
self._tasks.clear()


async def repeated_call(
func: Coroutine,
Expand Down
2 changes: 1 addition & 1 deletion packages/opal-common/opal_common/confi/confi.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def wrapped_cast(value, *args, **kwargs):
return wrapped_cast


def load_conf_if_none(variable, conf):
def load_conf_if_none(variable: Any, conf: Any):
if variable is None:
return conf
else:
Expand Down
165 changes: 4 additions & 161 deletions packages/opal-common/opal_common/topics/publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
from typing import Any, Optional, Set

from ddtrace import tracer
from fastapi_websocket_pubsub import PubSubClient, PubSubEndpoint, Topic, TopicList
from fastapi_websocket_pubsub import PubSubEndpoint, Topic, TopicList
from opal_common.async_utils import TasksPool
from opal_common.logger import logger


Expand All @@ -12,8 +13,7 @@ class TopicPublisher:

def __init__(self):
"""inits the publisher's asyncio tasks list."""
self._tasks: Set[asyncio.Task] = set()
self._tasks_lock = asyncio.Lock()
self._pool = TasksPool()

async def publish(self, topics: TopicList, data: Any = None):
raise NotImplementedError()
Expand All @@ -29,95 +29,10 @@ def start(self):
"""starts the publisher."""
logger.debug("started topic publisher")

async def _add_task(self, task: asyncio.Task):
async with self._tasks_lock:
self._tasks.add(task)
task.add_done_callback(self._cleanup_task)

async def wait(self):
async with self._tasks_lock:
await asyncio.gather(*self._tasks, return_exceptions=True)
self._tasks.clear()

async def stop(self):
"""stops the publisher (cancels any running publishing tasks)"""
logger.debug("stopping topic publisher")
await self.wait()

def _cleanup_task(self, task: asyncio.Task):
try:
self._tasks.remove(task)
except KeyError:
...


class PeriodicPublisher:
"""Wrapper for a task that publishes to topic on fixed interval
periodically."""

def __init__(
self,
publisher: TopicPublisher,
time_interval: int,
topic: Topic,
message: Any = None,
task_name: str = "periodic publish task",
):
"""inits the publisher.
Args:
publisher (TopicPublisher): can publish messages on the pub/sub channel
interval (int): the time interval between publishing consecutive messages
topic (Topic): the topic to publish on
message (Any): the message to publish
"""
self._publisher = publisher
self._interval = time_interval
self._topic = topic
self._message = message
self._task_name = task_name
self._task: Optional[asyncio.Task] = None

async def __aenter__(self):
self.start()
return self

async def __aexit__(self, exc_type, exc, tb):
await self.stop()

def start(self):
"""starts the periodic publisher task."""
if self._task is not None:
logger.warning(f"{self._task_name} already started")
return

logger.info(
f"started {self._task_name}: topic is '{self._topic}', interval is {self._interval} seconds"
)
self._task = asyncio.create_task(self._publish_task())

async def stop(self):
"""stops the publisher (cancels any running publishing tasks)"""
if self._task is not None:
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
self._task = None
logger.info(f"cancelled {self._task_name} to topic: {self._topic}")

async def wait_until_done(self):
await self._task

async def _publish_task(self):
while True:
await asyncio.sleep(self._interval)
logger.info(
f"{self._task_name}: publishing message on topic '{self._topic}', next publish is scheduled in {self._interval} seconds"
)
async with self._publisher:
await self._publisher.publish(topics=[self._topic], data=self._message)
await self._pool.join()


class ServerSideTopicPublisher(TopicPublisher):
Expand All @@ -132,77 +47,5 @@ def __init__(self, endpoint: PubSubEndpoint):
self._endpoint = endpoint
super().__init__()

async def _publish_impl(self, topics: TopicList, data: Any = None):
with tracer.trace("topic_publisher.publish", resource=str(topics)):
await self._endpoint.publish(topics=topics, data=data)

async def publish(self, topics: TopicList, data: Any = None):
await self._add_task(asyncio.create_task(self._publish_impl(topics, data)))


class ClientSideTopicPublisher(TopicPublisher):
"""A simple wrapper around a PubSubClient that exposes publish().
Provides start() and stop() shortcuts that helps treat this client
as a separate "process" or task that runs in the background.
"""

def __init__(self, client: PubSubClient, server_uri: str):
"""inits the publisher.
Args:
client (PubSubClient): a configured not-yet-started pub sub client
server_uri (str): the URI of the pub sub server we publish to
"""
self._client = client
self._server_uri = server_uri
super().__init__()

def start(self):
"""starts the pub/sub client as a background asyncio task.
the client will attempt to connect to the pubsub server until
successful.
"""
super().start()
self._client.start_client(f"{self._server_uri}")

async def stop(self):
"""stops the pubsub client, and cancels any publishing tasks."""
await self._client.disconnect()
await super().stop()

async def wait_until_done(self):
"""When the publisher is a used as a context manager, this method waits
until the client is done (i.e: terminated) to prevent exiting the
context."""
return await self._client.wait_until_done()

async def publish(self, topics: TopicList, data: Any = None):
"""publish a message by launching a background task on the event loop.
Args:
topics (TopicList): a list of topics to publish the message to
data (Any): optional data to publish as part of the message
"""
await self._add_task(
asyncio.create_task(self._publish(topics=topics, data=data))
)

async def _publish(self, topics: TopicList, data: Any = None) -> bool:
"""Do not trigger directly, must be triggered via publish() in order to
run as a monitored background asyncio task."""
await self._client.wait_until_ready()
logger.info("Publishing to topics: {topics}", topics=topics)
return await self._client.publish(topics, data)


class ScopedServerSideTopicPublisher(ServerSideTopicPublisher):
def __init__(self, endpoint: PubSubEndpoint, scope_id: str):
super().__init__(endpoint)
self._scope_id = scope_id

async def publish(self, topics: TopicList, data: Any = None):
scoped_topics = [f"{self._scope_id}:{topic}" for topic in topics]
logger.info("Publishing to topics: {topics}", topics=scoped_topics)
await super().publish(scoped_topics, data)
18 changes: 6 additions & 12 deletions packages/opal-server/opal_server/data/data_update_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,18 @@
import os
from typing import List

from fastapi_utils.tasks import repeat_every
from opal_common.logger import logger
from opal_common.schemas.data import (
DataSourceEntryWithPollingInterval,
DataUpdate,
ServerDataSourceConfig,
)
from opal_common.topics.publisher import TopicPublisher
from opal_common.schemas.data import DataUpdate
from opal_server.pubsub import PubSub
from opal_server.scopes.scoped_pubsub import ScopedPubSub

TOPIC_DELIMITER = "/"
PREFIX_DELIMITER = ":"


class DataUpdatePublisher:
def __init__(self, publisher: TopicPublisher) -> None:
self._publisher = publisher
def __init__(self, pubsub: PubSub | ScopedPubSub) -> None:
self._pubsub = pubsub

@staticmethod
def get_topic_combos(topic: str) -> List[str]:
Expand Down Expand Up @@ -108,6 +104,4 @@ async def publish_data_updates(self, update: DataUpdate):
entries=logged_entries,
)

await self._publisher.publish(
list(all_topic_combos), update.dict(by_alias=True)
)
await self._pubsub.publish(list(all_topic_combos), update.dict(by_alias=True))
9 changes: 3 additions & 6 deletions packages/opal-server/opal_server/policy/watcher/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
PolicyUpdateMessage,
PolicyUpdateMessageNotification,
)
from opal_common.topics.publisher import TopicPublisher
from opal_common.topics.utils import policy_topics
from opal_server.pubsub import PubSub


async def create_update_all_directories_in_repo(
Expand Down Expand Up @@ -104,7 +104,7 @@ def is_path_affected(path: Path) -> bool:
async def publish_changed_directories(
old_commit: Commit,
new_commit: Commit,
publisher: TopicPublisher,
pubsub: PubSub,
file_extensions: Optional[List[str]] = None,
bundle_ignore: Optional[List[str]] = None,
):
Expand All @@ -116,7 +116,4 @@ async def publish_changed_directories(
)

if notification:
async with publisher:
await publisher.publish(
topics=notification.topics, data=notification.update.dict()
)
await pubsub.publish_sync(notification.topics, notification.update.dict())
38 changes: 18 additions & 20 deletions packages/opal-server/opal_server/policy/watcher/factory.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,39 @@
from functools import partial
from typing import Any, List, Optional
from typing import List, Optional

from fastapi_websocket_pubsub.pub_sub_server import PubSubEndpoint
from opal_common.confi.confi import load_conf_if_none
from opal_common.git_utils.repo_cloner import RepoClonePathFinder
from opal_common.logger import logger
from opal_common.sources.api_policy_source import ApiPolicySource
from opal_common.sources.git_policy_source import GitPolicySource
from opal_common.topics.publisher import TopicPublisher
from opal_server.config import PolicySourceTypes, opal_server_config
from opal_server.policy.watcher.callbacks import publish_changed_directories
from opal_server.policy.watcher.task import BasePolicyWatcherTask, PolicyWatcherTask
from opal_server.pubsub import PubSub
from opal_server.scopes.task import ScopesPolicyWatcherTask


def setup_watcher_task(
publisher: TopicPublisher,
pubsub_endpoint: PubSubEndpoint,
source_type: str = None,
remote_source_url: str = None,
clone_path_finder: RepoClonePathFinder = None,
branch_name: str = None,
pubsub: PubSub,
source_type: Optional[str] = None,
remote_source_url: Optional[str] = None,
clone_path_finder: Optional[RepoClonePathFinder] = None,
branch_name: Optional[str] = None,
ssh_key: Optional[str] = None,
polling_interval: int = None,
request_timeout: int = None,
policy_bundle_token: str = None,
policy_bundle_token_id: str = None,
policy_bundle_server_type: str = None,
policy_bundle_aws_region: str = None,
polling_interval: Optional[int] = None,
request_timeout: Optional[int] = None,
policy_bundle_token: Optional[str] = None,
policy_bundle_token_id: Optional[str] = None,
policy_bundle_server_type: Optional[str] = None,
policy_bundle_aws_region: Optional[str] = None,
extensions: Optional[List[str]] = None,
bundle_ignore: Optional[List[str]] = None,
) -> BasePolicyWatcherTask:
"""Create a PolicyWatcherTask with Git / API policy source defined by env
vars Load all the defaults from config if called without params.
Args:
publisher(TopicPublisher): server side publisher to publish changes in policy
pubsub(PubSub): server side pubsub client to publish changes in policy
source_type(str): policy source type, can be Git / Api to opa bundle server
remote_source_url(str): the base address to request the policy from
clone_path_finder(RepoClonePathFinder): from which the local dir path for the repo clone would be retrieved
Expand All @@ -46,11 +44,11 @@ def setup_watcher_task(
policy_bundle_token(int): auth token to include in connections to OPAL server. Defaults to POLICY_BUNDLE_SERVER_TOKEN.
policy_bundle_token_id(int): id token to include in connections to OPAL server. Defaults to POLICY_BUNDLE_SERVER_TOKEN_ID.
policy_bundle_server_type (str): type of policy bundle server (HTTP S3). Defaults to POLICY_BUNDLE_SERVER_TYPE
extensions(list(str), optional): list of extantions to check when new policy arrive default is FILTER_FILE_EXTENSIONS
extensions(list(str), optional): list of extensions to check when new policy arrive default is FILTER_FILE_EXTENSIONS
bundle_ignore(list(str), optional): list of glob paths to use for excluding files from bundle default is OPA_BUNDLE_IGNORE
"""
if opal_server_config.SCOPES:
return ScopesPolicyWatcherTask(pubsub_endpoint)
return ScopesPolicyWatcherTask(pubsub)

# load defaults
source_type = load_conf_if_none(source_type, opal_server_config.POLICY_SOURCE_TYPE)
Expand Down Expand Up @@ -135,9 +133,9 @@ def setup_watcher_task(
watcher.add_on_new_policy_callback(
partial(
publish_changed_directories,
publisher=publisher,
pubsub=pubsub,
file_extensions=extensions,
bundle_ignore=bundle_ignore,
)
)
return PolicyWatcherTask(watcher, pubsub_endpoint)
return PolicyWatcherTask(watcher, pubsub)
Loading

0 comments on commit fd22880

Please sign in to comment.