Skip to content

Commit

Permalink
small refactoring, redis metrics
Browse files Browse the repository at this point in the history
  • Loading branch information
roma-frolov committed Sep 16, 2024
1 parent 1840ac8 commit 119dc8a
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 50 deletions.
13 changes: 5 additions & 8 deletions faststream/prometheus/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,12 @@ async def consume_scope(
self._metrics.received_messages.labels(
broker=messaging_system,
handler=destination_name,
).inc()

messages_sizes = consume_attrs["messages_sizes"]
).inc(consume_attrs["messages_count"])

for size in messages_sizes:
self._metrics.received_messages_size.labels(
broker=messaging_system,
handler=destination_name,
).observe(size)
self._metrics.received_messages_size.labels(
broker=messaging_system,
handler=destination_name,
).observe(consume_attrs["message_size"])

err: Optional[Exception] = None

Expand Down
5 changes: 3 additions & 2 deletions faststream/prometheus/provider.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Protocol, Sequence, TypedDict
from typing import TYPE_CHECKING, Protocol, TypedDict

from faststream.broker.message import MsgType, StreamMessage

Expand All @@ -8,8 +8,9 @@


class ConsumeAttrs(TypedDict):
messages_sizes: Sequence[int]
message_size: int
destination_name: str
messages_count: int


class PublishAttrs(TypedDict):
Expand Down
3 changes: 2 additions & 1 deletion faststream/rabbit/prometheus/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def get_consume_attrs_from_message(

return {
"destination_name": f"{exchange}.{routing_key}",
"messages_sizes": [len(msg.body)],
"message_size": len(msg.body),
"messages_count": 1,
}

def get_publish_destination_name_from_kwargs(
Expand Down
Empty file.
19 changes: 19 additions & 0 deletions faststream/redis/prometheus/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from typing import TYPE_CHECKING

from faststream.prometheus.middleware import BasePrometheusMiddleware
from faststream.redis.prometheus.provider import attributes_provider_factory

if TYPE_CHECKING:
from prometheus_client import CollectorRegistry


class RedisPrometheusMiddleware(BasePrometheusMiddleware):
def __init__(
self,
*,
registry: "CollectorRegistry",
):
super().__init__(
settings_provider_factory=attributes_provider_factory,
registry=registry,
)
61 changes: 61 additions & 0 deletions faststream/redis/prometheus/provider.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import TYPE_CHECKING, Optional, Union

from faststream.prometheus.provider import (
ConsumeAttrs,
MetricsSettingsProvider,
)

if TYPE_CHECKING:
from faststream.broker.message import StreamMessage
from faststream.types import AnyDict


class BaseRedisMetricsSettingsProvider(MetricsSettingsProvider["AnyDict"]):
def __init__(self):
self.messaging_system = "redis"

def get_publish_destination_name_from_kwargs(
self,
kwargs: "AnyDict",
) -> str:
return self._get_destination(kwargs)

@staticmethod
def _get_destination(kwargs: "AnyDict") -> str:
return kwargs.get("channel") or kwargs.get("list") or kwargs.get("stream") or ""


class RedisMetricsSettingsProvider(BaseRedisMetricsSettingsProvider):
def get_consume_attrs_from_message(
self,
msg: "StreamMessage[AnyDict]",
) -> ConsumeAttrs:
return {
"destination_name": self._get_destination(msg.raw_message),
"message_size": len(msg.body),
"messages_count": 1,
}


class BatchRedisMetricsSettingsProvider(BaseRedisMetricsSettingsProvider):
def get_consume_attrs_from_message(
self,
msg: "StreamMessage[AnyDict]",
) -> ConsumeAttrs:
return {
"destination_name": self._get_destination(msg.raw_message),
"message_size": len(msg.body),
"messages_count": len(msg._decoded_body),
}


def attributes_provider_factory(
msg: Optional["AnyDict"],
) -> Union[
RedisMetricsSettingsProvider,
BatchRedisMetricsSettingsProvider,
]:
if msg is not None and msg.get("type", "").startswith("b"):
return BatchRedisMetricsSettingsProvider()
else:
return RedisMetricsSettingsProvider()
69 changes: 32 additions & 37 deletions tests/prometheus/basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from typing import Any, Optional, Type
from unittest.mock import Mock, call, ANY
from unittest.mock import ANY, Mock, call

import pytest
from prometheus_client import CollectorRegistry
Expand All @@ -11,7 +11,6 @@
PROCESSING_STATUS_BY_HANDLER_EXCEPTION_MAP,
BasePrometheusMiddleware,
)
from faststream.prometheus.provider import MetricsSettingsProvider
from tests.brokers.base.basic import BaseTestcaseConfig


Expand All @@ -20,30 +19,30 @@ class LocalPrometheusTestcase(BaseTestcaseConfig):
broker_class: Type[BrokerUsecase]
middleware_class: Type[BasePrometheusMiddleware]
message_class: Type[StreamMessage[Any]]
provider: MetricsSettingsProvider

@pytest.fixture
def registry(self):
return CollectorRegistry()

@staticmethod
def consume_destination_name(queue: str) -> str:
return queue

@property
def settings_provider_factory(self):
return self.middleware_class(
registry=CollectorRegistry()
)._settings_provider_factory

@pytest.mark.parametrize("status", AckStatus)
@pytest.mark.parametrize(
"exception_class",
[*list(PROCESSING_STATUS_BY_HANDLER_EXCEPTION_MAP.keys()), Exception, None],
)
async def test_subscriber_metrics(
async def test_metrics(
self,
event: asyncio.Event,
queue: str,
registry: CollectorRegistry,
status: AckStatus,
exception_class: Optional[Type[Exception]],
):
middleware = self.middleware_class(registry=registry)
middleware = self.middleware_class(registry=CollectorRegistry())
metrics_mock = Mock()
middleware._metrics = metrics_mock

Expand Down Expand Up @@ -92,49 +91,40 @@ def assert_consume_metrics(
message: Any,
exception_class: Optional[Type[Exception]],
):
consume_attrs = self.provider.get_consume_attrs_from_message(message)
settings_provider = self.settings_provider_factory(message.raw_message)
consume_attrs = settings_provider.get_consume_attrs_from_message(message)
assert metrics.received_messages.labels.mock_calls == [
call(
broker=self.provider.messaging_system,
broker=settings_provider.messaging_system,
handler=consume_attrs["destination_name"],
),
call().inc(),
call().inc(consume_attrs["messages_count"]),
]

received_messages_size_labels_mock_calls = []

for size in consume_attrs["messages_sizes"]:
received_messages_size_labels_mock_calls.extend(
[
call(
broker=self.provider.messaging_system,
handler=consume_attrs["destination_name"],
),
call().observe(size),
]
)

assert (
metrics.received_messages_size.labels.mock_calls
== received_messages_size_labels_mock_calls
)
assert metrics.received_messages_size.labels.mock_calls == [
call(
broker=settings_provider.messaging_system,
handler=consume_attrs["destination_name"],
),
call().observe(consume_attrs["message_size"]),
]

assert metrics.received_messages_in_process.labels.mock_calls == [
call(
broker=self.provider.messaging_system,
broker=settings_provider.messaging_system,
handler=consume_attrs["destination_name"],
),
call().inc(),
call(
broker=self.provider.messaging_system,
broker=settings_provider.messaging_system,
handler=consume_attrs["destination_name"],
),
call().dec(),
]

assert metrics.received_messages_processing_time.labels.mock_calls == [
call(
broker=self.provider.messaging_system,
broker=settings_provider.messaging_system,
handler=consume_attrs["destination_name"],
),
call().observe(ANY),
Expand All @@ -150,7 +140,7 @@ def assert_consume_metrics(

assert metrics.received_processed_messages.labels.mock_calls == [
call(
broker=self.provider.messaging_system,
broker=settings_provider.messaging_system,
handler=consume_attrs["destination_name"],
status=status,
),
Expand All @@ -160,19 +150,24 @@ def assert_consume_metrics(
if status == "error":
assert metrics.messages_processing_exceptions.labels.mock_calls == [
call(
broker=self.provider.messaging_system,
broker=settings_provider.messaging_system,
handler=consume_attrs["destination_name"],
exception_type=exception_class.__name__,
),
call().inc(),
]

def assert_publish_metrics(self, metrics: Any):
settings_provider = self.settings_provider_factory(None)
assert metrics.messages_publish_time.labels.mock_calls == [
call(broker=self.provider.messaging_system, destination=ANY),
call(broker=settings_provider.messaging_system, destination=ANY),
call().observe(ANY),
]
assert metrics.published_messages.labels.mock_calls == [
call(broker=self.provider.messaging_system, destination=ANY, status=ANY),
call(
broker=settings_provider.messaging_system,
destination=ANY,
status="success",
),
call().inc(ANY),
]
3 changes: 1 addition & 2 deletions tests/prometheus/rabbit/test_rabbit.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

from faststream.rabbit import RabbitBroker, RabbitExchange, RabbitMessage
from faststream.rabbit.prometheus.middleware import RabbitPrometheusMiddleware
from faststream.rabbit.prometheus.provider import RabbitMetricsSettingsProvider
from tests.brokers.rabbit.test_consume import TestConsume
from tests.brokers.rabbit.test_publish import TestPublish
from tests.prometheus.basic import LocalPrometheusTestcase
Expand All @@ -14,11 +13,11 @@ def exchange(queue):
return RabbitExchange(name=queue)


@pytest.mark.rabbit
class TestPrometheus(LocalPrometheusTestcase):
broker_class = RabbitBroker
middleware_class = RabbitPrometheusMiddleware
message_class = RabbitMessage
provider = RabbitMetricsSettingsProvider()


@pytest.mark.rabbit
Expand Down
3 changes: 3 additions & 0 deletions tests/prometheus/redis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import pytest

pytest.importorskip("aio_pika")
73 changes: 73 additions & 0 deletions tests/prometheus/redis/test_redis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import asyncio
from unittest.mock import Mock

import pytest
from prometheus_client import CollectorRegistry

from faststream.redis import ListSub, RedisBroker, RedisMessage
from faststream.redis.prometheus.middleware import RedisPrometheusMiddleware
from tests.brokers.redis.test_consume import TestConsume
from tests.brokers.redis.test_publish import TestPublish
from tests.prometheus.basic import LocalPrometheusTestcase


@pytest.mark.redis
class TestPrometheus(LocalPrometheusTestcase):
broker_class = RedisBroker
middleware_class = RedisPrometheusMiddleware
message_class = RedisMessage

async def test_metrics_batch(
self,
event: asyncio.Event,
queue: str,
):
middleware = self.middleware_class(registry=CollectorRegistry())
metrics_mock = Mock()
middleware._metrics = metrics_mock

broker = self.broker_class(middlewares=(middleware,))

args, kwargs = self.get_subscriber_params(list=ListSub(queue, batch=True))

message_class = self.message_class
message = None

@broker.subscriber(*args, **kwargs)
async def handler(m: message_class):
event.set()

nonlocal message
message = m

async with broker:
await broker.start()
tasks = (
asyncio.create_task(broker.publish_batch("hello", "world", list=queue)),
asyncio.create_task(event.wait()),
)
await asyncio.wait(tasks, timeout=self.timeout)

assert event.is_set()
self.assert_consume_metrics(
metrics=metrics_mock, message=message, exception_class=None
)
self.assert_publish_metrics(metrics=metrics_mock)


@pytest.mark.redis
class TestPublishWithPrometheus(TestPublish):
def get_broker(self, apply_types: bool = False):
return RedisBroker(
middlewares=(RedisPrometheusMiddleware(registry=CollectorRegistry()),),
apply_types=apply_types,
)


@pytest.mark.redis
class TestConsumeWithTelemetry(TestConsume):
def get_broker(self, apply_types: bool = False):
return RedisBroker(
middlewares=(RedisPrometheusMiddleware(registry=CollectorRegistry()),),
apply_types=apply_types,
)

0 comments on commit 119dc8a

Please sign in to comment.