diff --git a/faststream/prometheus/middleware.py b/faststream/prometheus/middleware.py index 39a08828d0..c48618a5cd 100644 --- a/faststream/prometheus/middleware.py +++ b/faststream/prometheus/middleware.py @@ -102,15 +102,12 @@ async def consume_scope( self._metrics.received_messages.labels( broker=messaging_system, handler=destination_name, - ).inc() - - messages_sizes = consume_attrs["messages_sizes"] + ).inc(consume_attrs["messages_count"]) - for size in messages_sizes: - self._metrics.received_messages_size.labels( - broker=messaging_system, - handler=destination_name, - ).observe(size) + self._metrics.received_messages_size.labels( + broker=messaging_system, + handler=destination_name, + ).observe(consume_attrs["message_size"]) err: Optional[Exception] = None diff --git a/faststream/prometheus/provider.py b/faststream/prometheus/provider.py index d8b6f48ebc..e4b2399d3d 100644 --- a/faststream/prometheus/provider.py +++ b/faststream/prometheus/provider.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Protocol, Sequence, TypedDict +from typing import TYPE_CHECKING, Protocol, TypedDict from faststream.broker.message import MsgType, StreamMessage @@ -8,8 +8,9 @@ class ConsumeAttrs(TypedDict): - messages_sizes: Sequence[int] + message_size: int destination_name: str + messages_count: int class PublishAttrs(TypedDict): diff --git a/faststream/rabbit/prometheus/provider.py b/faststream/rabbit/prometheus/provider.py index ada5edaf56..de6b095d62 100644 --- a/faststream/rabbit/prometheus/provider.py +++ b/faststream/rabbit/prometheus/provider.py @@ -26,7 +26,8 @@ def get_consume_attrs_from_message( return { "destination_name": f"{exchange}.{routing_key}", - "messages_sizes": [len(msg.body)], + "message_size": len(msg.body), + "messages_count": 1, } def get_publish_destination_name_from_kwargs( diff --git a/faststream/redis/prometheus/__init__.py b/faststream/redis/prometheus/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/faststream/redis/prometheus/middleware.py b/faststream/redis/prometheus/middleware.py new file mode 100644 index 0000000000..1ccbfce248 --- /dev/null +++ b/faststream/redis/prometheus/middleware.py @@ -0,0 +1,19 @@ +from typing import TYPE_CHECKING + +from faststream.prometheus.middleware import BasePrometheusMiddleware +from faststream.redis.prometheus.provider import attributes_provider_factory + +if TYPE_CHECKING: + from prometheus_client import CollectorRegistry + + +class RedisPrometheusMiddleware(BasePrometheusMiddleware): + def __init__( + self, + *, + registry: "CollectorRegistry", + ): + super().__init__( + settings_provider_factory=attributes_provider_factory, + registry=registry, + ) diff --git a/faststream/redis/prometheus/provider.py b/faststream/redis/prometheus/provider.py new file mode 100644 index 0000000000..c32b1b056d --- /dev/null +++ b/faststream/redis/prometheus/provider.py @@ -0,0 +1,61 @@ +from typing import TYPE_CHECKING, Optional, Union + +from faststream.prometheus.provider import ( + ConsumeAttrs, + MetricsSettingsProvider, +) + +if TYPE_CHECKING: + from faststream.broker.message import StreamMessage + from faststream.types import AnyDict + + +class BaseRedisMetricsSettingsProvider(MetricsSettingsProvider["AnyDict"]): + def __init__(self): + self.messaging_system = "redis" + + def get_publish_destination_name_from_kwargs( + self, + kwargs: "AnyDict", + ) -> str: + return self._get_destination(kwargs) + + @staticmethod + def _get_destination(kwargs: "AnyDict") -> str: + return kwargs.get("channel") or kwargs.get("list") or kwargs.get("stream") or "" + + +class RedisMetricsSettingsProvider(BaseRedisMetricsSettingsProvider): + def get_consume_attrs_from_message( + self, + msg: "StreamMessage[AnyDict]", + ) -> ConsumeAttrs: + return { + "destination_name": self._get_destination(msg.raw_message), + "message_size": len(msg.body), + "messages_count": 1, + } + + +class BatchRedisMetricsSettingsProvider(BaseRedisMetricsSettingsProvider): + def get_consume_attrs_from_message( + self, + msg: "StreamMessage[AnyDict]", + ) -> ConsumeAttrs: + return { + "destination_name": self._get_destination(msg.raw_message), + "message_size": len(msg.body), + "messages_count": len(msg._decoded_body), + } + + +def attributes_provider_factory( + msg: Optional["AnyDict"], +) -> Union[ + RedisMetricsSettingsProvider, + BatchRedisMetricsSettingsProvider, +]: + if msg is not None and msg.get("type", "").startswith("b"): + return BatchRedisMetricsSettingsProvider() + else: + return RedisMetricsSettingsProvider() diff --git a/tests/prometheus/basic.py b/tests/prometheus/basic.py index 55105f5591..05eb59ab8c 100644 --- a/tests/prometheus/basic.py +++ b/tests/prometheus/basic.py @@ -1,6 +1,6 @@ import asyncio from typing import Any, Optional, Type -from unittest.mock import Mock, call, ANY +from unittest.mock import ANY, Mock, call import pytest from prometheus_client import CollectorRegistry @@ -11,7 +11,6 @@ PROCESSING_STATUS_BY_HANDLER_EXCEPTION_MAP, BasePrometheusMiddleware, ) -from faststream.prometheus.provider import MetricsSettingsProvider from tests.brokers.base.basic import BaseTestcaseConfig @@ -20,30 +19,30 @@ class LocalPrometheusTestcase(BaseTestcaseConfig): broker_class: Type[BrokerUsecase] middleware_class: Type[BasePrometheusMiddleware] message_class: Type[StreamMessage[Any]] - provider: MetricsSettingsProvider - - @pytest.fixture - def registry(self): - return CollectorRegistry() @staticmethod def consume_destination_name(queue: str) -> str: return queue + @property + def settings_provider_factory(self): + return self.middleware_class( + registry=CollectorRegistry() + )._settings_provider_factory + @pytest.mark.parametrize("status", AckStatus) @pytest.mark.parametrize( "exception_class", [*list(PROCESSING_STATUS_BY_HANDLER_EXCEPTION_MAP.keys()), Exception, None], ) - async def test_subscriber_metrics( + async def test_metrics( self, event: asyncio.Event, queue: str, - registry: CollectorRegistry, status: AckStatus, exception_class: Optional[Type[Exception]], ): - middleware = self.middleware_class(registry=registry) + middleware = self.middleware_class(registry=CollectorRegistry()) metrics_mock = Mock() middleware._metrics = metrics_mock @@ -92,41 +91,32 @@ def assert_consume_metrics( message: Any, exception_class: Optional[Type[Exception]], ): - consume_attrs = self.provider.get_consume_attrs_from_message(message) + settings_provider = self.settings_provider_factory(message.raw_message) + consume_attrs = settings_provider.get_consume_attrs_from_message(message) assert metrics.received_messages.labels.mock_calls == [ call( - broker=self.provider.messaging_system, + broker=settings_provider.messaging_system, handler=consume_attrs["destination_name"], ), - call().inc(), + call().inc(consume_attrs["messages_count"]), ] - received_messages_size_labels_mock_calls = [] - - for size in consume_attrs["messages_sizes"]: - received_messages_size_labels_mock_calls.extend( - [ - call( - broker=self.provider.messaging_system, - handler=consume_attrs["destination_name"], - ), - call().observe(size), - ] - ) - - assert ( - metrics.received_messages_size.labels.mock_calls - == received_messages_size_labels_mock_calls - ) + assert metrics.received_messages_size.labels.mock_calls == [ + call( + broker=settings_provider.messaging_system, + handler=consume_attrs["destination_name"], + ), + call().observe(consume_attrs["message_size"]), + ] assert metrics.received_messages_in_process.labels.mock_calls == [ call( - broker=self.provider.messaging_system, + broker=settings_provider.messaging_system, handler=consume_attrs["destination_name"], ), call().inc(), call( - broker=self.provider.messaging_system, + broker=settings_provider.messaging_system, handler=consume_attrs["destination_name"], ), call().dec(), @@ -134,7 +124,7 @@ def assert_consume_metrics( assert metrics.received_messages_processing_time.labels.mock_calls == [ call( - broker=self.provider.messaging_system, + broker=settings_provider.messaging_system, handler=consume_attrs["destination_name"], ), call().observe(ANY), @@ -150,7 +140,7 @@ def assert_consume_metrics( assert metrics.received_processed_messages.labels.mock_calls == [ call( - broker=self.provider.messaging_system, + broker=settings_provider.messaging_system, handler=consume_attrs["destination_name"], status=status, ), @@ -160,7 +150,7 @@ def assert_consume_metrics( if status == "error": assert metrics.messages_processing_exceptions.labels.mock_calls == [ call( - broker=self.provider.messaging_system, + broker=settings_provider.messaging_system, handler=consume_attrs["destination_name"], exception_type=exception_class.__name__, ), @@ -168,11 +158,16 @@ def assert_consume_metrics( ] def assert_publish_metrics(self, metrics: Any): + settings_provider = self.settings_provider_factory(None) assert metrics.messages_publish_time.labels.mock_calls == [ - call(broker=self.provider.messaging_system, destination=ANY), + call(broker=settings_provider.messaging_system, destination=ANY), call().observe(ANY), ] assert metrics.published_messages.labels.mock_calls == [ - call(broker=self.provider.messaging_system, destination=ANY, status=ANY), + call( + broker=settings_provider.messaging_system, + destination=ANY, + status="success", + ), call().inc(ANY), ] diff --git a/tests/prometheus/rabbit/test_rabbit.py b/tests/prometheus/rabbit/test_rabbit.py index aa6fc2a1da..74165af3a8 100644 --- a/tests/prometheus/rabbit/test_rabbit.py +++ b/tests/prometheus/rabbit/test_rabbit.py @@ -3,7 +3,6 @@ from faststream.rabbit import RabbitBroker, RabbitExchange, RabbitMessage from faststream.rabbit.prometheus.middleware import RabbitPrometheusMiddleware -from faststream.rabbit.prometheus.provider import RabbitMetricsSettingsProvider from tests.brokers.rabbit.test_consume import TestConsume from tests.brokers.rabbit.test_publish import TestPublish from tests.prometheus.basic import LocalPrometheusTestcase @@ -14,11 +13,11 @@ def exchange(queue): return RabbitExchange(name=queue) +@pytest.mark.rabbit class TestPrometheus(LocalPrometheusTestcase): broker_class = RabbitBroker middleware_class = RabbitPrometheusMiddleware message_class = RabbitMessage - provider = RabbitMetricsSettingsProvider() @pytest.mark.rabbit diff --git a/tests/prometheus/redis/__init__.py b/tests/prometheus/redis/__init__.py new file mode 100644 index 0000000000..ebec43fcd5 --- /dev/null +++ b/tests/prometheus/redis/__init__.py @@ -0,0 +1,3 @@ +import pytest + +pytest.importorskip("aio_pika") diff --git a/tests/prometheus/redis/test_redis.py b/tests/prometheus/redis/test_redis.py new file mode 100644 index 0000000000..c891c198ea --- /dev/null +++ b/tests/prometheus/redis/test_redis.py @@ -0,0 +1,73 @@ +import asyncio +from unittest.mock import Mock + +import pytest +from prometheus_client import CollectorRegistry + +from faststream.redis import ListSub, RedisBroker, RedisMessage +from faststream.redis.prometheus.middleware import RedisPrometheusMiddleware +from tests.brokers.redis.test_consume import TestConsume +from tests.brokers.redis.test_publish import TestPublish +from tests.prometheus.basic import LocalPrometheusTestcase + + +@pytest.mark.redis +class TestPrometheus(LocalPrometheusTestcase): + broker_class = RedisBroker + middleware_class = RedisPrometheusMiddleware + message_class = RedisMessage + + async def test_metrics_batch( + self, + event: asyncio.Event, + queue: str, + ): + middleware = self.middleware_class(registry=CollectorRegistry()) + metrics_mock = Mock() + middleware._metrics = metrics_mock + + broker = self.broker_class(middlewares=(middleware,)) + + args, kwargs = self.get_subscriber_params(list=ListSub(queue, batch=True)) + + message_class = self.message_class + message = None + + @broker.subscriber(*args, **kwargs) + async def handler(m: message_class): + event.set() + + nonlocal message + message = m + + async with broker: + await broker.start() + tasks = ( + asyncio.create_task(broker.publish_batch("hello", "world", list=queue)), + asyncio.create_task(event.wait()), + ) + await asyncio.wait(tasks, timeout=self.timeout) + + assert event.is_set() + self.assert_consume_metrics( + metrics=metrics_mock, message=message, exception_class=None + ) + self.assert_publish_metrics(metrics=metrics_mock) + + +@pytest.mark.redis +class TestPublishWithPrometheus(TestPublish): + def get_broker(self, apply_types: bool = False): + return RedisBroker( + middlewares=(RedisPrometheusMiddleware(registry=CollectorRegistry()),), + apply_types=apply_types, + ) + + +@pytest.mark.redis +class TestConsumeWithTelemetry(TestConsume): + def get_broker(self, apply_types: bool = False): + return RedisBroker( + middlewares=(RedisPrometheusMiddleware(registry=CollectorRegistry()),), + apply_types=apply_types, + )