Skip to content

Commit

Permalink
Add config param to pass additional parameters to confluent-kafka-python
Browse files Browse the repository at this point in the history
  • Loading branch information
kumaranvpl committed Jun 6, 2024
1 parent edc0f30 commit 190a58b
Show file tree
Hide file tree
Showing 6 changed files with 43 additions and 21 deletions.
18 changes: 18 additions & 0 deletions docs/docs/en/confluent/security.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,21 @@ If the user does not want to use SSL encryption without the warning getting logg
```python linenums="1"
{!> docs_src/confluent/security/sasl_scram512.py [ln:1-10.25,11-] !}
```

### 4. Other security related usecases

**Purpose**: If you want to pass additional values to `confluent-kafka-python`, you can pass a dictionary called `config` to `KafkaBroker`. For example, to pass your own certificate file:

```python
from faststream.confluent import KafkaBroker
from faststream.security import SASLPlaintext

security = SASLPlaintext(
username="admin",
password="password",
)

config = {"ssl.ca.location": "~/my_certs/CRT_cacerts.pem"}

broker = KafkaBroker("localhost:9092", security=security, config=config)
```
10 changes: 10 additions & 0 deletions faststream/confluent/broker/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ def __init__(
"""
),
] = SERVICE_NAME,
config: Annotated[
Optional[Dict[str, Any]],
Doc("""
Extra configuration for the confluent-kafka-python
producer/consumer. See `confluent_kafka.Config <https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#kafka-client-configuration>`_.
"""),
] = None,
# publisher args
acks: Annotated[
Union[Literal[0, 1, -1, "all"], object],
Expand Down Expand Up @@ -409,6 +416,7 @@ def __init__(
)
self.client_id = client_id
self._producer = None
self.config = config

async def _close(
self,
Expand Down Expand Up @@ -449,6 +457,7 @@ async def _connect( # type: ignore[override]
**kwargs,
client_id=client_id,
logger=self.logger,
config=self.config,
)

self._producer = AsyncConfluentFastProducer(
Expand All @@ -459,6 +468,7 @@ async def _connect( # type: ignore[override]
AsyncConfluentConsumer,
**filter_by_dict(ConsumerConnectionParams, kwargs),
logger=self.logger,
config=self.config,
)

async def start(self) -> None:
Expand Down
19 changes: 14 additions & 5 deletions faststream/confluent/client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
from ssl import SSLContext
from time import time
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -99,7 +98,6 @@ def __init__(
send_backoff_ms: int = 100,
retry_backoff_ms: int = 100,
security_protocol: str = "PLAINTEXT",
ssl_context: Optional[SSLContext] = None,
connections_max_idle_ms: int = 540000,
enable_idempotence: bool = False,
transactional_id: Optional[Union[str, int]] = None,
Expand All @@ -110,12 +108,16 @@ def __init__(
sasl_kerberos_service_name: str = "kafka",
sasl_kerberos_domain_name: Optional[str] = None,
sasl_oauth_token_provider: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
logger: Annotated[
Union["LoggerProto", None, object],
Doc("User specified logger to pass into Context and log service messages."),
] = logger,
) -> None:
self.logger = logger

self.config = {} if config is None else config

if isinstance(bootstrap_servers, Iterable) and not isinstance(
bootstrap_servers, str
):
Expand All @@ -127,7 +129,7 @@ def __init__(
if acks is _missing or acks == "all":
acks = -1

self.config = {
config_from_params = {
# "topic.metadata.refresh.interval.ms": 1000,
"bootstrap.servers": bootstrap_servers,
"client.id": client_id,
Expand All @@ -146,6 +148,8 @@ def __init__(
"connections.max.idle.ms": connections_max_idle_ms,
"sasl.kerberos.service.name": sasl_kerberos_service_name,
}
self.config = {**self.config, **config_from_params}

if sasl_mechanism:
self.config.update(
{
Expand Down Expand Up @@ -293,7 +297,6 @@ def __init__(
heartbeat_interval_ms: int = 3000,
consumer_timeout_ms: int = 200,
max_poll_records: Optional[int] = None,
ssl_context: Optional[SSLContext] = None,
security_protocol: str = "PLAINTEXT",
api_version: str = "auto",
exclude_internal_topics: bool = True,
Expand All @@ -305,12 +308,16 @@ def __init__(
sasl_kerberos_service_name: str = "kafka",
sasl_kerberos_domain_name: Optional[str] = None,
sasl_oauth_token_provider: Optional[str] = None,
config: Optional[Dict[str, Any]] = None,
logger: Annotated[
Union["LoggerProto", None, object],
Doc("User specified logger to pass into Context and log service messages."),
] = logger,
) -> None:
self.logger = logger

self.config = {} if config is None else config

if group_id is None:
group_id = "confluent-kafka-consumer-group"

Expand All @@ -328,7 +335,7 @@ def __init__(
for x in partition_assignment_strategy
]
)
self.config = {
config_from_params = {
"allow.auto.create.topics": True,
# "topic.metadata.refresh.interval.ms": 1000,
"bootstrap.servers": bootstrap_servers,
Expand All @@ -355,6 +362,8 @@ def __init__(
"isolation.level": isolation_level,
"sasl.kerberos.service.name": sasl_kerberos_service_name,
}
self.config = {**self.config, **config_from_params}

if sasl_mechanism:
self.config.update(
{
Expand Down
4 changes: 0 additions & 4 deletions faststream/confluent/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def parse_security(security: Optional[BaseSecurity]) -> "AnyDict":
def _parse_base_security(security: BaseSecurity) -> "AnyDict":
return {
"security_protocol": "SSL" if security.use_ssl else "PLAINTEXT",
"ssl_context": security.ssl_context,
}


Expand All @@ -45,7 +44,6 @@ def _parse_sasl_plaintext(security: SASLPlaintext) -> "AnyDict":

return {
"security_protocol": "SASL_SSL" if security.use_ssl else "SASL_PLAINTEXT",
"ssl_context": security.ssl_context,
"sasl_mechanism": "PLAIN",
"sasl_plain_username": security.username,
"sasl_plain_password": security.password,
Expand All @@ -55,7 +53,6 @@ def _parse_sasl_plaintext(security: SASLPlaintext) -> "AnyDict":
def _parse_sasl_scram256(security: SASLScram256) -> "AnyDict":
return {
"security_protocol": "SASL_SSL" if security.use_ssl else "SASL_PLAINTEXT",
"ssl_context": security.ssl_context,
"sasl_mechanism": "SCRAM-SHA-256",
"sasl_plain_username": security.username,
"sasl_plain_password": security.password,
Expand All @@ -65,7 +62,6 @@ def _parse_sasl_scram256(security: SASLScram256) -> "AnyDict":
def _parse_sasl_scram512(security: SASLScram512) -> "AnyDict":
return {
"security_protocol": "SASL_SSL" if security.use_ssl else "SASL_PLAINTEXT",
"ssl_context": security.ssl_context,
"sasl_mechanism": "SCRAM-SHA-512",
"sasl_plain_username": security.username,
"sasl_plain_password": security.password,
Expand Down
4 changes: 1 addition & 3 deletions faststream/rabbit/subscriber/asyncapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@ class AsyncAPISubscriber(LogicSubscriber):
"""AsyncAPI-compatible Rabbit Subscriber class."""

def get_name(self) -> str:
return (
f"{self.queue.name}:{getattr(self.exchange, 'name', None) or '_'}:{self.call_name}"
)
return f"{self.queue.name}:{getattr(self.exchange, 'name', None) or '_'}:{self.call_name}"

def get_schema(self) -> Dict[str, Channel]:
payloads = self.get_payloads()
Expand Down
9 changes: 0 additions & 9 deletions tests/brokers/confluent/test_security.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ssl
from contextlib import contextmanager
from typing import Tuple
from unittest.mock import AsyncMock, MagicMock, patch
Expand Down Expand Up @@ -43,8 +42,6 @@ async def test_base_security():
== call_kwargs["security_protocol"]
)

assert type(producer_call_kwargs["ssl_context"]) == ssl.SSLContext


@pytest.mark.asyncio()
@pytest.mark.confluent()
Expand All @@ -70,8 +67,6 @@ async def test_scram256():
== call_kwargs["security_protocol"]
)

assert type(producer_call_kwargs["ssl_context"]) == ssl.SSLContext


@pytest.mark.asyncio()
@pytest.mark.confluent()
Expand All @@ -97,8 +92,6 @@ async def test_scram512():
== call_kwargs["security_protocol"]
)

assert type(producer_call_kwargs["ssl_context"]) == ssl.SSLContext


@pytest.mark.asyncio()
@pytest.mark.confluent()
Expand All @@ -123,5 +116,3 @@ async def test_plaintext():
producer_call_kwargs["security_protocol"]
== call_kwargs["security_protocol"]
)

assert type(producer_call_kwargs["ssl_context"]) == ssl.SSLContext

0 comments on commit 190a58b

Please sign in to comment.