Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: extented RMQ features #1522

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
2 changes: 2 additions & 0 deletions docs/docs/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,8 @@ search:
- [RabbitBroker](api/faststream/rabbit/broker/RabbitBroker.md)
- broker
- [RabbitBroker](api/faststream/rabbit/broker/broker/RabbitBroker.md)
- connection
- [ConnectionManager](api/faststream/rabbit/broker/connection/ConnectionManager.md)
- logging
- [RabbitLoggingBroker](api/faststream/rabbit/broker/logging/RabbitLoggingBroker.md)
- registrator
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
# 0.5 - API
# 2 - Release
# 3 - Contributing
# 5 - Template Page
# 10 - Default
search:
boost: 0.5
---

::: faststream.rabbit.broker.connection.ConnectionManager
6 changes: 4 additions & 2 deletions faststream/confluent/broker/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,12 @@ def __init__(
] = SERVICE_NAME,
config: Annotated[
Optional[ConfluentConfig],
Doc("""
Doc(
"""
Extra configuration for the confluent-kafka-python
producer/consumer. See `confluent_kafka.Config <https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#kafka-client-configuration>`_.
"""),
"""
),
] = None,
# publisher args
acks: Annotated[
Expand Down
6 changes: 0 additions & 6 deletions faststream/rabbit/annotations.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from aio_pika import RobustChannel, RobustConnection
from typing_extensions import Annotated

from faststream.annotations import ContextRepo, Logger, NoCast
Expand All @@ -14,17 +13,12 @@
"RabbitMessage",
"RabbitBroker",
"RabbitProducer",
"Channel",
"Connection",
)

RabbitMessage = Annotated[RM, Context("message")]
RabbitBroker = Annotated[RB, Context("broker")]
RabbitProducer = Annotated[AioPikaFastProducer, Context("broker._producer")]

Channel = Annotated[RobustChannel, Context("broker._channel")]
Connection = Annotated[RobustConnection, Context("broker._connection")]

# NOTE: transaction is not for the public usage yet
# async def _get_transaction(connection: Connection) -> RabbitTransaction:
# async with connection.channel(publisher_confirms=False) as channel:
Expand Down
106 changes: 41 additions & 65 deletions faststream/rabbit/broker/broker.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,19 @@
import logging
from inspect import Parameter
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
Optional,
Type,
Union,
cast,
)
from typing import TYPE_CHECKING, Any, Callable, Iterable, Optional, Type, Union, cast
from urllib.parse import urlparse

from aio_pika import connect_robust
from typing_extensions import Annotated, Doc, override

from faststream.__about__ import SERVICE_NAME
from faststream.broker.message import gen_cor_id
from faststream.exceptions import NOT_CONNECTED_YET
from faststream.rabbit.broker.connection import ConnectionManager
from faststream.rabbit.broker.logging import RabbitLoggingBroker
from faststream.rabbit.broker.registrator import RabbitRegistrator
from faststream.rabbit.helpers.declarer import RabbitDeclarer
from faststream.rabbit.publisher.producer import AioPikaFastProducer
from faststream.rabbit.schemas import (
RABBIT_REPLY,
RabbitExchange,
RabbitQueue,
)
from faststream.rabbit.schemas import RabbitExchange, RabbitQueue
from faststream.rabbit.security import parse_security
from faststream.rabbit.subscriber.asyncapi import AsyncAPISubscriber
from faststream.rabbit.utils import build_url
Expand All @@ -37,8 +24,6 @@

from aio_pika import (
IncomingMessage,
RobustChannel,
RobustConnection,
RobustExchange,
RobustQueue,
)
Expand All @@ -48,10 +33,7 @@
from yarl import URL

