diff --git a/faststream/_internal/utils/functions.py b/faststream/_internal/utils/functions.py index e8cb60d696..efea5541d6 100644 --- a/faststream/_internal/utils/functions.py +++ b/faststream/_internal/utils/functions.py @@ -1,9 +1,12 @@ +import asyncio from collections.abc import AsyncIterator, Awaitable, Iterator +from concurrent.futures import Executor from contextlib import asynccontextmanager, contextmanager -from functools import wraps +from functools import partial, wraps from typing import ( Any, Callable, + Optional, TypeVar, Union, cast, @@ -80,3 +83,8 @@ def drop_response_type(model: CallModel) -> CallModel: async def return_input(x: Any) -> Any: return x + + +async def run_in_executor(executor: Optional[Executor], func: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: + loop = asyncio.get_running_loop() + return await loop.run_in_executor(executor, partial(func, *args, **kwargs)) diff --git a/faststream/confluent/client.py b/faststream/confluent/client.py index f34f55e6ac..b4ffad370e 100644 --- a/faststream/confluent/client.py +++ b/faststream/confluent/client.py @@ -18,7 +18,7 @@ from faststream._internal.constants import EMPTY from faststream._internal.log import logger as faststream_logger -from faststream._internal.utils.functions import call_or_await +from faststream._internal.utils.functions import call_or_await, run_in_executor from faststream.confluent import config as config_module from faststream.confluent.schemas import TopicPartition from faststream.exceptions import SetupError @@ -324,15 +324,13 @@ def topics_to_create(self) -> list[str]: async def start(self) -> None: """Starts the Kafka consumer and subscribes to the specified topics.""" - loop = asyncio.get_running_loop() - if self.allow_auto_create_topics: - await loop.run_in_executor( + await run_in_executor( self._thread_pool, create_topics, - self.topics_to_create, - self.config, - self.logger_state.logger.logger, + topics=self.topics_to_create, + config=self.config, + logger_=self.logger_state.logger.logger, ) else: @@ -342,15 +340,15 @@ async def start(self) -> None: ) if self.topics: - await loop.run_in_executor( - self._thread_pool, self.consumer.subscribe, self.topics + await run_in_executor( + self._thread_pool, self.consumer.subscribe, topics=self.topics ) elif self.partitions: - await loop.run_in_executor( + await run_in_executor( self._thread_pool, self.consumer.assign, - [p.to_confluent() for p in self.partitions], + partitions=[p.to_confluent() for p in self.partitions], ) else: @@ -359,19 +357,10 @@ async def start(self) -> None: async def commit(self, asynchronous: bool = True) -> None: """Commits the offsets of all messages returned by the last poll operation.""" - loop = asyncio.get_running_loop() - await loop.run_in_executor( - self._thread_pool, - self.consumer.commit, - None, - None, - asynchronous, - ) + await run_in_executor(self._thread_pool, self.consumer.commit, asynchronous=asynchronous) async def stop(self) -> None: """Stops the Kafka consumer and releases all resources.""" - loop = asyncio.get_running_loop() - # NOTE: If we don't explicitly call commit and then close the consumer, the confluent consumer gets stuck. # We are doing this to avoid the issue. enable_auto_commit = self.config["enable.auto.commit"] @@ -391,14 +380,13 @@ async def stop(self) -> None: ) # Wrap calls to async to make method cancelable by timeout - await loop.run_in_executor(self._thread_pool, self.consumer.close) + await run_in_executor(self._thread_pool, self.consumer.close) self._thread_pool.shutdown(wait=False) async def getone(self, timeout: float = 0.1) -> Optional[Message]: """Consumes a single message from Kafka.""" - loop = asyncio.get_running_loop() - msg = await loop.run_in_executor(self._thread_pool, self.consumer.poll, timeout) + msg = await run_in_executor(self._thread_pool, self.consumer.poll, timeout) return check_msg_error(msg) async def getmany( @@ -407,26 +395,22 @@ async def getmany( max_records: Optional[int] = 10, ) -> tuple[Message, ...]: """Consumes a batch of messages from Kafka and groups them by topic and partition.""" - loop = asyncio.get_running_loop() - raw_messages: list[Optional[Message]] = await loop.run_in_executor( + raw_messages: list[Optional[Message]] = await run_in_executor( self._thread_pool, self.consumer.consume, # type: ignore[arg-type] - max_records or 10, - timeout, + num_messages=max_records or 10, + timeout=timeout, ) return tuple(x for x in map(check_msg_error, raw_messages) if x is not None) async def seek(self, topic: str, partition: int, offset: int) -> None: """Seeks to the specified offset in the specified topic and partition.""" - loop = asyncio.get_running_loop() topic_partition = TopicPartition( topic=topic, partition=partition, offset=offset, ) - await loop.run_in_executor( - self._thread_pool, self.consumer.seek, topic_partition.to_confluent() - ) + await run_in_executor(self._thread_pool, self.consumer.seek, topic_partition.to_confluent()) def check_msg_error(msg: Optional[Message]) -> Optional[Message]: