Skip to content

Commit

Permalink
Add broker.subscriber.get_one() (#1726)
Browse files Browse the repository at this point in the history
* subscriber.get_one()

* remove _prepare

* ruff satisfied

* fixes

* fixes

* fixes

* fixes

* Kafka subscriber.get_one()

* Confluent subscriber.get_one()

* refactor: polist RMQ get_one method

* Small refactoring of get_one

* Rabbit get_one error fix

* Kafka get_one update

* Confluent get_one update

* Redis channel get_one

* Redis list get_one draft

* Redis batch list get_one draft

* Redis channel get_one update and list get_one message decoding

* Redis list batch get_one message decoding

* Redis stream get_one

* Redis batch stream get_one

* Redis channel get_one fix

* Update brokers start methods

* remove unnecessary code

* Nats CoreSubscriber.get_one

* Nats CoreSubscriber.get_one timeout support

* Nats PullStreamSubscriber get_one

* Nats KeyValueWatchSubscriber get_one prototype

* Nats ObjStoreWatchSubscriber get_one prototype

* Add Nats additional get_one methods

* refactor: polist RMQ get_one method

* Rabbit subscriber get_one tests

* Kafka subscriber get_one tests

* Confluent subscriber get_one tests

* Redis subscriber get_one tests

* Nats core and JS get_one tests

* Nats PoolSubscriber get_one tests

* Nats batch pull get_one tests + fixes

* Nats get_one with filter test

* Nats CoreSubscriber.get_one small refactoring

* refactor: polish get_one

* lint: fix redis mypy

* lint: fix rabbit mypy

* lint: fix kafka mypy

* lint: fix confluent mypy

* lint: fix kafka mypy

* lint: fix nats mypy

* lint: fix precommit

* refactor: fix nats unsub

* fix: correct redis timeout

* fix: correct redis channel sub

* lint: fix precommit

* tests: mv get_one tests to basic testcase

* tests: mv get_one tests to real testcase

* docs: generate API References

* refactor: use process_msg everywhere

* refactor: mv process_msg broker.utils

* lint: fix mypy

* Nats KV and Obj subscribers get_one tests

* Nats KV and Obj subscribers get_one timeout tests

* format fix

* tests: tefactor timeout tests

---------

Co-authored-by: Nikita Pastukhov <diementros@yandex.ru>
Co-authored-by: Pastukhov Nikita <nikita@pastukhov-dev.ru>
Co-authored-by: Lancetnik <Lancetnik@users.noreply.github.com>
  • Loading branch information
4 people authored Sep 8, 2024
1 parent 218ddcc commit 7aaafdd
Show file tree
Hide file tree
Showing 16 changed files with 1,049 additions and 75 deletions.
1 change: 1 addition & 0 deletions docs/docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ search:
- [MultiLock](api/faststream/broker/utils/MultiLock.md)
- [default_filter](api/faststream/broker/utils/default_filter.md)
- [get_watcher_context](api/faststream/broker/utils/get_watcher_context.md)
- [process_msg](api/faststream/broker/utils/process_msg.md)
- [resolve_custom_func](api/faststream/broker/utils/resolve_custom_func.md)
- wrapper
- call
Expand Down
11 changes: 11 additions & 0 deletions docs/docs/en/api/faststream/broker/utils/process_msg.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
# 0.5 - API
# 2 - Release
# 3 - Contributing
# 5 - Template Page
# 10 - Default
search:
boost: 0.5
---

::: faststream.broker.utils.process_msg
12 changes: 10 additions & 2 deletions faststream/broker/middlewares/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Tuple,
Type,
Union,
cast,
overload,
)

Expand All @@ -28,7 +29,9 @@
from faststream.types import AsyncFuncAny


GeneralExceptionHandler: TypeAlias = Union[Callable[..., None], Callable[..., Awaitable[None]]]
GeneralExceptionHandler: TypeAlias = Union[
Callable[..., None], Callable[..., Awaitable[None]]
]
PublishingExceptionHandler: TypeAlias = Callable[..., "Any"]

CastedGeneralExceptionHandler: TypeAlias = Callable[..., Awaitable[None]]
Expand Down Expand Up @@ -126,7 +129,12 @@ def __init__(
self._handlers: CastedHandlers = [
(IgnoredException, ignore_handler),
*(
(exc_type, apply_types(to_async(handler)))
(
exc_type,
apply_types(
cast(Callable[..., Awaitable[None]], to_async(handler))
),
)
for exc_type, handler in (handlers or {}).items()
),
]
Expand Down
5 changes: 5 additions & 0 deletions faststream/broker/subscriber/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,11 @@ async def consume(self, msg: MsgType) -> Any: ...
@abstractmethod
async def process_message(self, msg: MsgType) -> "Response": ...

@abstractmethod
async def get_one(
self, *, timeout: float = 5.0
) -> "Optional[StreamMessage[MsgType]]": ...

@abstractmethod
def add_call(
self,
Expand Down
19 changes: 9 additions & 10 deletions faststream/broker/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def __init__(
"""Initialize a new instance of the class."""
self.calls = []

self._default_parser = default_parser
self._default_decoder = default_decoder
self._parser = default_parser
self._decoder = default_decoder
self._no_reply = no_reply
# Watcher args
self._no_ack = no_ack
Expand Down Expand Up @@ -166,18 +166,17 @@ def setup( # type: ignore[override]

for call in self.calls:
if parser := call.item_parser or broker_parser:
async_parser = resolve_custom_func(
to_async(parser), self._default_parser
)
async_parser = resolve_custom_func(to_async(parser), self._parser)
else:
async_parser = self._default_parser
async_parser = self._parser

if decoder := call.item_decoder or broker_decoder:
async_decoder = resolve_custom_func(
to_async(decoder), self._default_decoder
)
async_decoder = resolve_custom_func(to_async(decoder), self._decoder)
else:
async_decoder = self._default_decoder
async_decoder = self._decoder

self._parser = async_parser
self._decoder = async_decoder

call.setup(
parser=async_parser,
Expand Down
55 changes: 52 additions & 3 deletions faststream/broker/utils.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,85 @@
import asyncio
import inspect
from contextlib import suppress
from contextlib import AsyncExitStack, suppress
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
AsyncContextManager,
Awaitable,
Callable,
Iterable,
Optional,
Type,
Union,
cast,
)

import anyio
from typing_extensions import Self
from typing_extensions import Literal, Self, overload

from faststream.broker.acknowledgement_watcher import WatcherContext, get_watcher
from faststream.utils.functions import fake_context, to_async
from faststream.broker.types import MsgType
from faststream.utils.functions import fake_context, return_input, to_async

if TYPE_CHECKING:
from types import TracebackType

from faststream.broker.message import StreamMessage
from faststream.broker.types import (
AsyncCallable,
BrokerMiddleware,
CustomCallable,
SyncCallable,
)
from faststream.types import LoggerProto


@overload
async def process_msg(
msg: Literal[None],
middlewares: Iterable["BrokerMiddleware[MsgType]"],
parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]],
decoder: Callable[["StreamMessage[MsgType]"], "Any"],
) -> None: ...


@overload
async def process_msg(
msg: MsgType,
middlewares: Iterable["BrokerMiddleware[MsgType]"],
parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]],
decoder: Callable[["StreamMessage[MsgType]"], "Any"],
) -> "StreamMessage[MsgType]": ...


async def process_msg(
msg: Optional[MsgType],
middlewares: Iterable["BrokerMiddleware[MsgType]"],
parser: Callable[[MsgType], Awaitable["StreamMessage[MsgType]"]],
decoder: Callable[["StreamMessage[MsgType]"], "Any"],
) -> Optional["StreamMessage[MsgType]"]:
if msg is None:
return None

