Skip to content

Commit

Permalink
feat(python): support read_database options passthrough to the unde…
Browse files Browse the repository at this point in the history
…rlying connection's `execute` method (enables parameterised SQL queries, etc) (#11562)
  • Loading branch information
alexander-beedie authored Oct 7, 2023
1 parent 4d0e54a commit c01d599
Show file tree
Hide file tree
Showing 2 changed files with 200 additions and 85 deletions.
58 changes: 46 additions & 12 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from polars.convert import from_arrow
from polars.exceptions import UnsuitableSQLError
from polars.utils.deprecation import (
deprecate_nonkeyword_arguments,
deprecate_renamed_parameter,
issue_deprecation_warning,
)
Expand Down Expand Up @@ -94,13 +93,15 @@ class ODBCCursorProxy:

def __init__(self, connection_string: str) -> None:
self.connection_string = connection_string
self.execute_options: dict[str, Any] = {}
self.query: str | None = None

def close(self) -> None:
"""Close the cursor (n/a: nothing to close)."""

def execute(self, query: str) -> None:
def execute(self, query: str, **execute_options: Any) -> None:
"""Execute a query (n/a: just store query for the fetch* methods)."""
self.execute_options = execute_options
self.query = query

def fetchmany(
Expand All @@ -113,8 +114,10 @@ def fetchmany(
query=self.query,
batch_size=batch_size,
connection_string=self.connection_string,
**self.execute_options,
)

# internally arrow-odbc always reads batches
fetchall = fetchmany


Expand Down Expand Up @@ -257,9 +260,12 @@ def _from_rows(
)
return None

@deprecate_nonkeyword_arguments(allowed_args=["self", "query"], version="0.19.3")
def execute(
self, query: str | Selectable, select_queries_only: bool = True # noqa: FBT001
self,
query: str | Selectable,
*,
options: dict[str, Any] | None = None,
select_queries_only: bool = True,
) -> Self:
"""Execute a query and reference the result set."""
if select_queries_only and isinstance(query, str):
Expand All @@ -274,7 +280,7 @@ def execute(

query = text(query) # type: ignore[assignment]

if (result := self.cursor.execute(query)) is None:
if (result := self.cursor.execute(query, **(options or {}))) is None:
result = self.cursor # some cursors execute in-place

self.result = result
Expand Down Expand Up @@ -306,12 +312,13 @@ def to_frame(


@deprecate_renamed_parameter("connection_uri", "connection", version="0.18.9")
def read_database( # noqa: D417
def read_database( # noqa D417
query: str | Selectable,
connection: ConnectionOrCursor | str,
*,
batch_size: int | None = None,
schema_overrides: SchemaDict | None = None,
execute_options: dict[str, Any] | None = None,
**kwargs: Any,
) -> DataFrame:
"""
Expand All @@ -324,8 +331,9 @@ def read_database( # noqa: D417
be a suitable "Selectable", otherwise it is expected to be a string).
connection
An instantiated connection (or cursor/client object) that the query can be
executed against. Can also pass a valid ODBC connection string here, if you
have installed the ``arrow-odbc`` driver/package.
executed against. Can also pass a valid ODBC connection string, starting with
"Driver=", in which case the ``arrow-odbc`` package will be used to establish
the connection and return Arrow-native data to Polars.
batch_size
Enable batched data fetching (internally) instead of collecting all rows at
once; this can be helpful for minimising the peak memory used for very large
Expand All @@ -341,6 +349,11 @@ def read_database( # noqa: D417
on driver/backend). This can be useful if the given types can be more precisely
defined (for example, if you know that a given column can be declared as `u32`
instead of `i64`).
execute_options
These options will be passed through into the underlying query execution method
as kwargs. In the case of connections made using an ODBC string (which use
`arrow-odbc`) these options are passed to the ``read_arrow_batches_from_odbc``
method.
Notes
-----
Expand All @@ -356,7 +369,7 @@ def read_database( # noqa: D417
more details about using this driver (notable databases implementing Flight SQL
include Dremio and InfluxDB).
* The ``read_connection_uri`` function is likely to be noticeably faster than
* The ``read_database_uri`` function is likely to be noticeably faster than
``read_database`` if you are using a SQLAlchemy or DBAPI2 connection, as
``connectorx`` will optimise translation of the result set into Arrow format
in Rust, whereas these libraries will return row-wise data to Python *before*
Expand All @@ -381,15 +394,26 @@ def read_database( # noqa: D417
... schema_overrides={"normalised_score": pl.UInt8},
... ) # doctest: +SKIP
Instantiate a DataFrame using an ODBC connection string (requires ``arrow-odbc``):
Use a parameterised SQLAlchemy query, passing values via ``execute_options``:
>>> df = pl.read_database(
... query="SELECT * FROM test_data WHERE metric > :value",
... connection=alchemy_conn,
... execute_options={"parameters": {"value": 0}},
... ) # doctest: +SKIP
Instantiate a DataFrame using an ODBC connection string (requires ``arrow-odbc``)
and set upper limits on the buffer size of variadic text/binary columns:
>>> df = pl.read_database(
... query="SELECT * FROM test_data",
... connection="Driver={PostgreSQL};Server=localhost;Port=5432;Database=test;Uid=usr;Pwd=",
... execute_options={"max_text_size": 512, "max_binary_size": 1024},
... ) # doctest: +SKIP
""" # noqa: W505
if isinstance(connection, str):
# check for odbc connection string
if re.sub(r"\s", "", connection[:20]).lower().startswith("driver="):
try:
import arrow_odbc # noqa: F401
Expand All @@ -401,6 +425,7 @@ def read_database( # noqa: D417

connection = ODBCCursorProxy(connection)
else:
# otherwise looks like a call to read_database_uri
issue_deprecation_warning(
message="Use of a string URI with 'read_database' is deprecated; use 'read_database_uri' instead",
version="0.19.0",
Expand All @@ -410,16 +435,25 @@ def read_database( # noqa: D417
f"`read_database_uri` expects one or more string queries; found {type(query)}"
)
return read_database_uri(
query, uri=connection, schema_overrides=schema_overrides, **kwargs
query,
uri=connection,
schema_overrides=schema_overrides,
**kwargs,
)

# note: can remove this check (and **kwargs) once we drop the
# pass-through deprecation support for read_database_uri
if kwargs:
raise ValueError(
f"`read_database` **kwargs only exist for passthrough to `read_database_uri`: found {kwargs!r}"
)

# return frame from arbitrary connections using the executor abstraction
with ConnectionExecutor(connection) as cx:
return cx.execute(query).to_frame(
return cx.execute(
query=query,
options=execute_options,
).to_frame(
batch_size=batch_size,
schema_overrides=schema_overrides,
)
Expand Down
Loading

0 comments on commit c01d599

Please sign in to comment.