Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(python): Replacing nest-asyncio with greenlet in database dependencies #17811

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions py-polars/polars/io/database/_cursor_proxies.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from typing import TYPE_CHECKING, Any, Iterable

from polars.io.database._utils import _run_async
from polars.io.database._utils import _read_surreal_query_sync

if TYPE_CHECKING:
import sys
from collections.abc import Coroutine

import pyarrow as pa

Expand Down Expand Up @@ -80,16 +79,6 @@ def __init__(self, client: Any) -> None:
self.execute_options: dict[str, Any] = {}
self.query: str = None # type: ignore[assignment]

@staticmethod
async def _unpack_result(
result: Coroutine[Any, Any, list[dict[str, Any]]],
) -> Coroutine[Any, Any, list[dict[str, Any]]]:
"""Unpack the async query result."""
response = (await result)[0]
if response["status"] != "OK":
raise RuntimeError(response["result"])
return response["result"]

def close(self) -> None:
"""Close the cursor."""
# no-op; never close a user's Surreal session
Expand All @@ -103,13 +92,10 @@ def execute(self, query: str, **execute_options: Any) -> Self:

def fetchall(self) -> list[dict[str, Any]]:
"""Fetch all results (as a list of dictionaries)."""
return _run_async(
self._unpack_result(
result=self.client.query(
sql=self.query,
vars=(self.execute_options or None),
),
)
return _read_surreal_query_sync(
client=self.client,
query=self.query,
vars=(self.execute_options or None),
)

def fetchmany(self, size: int) -> list[dict[str, Any]]:
Expand Down
103 changes: 96 additions & 7 deletions py-polars/polars/io/database/_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import re
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypeVar

from polars.convert import from_arrow
from polars.dependencies import import_optional
Expand All @@ -11,9 +11,17 @@
from collections.abc import Coroutine

if sys.version_info >= (3, 10):
from typing import TypeAlias
from typing import ParamSpec, TypeAlias
else:
from typing_extensions import TypeAlias
from typing_extensions import ParamSpec, TypeAlias

if sys.version_info >= (3, 9):
from collections.abc import Callable
else:
from typing import Callable
from concurrent.futures import Future

import greenlet

from polars import DataFrame
from polars._typing import SchemaDict
Expand All @@ -23,10 +31,57 @@
except ImportError:
Selectable: TypeAlias = Any # type: ignore[no-redef]

P = ParamSpec("P")
T_co = TypeVar("T_co", covariant=True)


def _check_is_sa_greenlet(green: greenlet.greenlet) -> bool:
return getattr(green, "__sqlalchemy_greenlet_provider__", False)


def _greenlet_wait(co: Coroutine[Any, Any, T_co]) -> T_co:
"""Compatible with sqlalchemy."""
from polars.dependencies import import_optional

if TYPE_CHECKING:
from sqlalchemy import util as sa_util
else:
sa_util = import_optional("sqlalchemy.util")

return sa_util.await_only(co)


def _run_asyncio_func(
func: Callable[P, Coroutine[Any, Any, T_co]], *args: P.args, **kwargs: P.kwargs
) -> T_co:
import asyncio

return asyncio.run(func(*args, **kwargs))


def _run_async_func(
func: Callable[P, Coroutine[Any, Any, T_co]], *args: P.args, **kwargs: P.kwargs
) -> T_co:
"""Run asynchronous func as if it was synchronous."""
from concurrent.futures import ThreadPoolExecutor, wait

from polars._utils.unstable import issue_unstable_warning

def _run_async(co: Coroutine[Any, Any, Any]) -> Any:
issue_unstable_warning(
"Use of asynchronous connections is currently considered unstable "
"and unexpected issues may arise; if this happens, please report them."
)

with ThreadPoolExecutor(1) as executor:
future = executor.submit(_run_asyncio_func, func, *args, **kwargs) # type: ignore[arg-type]
wait([future], return_when="ALL_COMPLETED")
return future.result()


def _run_async(co: Coroutine[Any, Any, T_co]) -> T_co:
"""Run asynchronous code as if it was synchronous."""
import asyncio
from concurrent.futures import ThreadPoolExecutor, wait

from polars._utils.unstable import issue_unstable_warning
from polars.dependencies import import_optional
Expand All @@ -35,9 +90,20 @@ def _run_async(co: Coroutine[Any, Any, Any]) -> Any:
"Use of asynchronous connections is currently considered unstable "
"and unexpected issues may arise; if this happens, please report them."
)
nest_asyncio = import_optional("nest_asyncio")
nest_asyncio.apply()
return asyncio.run(co)