async with AsyncExitStack() as stack:
return_msg: Callable[
[StreamMessage[MsgType]],
Awaitable[StreamMessage[MsgType]],
] = return_input

for m in middlewares:
mid = m(msg)
await stack.enter_async_context(mid)
return_msg = partial(mid.consume_scope, return_msg)

parsed_msg = await parser(msg)
parsed_msg._decoded_body = await decoder(parsed_msg)
return await return_msg(parsed_msg)

raise AssertionError("unreachable")


async def default_filter(msg: "StreamMessage[Any]") -> bool:
"""A function to filter stream messages."""
return not msg.processed
Expand Down
25 changes: 24 additions & 1 deletion faststream/confluent/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from faststream.broker.publisher.fake import FakePublisher
from faststream.broker.subscriber.usecase import SubscriberUsecase
from faststream.broker.types import MsgType
from faststream.broker.utils import process_msg
from faststream.confluent.parser import AsyncConfluentParser
from faststream.confluent.schemas import TopicPartition

Expand Down Expand Up @@ -152,7 +153,8 @@ async def start(self) -> None:

await super().start()

self.task = asyncio.create_task(self._consume())
if self.calls:
self.task = asyncio.create_task(self._consume())

async def close(self) -> None:
await super().close()
Expand All @@ -166,6 +168,27 @@ async def close(self) -> None:

self.task = None

@override
async def get_one(
self,
*,
timeout: float = 5.0,
) -> "Optional[StreamMessage[Message]]":
assert self.consumer, "You should start subscriber at first." # nosec B101
assert ( # nosec B101
not self.calls
), "You can't use `get_one` method if subscriber has registered handlers."

raw_message = await self.consumer.getone(timeout=timeout)

msg = await process_msg(
msg=raw_message,
middlewares=self._broker_middlewares,
parser=self._parser,
decoder=self._decoder,
)
return msg

def _make_response_publisher(
self,
message: "StreamMessage[Any]",
Expand Down
32 changes: 31 additions & 1 deletion faststream/kafka/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
CustomCallable,
MsgType,
)
from faststream.broker.utils import process_msg
from faststream.kafka.message import KafkaAckableMessage, KafkaMessage
from faststream.kafka.parser import AioKafkaBatchParser, AioKafkaParser
from faststream.utils.path import compile_path
Expand Down Expand Up @@ -164,7 +165,8 @@ async def start(self) -> None:
await consumer.start()
await super().start()

self.task = asyncio.create_task(self._consume())
if self.calls:
self.task = asyncio.create_task(self._consume())

async def close(self) -> None:
await super().close()
Expand All @@ -178,6 +180,34 @@ async def close(self) -> None:

self.task = None

@override
async def get_one(
self,
*,
timeout: float = 5.0,
) -> "Optional[StreamMessage[MsgType]]":
assert self.consumer, "You should start subscriber at first." # nosec B101
assert ( # nosec B101
not self.calls
), "You can't use `get_one` method if subscriber has registered handlers."

raw_messages = await self.consumer.getmany(
timeout_ms=timeout * 1000, max_records=1
)

if not raw_messages:
return None

((raw_message,),) = raw_messages.values()

msg: StreamMessage[MsgType] = await process_msg(
msg=raw_message,
middlewares=self._broker_middlewares,
parser=self._parser,
decoder=self._decoder,
)
return msg

def _make_response_publisher(
self,
message: "StreamMessage[Any]",
Expand Down
2 changes: 1 addition & 1 deletion faststream/nats/subscriber/asyncapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)


class AsyncAPISubscriber(LogicSubscriber[Any]):
class AsyncAPISubscriber(LogicSubscriber[Any, Any]):
"""A class to represent a NATS handler."""

def get_name(self) -> str:
Expand Down
Loading

0 comments on commit 7aaafdd

Please sign in to comment.