Skip to content

Commit

Permalink
Separated thread for confluent kafka consumer client. (#2014)
Browse files Browse the repository at this point in the history
* fix: Disabled excessive throttling for BatchSubscriber.

* fix: Use separate thread for confluent kafka consumer.

* refactor: Added run_in_executor function.

* fix: Stop consumer client after consumer tasks are stopped.
  • Loading branch information
DABND19 authored Dec 30, 2024
1 parent d482f7c commit 18c166f
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 36 deletions.
10 changes: 9 additions & 1 deletion faststream/_internal/utils/functions.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
from collections.abc import AsyncIterator, Awaitable, Iterator
from concurrent.futures import Executor
from contextlib import asynccontextmanager, contextmanager
from functools import wraps
from functools import partial, wraps
from typing import (
Any,
Callable,
Optional,
TypeVar,
Union,
cast,
Expand Down Expand Up @@ -80,3 +83,8 @@ def drop_response_type(model: CallModel) -> CallModel:

async def return_input(x: Any) -> Any:
return x


async def run_in_executor(executor: Optional[Executor], func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(executor, partial(func, *args, **kwargs))
49 changes: 26 additions & 23 deletions faststream/confluent/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
from collections.abc import Iterable, Sequence
from concurrent.futures import ThreadPoolExecutor
from contextlib import suppress
from time import time
from typing import (
Expand All @@ -17,7 +18,7 @@

from faststream._internal.constants import EMPTY
from faststream._internal.log import logger as faststream_logger
from faststream._internal.utils.functions import call_or_await
from faststream._internal.utils.functions import call_or_await, run_in_executor
from faststream.confluent import config as config_module
from faststream.confluent.schemas import TopicPartition
from faststream.exceptions import SetupError
Expand Down Expand Up @@ -314,9 +315,8 @@ def __init__(
self.config = final_config
self.consumer = Consumer(final_config, logger=self.logger_state.logger.logger) # type: ignore[call-arg]

# We shouldn't read messages and close consumer concurrently
# https://github.com/airtai/faststream/issues/1904#issuecomment-2506990895
self._lock = anyio.Lock()
# A pool with single thread is used in order to execute the commands of the consumer sequentially:
self._thread_pool = ThreadPoolExecutor(max_workers=1)

@property
def topics_to_create(self) -> list[str]:
Expand All @@ -325,11 +325,12 @@ def topics_to_create(self) -> list[str]:
async def start(self) -> None:
"""Starts the Kafka consumer and subscribes to the specified topics."""
if self.allow_auto_create_topics:
await call_or_await(
await run_in_executor(
self._thread_pool,
create_topics,
self.topics_to_create,
self.config,
self.logger_state.logger.logger,
topics=self.topics_to_create,
config=self.config,
logger_=self.logger_state.logger.logger,
)

else:
Expand All @@ -339,10 +340,13 @@ async def start(self) -> None:
)

if self.topics:
await call_or_await(self.consumer.subscribe, self.topics)
await run_in_executor(
self._thread_pool, self.consumer.subscribe, topics=self.topics
)

elif self.partitions:
await call_or_await(
await run_in_executor(
self._thread_pool,
self.consumer.assign,
[p.to_confluent() for p in self.partitions],
)
Expand All @@ -353,7 +357,7 @@ async def start(self) -> None:

async def commit(self, asynchronous: bool = True) -> None:
"""Commits the offsets of all messages returned by the last poll operation."""
await call_or_await(self.consumer.commit, asynchronous=asynchronous)
await run_in_executor(self._thread_pool, self.consumer.commit, asynchronous=asynchronous)

async def stop(self) -> None:
"""Stops the Kafka consumer and releases all resources."""
Expand All @@ -376,13 +380,13 @@ async def stop(self) -> None:
)

# Wrap calls to async to make method cancelable by timeout
async with self._lock:
await call_or_await(self.consumer.close)
await run_in_executor(self._thread_pool, self.consumer.close)

self._thread_pool.shutdown(wait=False)

async def getone(self, timeout: float = 0.1) -> Optional[Message]:
"""Consumes a single message from Kafka."""
async with self._lock:
msg = await call_or_await(self.consumer.poll, timeout)
msg = await run_in_executor(self._thread_pool, self.consumer.poll, timeout)
return check_msg_error(msg)

async def getmany(
Expand All @@ -391,13 +395,12 @@ async def getmany(
max_records: Optional[int] = 10,
) -> tuple[Message, ...]:
"""Consumes a batch of messages from Kafka and groups them by topic and partition."""
async with self._lock:
raw_messages: list[Optional[Message]] = await call_or_await(
self.consumer.consume, # type: ignore[arg-type]
num_messages=max_records or 10,
timeout=timeout,
)

raw_messages: list[Optional[Message]] = await run_in_executor(
self._thread_pool,
self.consumer.consume, # type: ignore[arg-type]
num_messages=max_records or 10,
timeout=timeout,
)
return tuple(x for x in map(check_msg_error, raw_messages) if x is not None)

async def seek(self, topic: str, partition: int, offset: int) -> None:
Expand All @@ -407,7 +410,7 @@ async def seek(self, topic: str, partition: int, offset: int) -> None:
partition=partition,
offset=offset,
)
await call_or_await(self.consumer.seek, topic_partition.to_confluent())
await run_in_executor(self._thread_pool, self.consumer.seek, topic_partition.to_confluent())


def check_msg_error(msg: Optional[Message]) -> Optional[Message]:
Expand Down
20 changes: 8 additions & 12 deletions faststream/confluent/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@ async def start(self) -> None:
self.add_task(self._consume())

async def close(self) -> None:
await super().close()

if self.consumer is not None:
await self.consumer.stop()
self.consumer = None

await super().close()

@override
async def get_one(
self,
Expand Down Expand Up @@ -335,18 +335,14 @@ def __init__(

async def get_msg(self) -> Optional[tuple["Message", ...]]:
assert self.consumer, "You should setup subscriber at first." # nosec B101

messages = await self.consumer.getmany(
timeout=self.polling_interval,
max_records=self.max_records,
return (
await self.consumer.getmany(
timeout=self.polling_interval,
max_records=self.max_records,
)
or None
)

if not messages: # TODO: why we are sleeping here?
await anyio.sleep(self.polling_interval)
return None

return messages

def get_log_context(
self,
message: Optional["StreamMessage[tuple[Message, ...]]"],
Expand Down

0 comments on commit 18c166f

Please sign in to comment.