if TYPE_CHECKING:
import greenlet
else:
greenlet = import_optional("greenlet")

current = greenlet.getcurrent()
if _check_is_sa_greenlet(current):
return _greenlet_wait(co)

with ThreadPoolExecutor(1) as executor:
future: Future[T_co] = executor.submit(asyncio.run, co)
wait([future], return_when="ALL_COMPLETED")
return future.result()


def _read_sql_connectorx(
Expand Down Expand Up @@ -80,6 +146,29 @@ def _read_sql_adbc(
return from_arrow(tbl, schema_overrides=schema_overrides) # type: ignore[return-value]


def _unpack_surreal_result(
result: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Unpack the async query result."""
response = result[0]
if response["status"] != "OK":
raise RuntimeError(response["result"])
return response["result"]


async def _read_surreal_query_async(
client: Any, query: str, vars: dict[str, Any] | None
) -> list[dict[str, Any]]:
fetch = await client.query(sql=query, vars=vars)
return _unpack_surreal_result(fetch)


def _read_surreal_query_sync(
client: Any, query: str, vars: dict[str, Any] | None
) -> list[dict[str, Any]]:
return _run_async_func(_read_surreal_query_async, client, query, vars)


def _open_adbc_connection(connection_uri: str) -> Any:
driver_name = connection_uri.split(":", 1)[0].lower()

Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/meta/versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def show_versions() -> None:
gevent: 24.2.1
hvplot: 0.9.2
matplotlib: 3.8.4
nest_asyncio: 1.6.0
greenlet: 3.0.3
numpy: 1.26.4
openpyxl: 3.1.2
pandas: 2.2.2
Expand Down Expand Up @@ -72,7 +72,7 @@ def _get_dependency_info() -> dict[str, str]:
"great_tables",
"hvplot",
"matplotlib",
"nest_asyncio",
"greenlet",
"numpy",
"openpyxl",
"pandas",
Expand Down
5 changes: 2 additions & 3 deletions py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ excel = ["polars[calamine,openpyxl,xlsx2csv,xlsxwriter]"]
# Database
adbc = ["adbc-driver-manager[dbapi]", "adbc-driver-sqlite[dbapi]"]
connectorx = ["connectorx >= 0.3.2"]
sqlalchemy = ["sqlalchemy", "polars[pandas]"]
database = ["polars[adbc,connectorx,sqlalchemy]", "nest-asyncio"]
sqlalchemy = ["sqlalchemy[asyncio]", "polars[pandas]"]
database = ["polars[adbc,connectorx,sqlalchemy]", "greenlet"]

# Cloud
fsspec = ["fsspec"]
Expand Down Expand Up @@ -115,7 +115,6 @@ module = [
"kuzu",
"matplotlib.*",
"moto.server",
"nest_asyncio",
"openpyxl",
"polars.polars",
"pyarrow.*",
Expand Down
5 changes: 3 additions & 2 deletions py-polars/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ numba
backports.zoneinfo; python_version < '3.9'
tzdata; platform_system == 'Windows'
# Database
sqlalchemy
sqlalchemy[asyncio]
adbc-driver-manager; python_version >= '3.9' and platform_system != 'Windows'
adbc-driver-sqlite; python_version >= '3.9' and platform_system != 'Windows'
aiosqlite
connectorx
kuzu
nest-asyncio
greenlet
# Cloud
cloudpickle
fsspec
Expand Down Expand Up @@ -74,3 +74,4 @@ flask-cors
# Stub files
pandas-stubs
boto3-stubs
types-greenlet