From 190a58bc42f988526a7e38d45c5ade62f6c5c1e4 Mon Sep 17 00:00:00 2001 From: Kumaran Rajendhiran Date: Thu, 6 Jun 2024 17:29:19 +0530 Subject: [PATCH] Add config param to pass additional parameters to confluent-kafka-python --- docs/docs/en/confluent/security.md | 18 ++++++++++++++++++ faststream/confluent/broker/broker.py | 10 ++++++++++ faststream/confluent/client.py | 19 ++++++++++++++----- faststream/confluent/security.py | 4 ---- faststream/rabbit/subscriber/asyncapi.py | 4 +--- tests/brokers/confluent/test_security.py | 9 --------- 6 files changed, 43 insertions(+), 21 deletions(-) diff --git a/docs/docs/en/confluent/security.md b/docs/docs/en/confluent/security.md index bb9d960eca..8bfc32a92b 100644 --- a/docs/docs/en/confluent/security.md +++ b/docs/docs/en/confluent/security.md @@ -65,3 +65,21 @@ If the user does not want to use SSL encryption without the warning getting logg ```python linenums="1" {!> docs_src/confluent/security/sasl_scram512.py [ln:1-10.25,11-] !} ``` + +### 4. Other security related usecases + +**Purpose**: If you want to pass additional values to `confluent-kafka-python`, you can pass a dictionary called `config` to `KafkaBroker`. For example, to pass your own certificate file: + +```python +from faststream.confluent import KafkaBroker +from faststream.security import SASLPlaintext + +security = SASLPlaintext( + username="admin", + password="password", +) + +config = {"ssl.ca.location": "~/my_certs/CRT_cacerts.pem"} + +broker = KafkaBroker("localhost:9092", security=security, config=config) +``` diff --git a/faststream/confluent/broker/broker.py b/faststream/confluent/broker/broker.py index 960b2606ad..78ffc3c5d0 100644 --- a/faststream/confluent/broker/broker.py +++ b/faststream/confluent/broker/broker.py @@ -128,6 +128,13 @@ def __init__( """ ), ] = SERVICE_NAME, + config: Annotated[ + Optional[Dict[str, Any]], + Doc(""" + Extra configuration for the confluent-kafka-python + producer/consumer. See `confluent_kafka.Config `_. + """), + ] = None, # publisher args acks: Annotated[ Union[Literal[0, 1, -1, "all"], object], @@ -409,6 +416,7 @@ def __init__( ) self.client_id = client_id self._producer = None + self.config = config async def _close( self, @@ -449,6 +457,7 @@ async def _connect( # type: ignore[override] **kwargs, client_id=client_id, logger=self.logger, + config=self.config, ) self._producer = AsyncConfluentFastProducer( @@ -459,6 +468,7 @@ async def _connect( # type: ignore[override] AsyncConfluentConsumer, **filter_by_dict(ConsumerConnectionParams, kwargs), logger=self.logger, + config=self.config, ) async def start(self) -> None: diff --git a/faststream/confluent/client.py b/faststream/confluent/client.py index f1703c3694..027ade73d3 100644 --- a/faststream/confluent/client.py +++ b/faststream/confluent/client.py @@ -1,5 +1,4 @@ import asyncio -from ssl import SSLContext from time import time from typing import ( TYPE_CHECKING, @@ -99,7 +98,6 @@ def __init__( send_backoff_ms: int = 100, retry_backoff_ms: int = 100, security_protocol: str = "PLAINTEXT", - ssl_context: Optional[SSLContext] = None, connections_max_idle_ms: int = 540000, enable_idempotence: bool = False, transactional_id: Optional[Union[str, int]] = None, @@ -110,12 +108,16 @@ def __init__( sasl_kerberos_service_name: str = "kafka", sasl_kerberos_domain_name: Optional[str] = None, sasl_oauth_token_provider: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, logger: Annotated[ Union["LoggerProto", None, object], Doc("User specified logger to pass into Context and log service messages."), ] = logger, ) -> None: self.logger = logger + + self.config = {} if config is None else config + if isinstance(bootstrap_servers, Iterable) and not isinstance( bootstrap_servers, str ): @@ -127,7 +129,7 @@ def __init__( if acks is _missing or acks == "all": acks = -1 - self.config = { + config_from_params = { # "topic.metadata.refresh.interval.ms": 1000, "bootstrap.servers": bootstrap_servers, "client.id": client_id, @@ -146,6 +148,8 @@ def __init__( "connections.max.idle.ms": connections_max_idle_ms, "sasl.kerberos.service.name": sasl_kerberos_service_name, } + self.config = {**self.config, **config_from_params} + if sasl_mechanism: self.config.update( { @@ -293,7 +297,6 @@ def __init__( heartbeat_interval_ms: int = 3000, consumer_timeout_ms: int = 200, max_poll_records: Optional[int] = None, - ssl_context: Optional[SSLContext] = None, security_protocol: str = "PLAINTEXT", api_version: str = "auto", exclude_internal_topics: bool = True, @@ -305,12 +308,16 @@ def __init__( sasl_kerberos_service_name: str = "kafka", sasl_kerberos_domain_name: Optional[str] = None, sasl_oauth_token_provider: Optional[str] = None, + config: Optional[Dict[str, Any]] = None, logger: Annotated[ Union["LoggerProto", None, object], Doc("User specified logger to pass into Context and log service messages."), ] = logger, ) -> None: self.logger = logger + + self.config = {} if config is None else config + if group_id is None: group_id = "confluent-kafka-consumer-group" @@ -328,7 +335,7 @@ def __init__( for x in partition_assignment_strategy ] ) - self.config = { + config_from_params = { "allow.auto.create.topics": True, # "topic.metadata.refresh.interval.ms": 1000, "bootstrap.servers": bootstrap_servers, @@ -355,6 +362,8 @@ def __init__( "isolation.level": isolation_level, "sasl.kerberos.service.name": sasl_kerberos_service_name, } + self.config = {**self.config, **config_from_params} + if sasl_mechanism: self.config.update( { diff --git a/faststream/confluent/security.py b/faststream/confluent/security.py index dea4a0bc98..e4d4e788b3 100644 --- a/faststream/confluent/security.py +++ b/faststream/confluent/security.py @@ -31,7 +31,6 @@ def parse_security(security: Optional[BaseSecurity]) -> "AnyDict": def _parse_base_security(security: BaseSecurity) -> "AnyDict": return { "security_protocol": "SSL" if security.use_ssl else "PLAINTEXT", - "ssl_context": security.ssl_context, } @@ -45,7 +44,6 @@ def _parse_sasl_plaintext(security: SASLPlaintext) -> "AnyDict": return { "security_protocol": "SASL_SSL" if security.use_ssl else "SASL_PLAINTEXT", - "ssl_context": security.ssl_context, "sasl_mechanism": "PLAIN", "sasl_plain_username": security.username, "sasl_plain_password": security.password, @@ -55,7 +53,6 @@ def _parse_sasl_plaintext(security: SASLPlaintext) -> "AnyDict": def _parse_sasl_scram256(security: SASLScram256) -> "AnyDict": return { "security_protocol": "SASL_SSL" if security.use_ssl else "SASL_PLAINTEXT", - "ssl_context": security.ssl_context, "sasl_mechanism": "SCRAM-SHA-256", "sasl_plain_username": security.username, "sasl_plain_password": security.password, @@ -65,7 +62,6 @@ def _parse_sasl_scram256(security: SASLScram256) -> "AnyDict": def _parse_sasl_scram512(security: SASLScram512) -> "AnyDict": return { "security_protocol": "SASL_SSL" if security.use_ssl else "SASL_PLAINTEXT", - "ssl_context": security.ssl_context, "sasl_mechanism": "SCRAM-SHA-512", "sasl_plain_username": security.username, "sasl_plain_password": security.password, diff --git a/faststream/rabbit/subscriber/asyncapi.py b/faststream/rabbit/subscriber/asyncapi.py index 2b0cb4cd5b..05313a6247 100644 --- a/faststream/rabbit/subscriber/asyncapi.py +++ b/faststream/rabbit/subscriber/asyncapi.py @@ -18,9 +18,7 @@ class AsyncAPISubscriber(LogicSubscriber): """AsyncAPI-compatible Rabbit Subscriber class.""" def get_name(self) -> str: - return ( - f"{self.queue.name}:{getattr(self.exchange, 'name', None) or '_'}:{self.call_name}" - ) + return f"{self.queue.name}:{getattr(self.exchange, 'name', None) or '_'}:{self.call_name}" def get_schema(self) -> Dict[str, Channel]: payloads = self.get_payloads() diff --git a/tests/brokers/confluent/test_security.py b/tests/brokers/confluent/test_security.py index 7a11bb7119..176e641e29 100644 --- a/tests/brokers/confluent/test_security.py +++ b/tests/brokers/confluent/test_security.py @@ -1,4 +1,3 @@ -import ssl from contextlib import contextmanager from typing import Tuple from unittest.mock import AsyncMock, MagicMock, patch @@ -43,8 +42,6 @@ async def test_base_security(): == call_kwargs["security_protocol"] ) - assert type(producer_call_kwargs["ssl_context"]) == ssl.SSLContext - @pytest.mark.asyncio() @pytest.mark.confluent() @@ -70,8 +67,6 @@ async def test_scram256(): == call_kwargs["security_protocol"] ) - assert type(producer_call_kwargs["ssl_context"]) == ssl.SSLContext - @pytest.mark.asyncio() @pytest.mark.confluent() @@ -97,8 +92,6 @@ async def test_scram512(): == call_kwargs["security_protocol"] ) - assert type(producer_call_kwargs["ssl_context"]) == ssl.SSLContext - @pytest.mark.asyncio() @pytest.mark.confluent() @@ -123,5 +116,3 @@ async def test_plaintext(): producer_call_kwargs["security_protocol"] == call_kwargs["security_protocol"] ) - - assert type(producer_call_kwargs["ssl_context"]) == ssl.SSLContext