from faststream.asyncapi import schema as asyncapi
from faststream.broker.types import (
BrokerMiddleware,
CustomCallable,
)
from faststream.broker.types import BrokerMiddleware, CustomCallable
from faststream.rabbit.types import AioPikaSendableMessage
from faststream.security import BaseSecurity
from faststream.types import AnyDict, Decorator, LoggerProto
Expand All @@ -67,7 +49,6 @@ class RabbitBroker(
_producer: Optional["AioPikaFastProducer"]

declarer: Optional[RabbitDeclarer]
_channel: Optional["RobustChannel"]

def __init__(
self,
Expand Down Expand Up @@ -213,6 +194,14 @@ def __init__(
Iterable["Decorator"],
Doc("Any custom decorator to apply to wrapped functions."),
] = (),
max_connection_pool_size: Annotated[
int,
Doc("Max connection pool size"),
] = 1,
max_channel_pool_size: Annotated[
int,
Doc("Max channel pool size"),
] = 1,
) -> None:
security_args = parse_security(security)

Expand All @@ -234,6 +223,8 @@ def __init__(
# respect ascynapi_url argument scheme
builded_asyncapi_url = urlparse(asyncapi_url)
self.virtual_host = builded_asyncapi_url.path
self.max_connection_pool_size = max_connection_pool_size
self.max_channel_pool_size = max_channel_pool_size
if protocol is None:
protocol = builded_asyncapi_url.scheme

Expand Down Expand Up @@ -273,13 +264,13 @@ def __init__(

self.app_id = app_id

self._channel = None
self.declarer = None

@property
def _subscriber_setup_extra(self) -> "AnyDict":
return {
**super()._subscriber_setup_extra,
"max_consumers": self._max_consumers,
"app_id": self.app_id,
"virtual_host": self.virtual_host,
"declarer": self.declarer,
Expand Down Expand Up @@ -350,7 +341,7 @@ async def connect( # type: ignore[override]
"when mandatory message will be returned"
),
] = Parameter.empty,
) -> "RobustConnection":
) -> "ConnectionManager":
"""Connect broker object to RabbitMQ.

To startup subscribers too you should use `broker.start()` after/instead this method.
Expand Down Expand Up @@ -405,65 +396,50 @@ async def _connect( # type: ignore[override]
channel_number: Optional[int],
publisher_confirms: bool,
on_return_raises: bool,
) -> "RobustConnection":
connection = cast(
"RobustConnection",
await connect_robust(
url,
timeout=timeout,
ssl_context=ssl_context,
),
)

if self._channel is None: # pragma: no branch
max_consumers = self._max_consumers
channel = self._channel = cast(
"RobustChannel",
await connection.channel(
channel_number=channel_number,
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises,
),
) -> "ConnectionManager":
if self._max_consumers:
c = AsyncAPISubscriber.build_log_context(
None,
RabbitQueue(""),
RabbitExchange(""),
)
self._log(f"Set max consumers to {self._max_consumers}", extra=c)

declarer = self.declarer = RabbitDeclarer(channel)
await declarer.declare_queue(RABBIT_REPLY)
connection_manager = ConnectionManager(
url=url,
timeout=timeout,
ssl_context=ssl_context,
connection_pool_size=self.max_connection_pool_size,
channel_pool_size=self.max_channel_pool_size,
channel_number=channel_number,
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises,
)

if self.declarer is None:
self.declarer = RabbitDeclarer(connection_manager)

if self._producer is None:
self._producer = AioPikaFastProducer(
declarer=declarer,
declarer=self.declarer,
decoder=self._decoder,
parser=self._parser,
)

if max_consumers:
c = AsyncAPISubscriber.build_log_context(
None,
RabbitQueue(""),
RabbitExchange(""),
)
self._log(f"Set max consumers to {max_consumers}", extra=c)
await channel.set_qos(prefetch_count=int(max_consumers))

return connection
return connection_manager

async def _close(
self,
exc_type: Optional[Type[BaseException]] = None,
exc_val: Optional[BaseException] = None,
exc_tb: Optional["TracebackType"] = None,
) -> None:
if self._channel is not None:
if not self._channel.is_closed:
await self._channel.close()

self._channel = None
if self._connection is not None:
await self._connection.close()

self.declarer = None
self._producer = None

if self._connection is not None:
await self._connection.close()

await super()._close(exc_type, exc_val, exc_tb)

async def start(self) -> None:
Expand Down
87 changes: 87 additions & 0 deletions faststream/rabbit/broker/connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, AsyncIterator, Optional, cast

from aio_pika import connect_robust
from aio_pika.pool import Pool

if TYPE_CHECKING:
from ssl import SSLContext

from aio_pika import (
RobustChannel,
RobustConnection,
)
from aio_pika.abc import TimeoutType


class ConnectionManager:
def __init__(
self,
*,
url: str,
timeout: "TimeoutType",
ssl_context: Optional["SSLContext"],
connection_pool_size: Optional[int],
channel_pool_size: Optional[int],
channel_number: Optional[int],
publisher_confirms: bool,
on_return_raises: bool,
) -> None:
self._connection_pool: "Pool[RobustConnection]" = Pool(
lambda: connect_robust(
url=url,
timeout=timeout,
ssl_context=ssl_context,
),
max_size=connection_pool_size,
)

self._channel_pool: "Pool[RobustChannel]" = Pool(
lambda: self._get_channel(
channel_number=channel_number,
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises,
),
max_size=channel_pool_size,
)

async def get_connection(self) -> "RobustConnection":
return await self._connection_pool._get()

@asynccontextmanager
async def acquire_connection(self) -> AsyncIterator["RobustConnection"]:
async with self._connection_pool.acquire() as connection:
yield connection

async def get_channel(self) -> "RobustChannel":
return await self._channel_pool._get()

@asynccontextmanager
async def acquire_channel(self) -> AsyncIterator["RobustChannel"]:
async with self._channel_pool.acquire() as channel:
yield channel

async def _get_channel(
self,
channel_number: Optional[int] = None,
publisher_confirms: bool = True,
on_return_raises: bool = False,
) -> "RobustChannel":
async with self.acquire_connection() as connection:
channel = cast(
"RobustChannel",
await connection.channel(
channel_number=channel_number,
publisher_confirms=publisher_confirms,
on_return_raises=on_return_raises,
),
)

return channel

async def close(self) -> None:
if not self._channel_pool.is_closed:
await self._channel_pool.close()

if not self._connection_pool.is_closed:
await self._connection_pool.close()
5 changes: 3 additions & 2 deletions faststream/rabbit/broker/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@
from inspect import Parameter
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union

from aio_pika import IncomingMessage, RobustConnection
from aio_pika import IncomingMessage

from faststream.broker.core.usecase import BrokerUsecase
from faststream.log.logging import get_broker_logger
from faststream.rabbit.broker.connection import ConnectionManager

if TYPE_CHECKING:
from faststream.types import LoggerProto


class RabbitLoggingBroker(BrokerUsecase[IncomingMessage, RobustConnection]):
class RabbitLoggingBroker(BrokerUsecase[IncomingMessage, ConnectionManager]):
"""A class that extends the LoggingMixin class and adds additional functionality for logging RabbitMQ related information."""

_max_queue_len: int
Expand Down
Loading