Skip to content

Commit

Permalink
Implement name and description discovery over nats and fastapi (#353)
Browse files Browse the repository at this point in the history
* Implement name and description discovery over nats and fastapi

* Fix: rename self.wf to self.provider
  • Loading branch information
kumaranvpl authored Oct 9, 2024
1 parent d60bd81 commit 23c1a47
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 11 deletions.
32 changes: 24 additions & 8 deletions fastagency/adapters/fastapi/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(
password (Optional[str], optional): The password. Defaults to None.
super_conversation (Optional["FastAPIProvider"], optional): The super conversation. Defaults to None.
"""
self.wf = provider
self.provider = provider

self.user = user
self.password = password
Expand Down Expand Up @@ -77,7 +77,9 @@ async def initiate_chat(
)

nc = await nats.connect(
self.wf.nats_url, user=self.wf.user, password=self.wf.password
self.provider.nats_url,
user=self.provider.user,
password=self.provider.password,
)
await nc.publish(
"chat.server.initiate_chat",
Expand All @@ -86,10 +88,10 @@ async def initiate_chat(

return init_msg

@router.get("/discover")
async def discover() -> list[WorkflowInfo]:
names = self.wf.names
descriptions = [self.wf.get_description(name) for name in names]
@router.get("/discovery")
def discovery() -> list[WorkflowInfo]:
names = self.provider.names
descriptions = [self.provider.get_description(name) for name in names]
return [
WorkflowInfo(name=name, description=description)
for name, description in zip(names, descriptions)
Expand Down Expand Up @@ -274,9 +276,23 @@ async def run_lifespan() -> None:

return "FastAPIWorkflows.run() completed"

def _get_workflow_info(self) -> list[dict[str, str]]:
resp = requests.get(f"{self.fastapi_url}/discovery", timeout=5)
return resp.json() # type: ignore [no-any-return]

def _get_names(self) -> list[str]:
return [workflow["name"] for workflow in self._get_workflow_info()]

def _get_description(self, name: str) -> str:
return next(
workflow["description"]
for workflow in self._get_workflow_info()
if workflow["name"] == name
)

@property
def names(self) -> list[str]:
return ["simple_learning"]
return self._get_names()

def get_description(self, name: str) -> str:
return "Student and teacher learning chat"
return self._get_description(name)
53 changes: 50 additions & 3 deletions fastagency/adapters/nats/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
from asyncer import asyncify, syncify
from faststream import FastStream, Logger
from faststream.nats import JStream, NatsBroker, NatsMessage
from nats.js import api
from nats.aio.client import Client as NatsClient
from nats.js import JetStreamContext, api
from nats.js.kv import KeyValue
from pydantic import BaseModel

from ...base import UI, ProviderProtocol, run_workflow
Expand Down Expand Up @@ -55,6 +57,8 @@ class InitiateModel(BaseModel):
# server prints message to client; chat.server.messages.<user_uuid>.<chat_uuid>
# we create this topic dynamically and subscribe to it => worker is fixed
"chat.server.messages.*.*",
# discovery subject
"discovery",
],
)

Expand Down Expand Up @@ -198,11 +202,21 @@ def callback(t: asyncio.Task[Any]) -> None:
except Exception as e:
await self._send_error_msg(e, logger)

async def _publish_discovery(self) -> None:
"""Publish the discovery message."""
jetstream_key_value = await self.broker.key_value(bucket="discovery")

names = self.provider.names
for name in names:
description = self.provider.get_description(name)
await jetstream_key_value.put(name, description.encode())

# todo: make it a router
@asynccontextmanager
async def lifespan(self, app: Any) -> AsyncIterator[None]:
async with self.broker:
await self.broker.start()
await self._publish_discovery()
try:
yield
finally:
Expand Down Expand Up @@ -430,9 +444,42 @@ async def run_lifespan() -> None:

return "NatsWorkflows.run() completed"

@asynccontextmanager
async def _get_jetstream_context(self) -> AsyncIterator[JetStreamContext]:
nc = NatsClient()
await nc.connect(self.nats_url, user=self.user, password=self.password)
js = nc.jetstream()
try:
yield js
finally:
await nc.close()

@asynccontextmanager
async def _get_jetstream_key_value(
self, bucket: str = "discovery"
) -> AsyncIterator[KeyValue]:
async with self._get_jetstream_context() as js:
kv = await js.create_key_value(bucket=bucket)
yield kv

async def _get_names(self) -> list[str]:
async with self._get_jetstream_key_value() as kv:
names = await kv.keys()
return names

async def _get_description(self, name: str) -> str:
async with self._get_jetstream_key_value() as kv:
description = await kv.get(name)
return description.value.decode() if description.value else ""

@property
def names(self) -> list[str]:
return ["simple_learning"]
names = syncify(self._get_names)()
logger.debug(f"Names: {names}")
return names

def get_description(self, name: str) -> str:
return "Student and teacher learning chat"
description = syncify(self._get_description)(name)
logger.debug(f"Description: {description}")
# return "Student and teacher learning chat"
return description

0 comments on commit 23c1a47

Please sign in to comment.