Skip to content

Commit

Permalink
Add support for psycopg and asyncpg drivers
Browse files Browse the repository at this point in the history
This introduces the `crate+psycopg://`, `crate+asyncpg://`, and
`crate+urllib3://` dialect identifiers. The asynchronous variant of
`psycopg` is also supported.
  • Loading branch information
amotl committed Nov 4, 2024
1 parent af10a1c commit dd73fd7
Show file tree
Hide file tree
Showing 6 changed files with 273 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Changelog

## Unreleased
- Added support for `psycopg` and `asyncpg` drivers, by introducing the
`crate+psycopg://`, `crate+asyncpg://`, and `crate+urllib3://` dialect
identifiers. The asynchronous variant of `psycopg` is also supported.

## 2024/11/04 0.40.1
- CI: Verified support on Python 3.13
Expand Down
12 changes: 10 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ dependencies = [
"verlib2==0.2",
]
optional-dependencies.all = [
"sqlalchemy-cratedb[vector]",
"sqlalchemy-cratedb[postgresql,vector]",
]
optional-dependencies.develop = [
"mypy<1.14",
Expand All @@ -102,6 +102,9 @@ optional-dependencies.doc = [
"crate-docs-theme>=0.26.5",
"sphinx>=3.5,<9",
]
optional-dependencies.postgresql = [
"sqlalchemy-postgresql-relaxed",
]
optional-dependencies.release = [
"build<2",
"twine<6",
Expand All @@ -112,6 +115,7 @@ optional-dependencies.test = [
"pandas<2.3",
"pueblo>=0.0.7",
"pytest<9",
"pytest-asyncio<0.24",
"pytest-cov<7",
"pytest-mock<4",
]
Expand All @@ -122,7 +126,11 @@ urls.changelog = "https://github.com/crate/sqlalchemy-cratedb/blob/main/CHANGES.
urls.documentation = "https://cratedb.com/docs/sqlalchemy-cratedb/"
urls.homepage = "https://cratedb.com/docs/sqlalchemy-cratedb/"
urls.repository = "https://github.com/crate/sqlalchemy-cratedb"
entry-points."sqlalchemy.dialects".crate = "sqlalchemy_cratedb:dialect"
entry-points."sqlalchemy.dialects"."crate" = "sqlalchemy_cratedb:dialect"
entry-points."sqlalchemy.dialects"."crate.asyncpg" = "sqlalchemy_cratedb.dialect_more:dialect_asyncpg"
entry-points."sqlalchemy.dialects"."crate.psycopg" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg"
entry-points."sqlalchemy.dialects"."crate.psycopg_async" = "sqlalchemy_cratedb.dialect_more:dialect_psycopg_async"
entry-points."sqlalchemy.dialects"."crate.urllib3" = "sqlalchemy_cratedb.dialect_more:dialect_urllib3"

[tool.black]
line-length = 100
Expand Down
50 changes: 43 additions & 7 deletions src/sqlalchemy_cratedb/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import logging
from datetime import date, datetime
from types import ModuleType

from sqlalchemy import types as sqltypes
from sqlalchemy.engine import default, reflection
Expand Down Expand Up @@ -212,6 +213,12 @@ def initialize(self, connection):
# get default schema name
self.default_schema_name = self._get_default_schema_name(connection)

def set_isolation_level(self, dbapi_connection, level):
"""
For CrateDB, this is implemented as a noop.
"""
pass

def do_rollback(self, connection):
# if any exception is raised by the dbapi, sqlalchemy by default
# attempts to do a rollback crate doesn't support rollbacks.
Expand All @@ -230,7 +237,21 @@ def connect(self, host=None, port=None, *args, **kwargs):
use_ssl = asbool(kwargs.pop("ssl", False))
if use_ssl:
servers = ["https://" + server for server in servers]
return self.dbapi.connect(servers=servers, **kwargs)

is_module = isinstance(self.dbapi, ModuleType)
if is_module:
driver_name = self.dbapi.__name__
else:
driver_name = self.dbapi.__class__.__name__
if driver_name == "crate.client":
if "database" in kwargs:
del kwargs["database"]

Check warning on line 248 in src/sqlalchemy_cratedb/dialect.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect.py#L248

Added line #L248 was not covered by tests
return self.dbapi.connect(servers=servers, **kwargs)
elif driver_name in ["psycopg", "PsycopgAdaptDBAPI", "AsyncAdapt_asyncpg_dbapi"]:
return self.dbapi.connect(host=host, port=port, **kwargs)
else:
raise ValueError(f"Unknown driver variant: {driver_name}")

Check warning on line 253 in src/sqlalchemy_cratedb/dialect.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect.py#L253

Added line #L253 was not covered by tests

return self.dbapi.connect(**kwargs)

def do_execute(self, cursor, statement, parameters, context=None):
Expand Down Expand Up @@ -300,10 +321,12 @@ def get_table_names(self, connection, schema=None, **kw):
if schema is None:
schema = self._get_effective_schema_name(connection)
cursor = connection.exec_driver_sql(
"SELECT table_name FROM information_schema.tables "
"WHERE {0} = ? "
"AND table_type = 'BASE TABLE' "
"ORDER BY table_name ASC, {0} ASC".format(self.schema_column),
self._format_query(
"SELECT table_name FROM information_schema.tables "
"WHERE {0} = ? "
"AND table_type = 'BASE TABLE' "
"ORDER BY table_name ASC, {0} ASC"
).format(self.schema_column),
(schema or self.default_schema_name,),
)
return [row[0] for row in cursor.fetchall()]
Expand All @@ -326,7 +349,7 @@ def get_columns(self, connection, table_name, schema=None, **kw):
"AND column_name !~ ?".format(self.schema_column)
)
cursor = connection.exec_driver_sql(
query,
self._format_query(query),
(
table_name,
schema or self.default_schema_name,
Expand Down Expand Up @@ -366,7 +389,9 @@ def result_fun(result):
rows = result.fetchone()
return set(rows[0] if rows else [])

pk_result = engine.exec_driver_sql(query, (table_name, schema or self.default_schema_name))
pk_result = engine.exec_driver_sql(
self._format_query(query), (table_name, schema or self.default_schema_name)
)
pks = result_fun(pk_result)
return {"constrained_columns": sorted(pks), "name": "PRIMARY KEY"}

Expand Down Expand Up @@ -405,6 +430,17 @@ def has_ilike_operator(self):
server_version_info = self.server_version_info
return server_version_info is not None and server_version_info >= (4, 1, 0)

def _format_query(self, query):
"""
When using the PostgreSQL protocol with drivers `psycopg` or `asyncpg`,
the paramstyle is not `qmark`, but `pyformat`.
TODO: Review: Is it legit and sane? Are there alternatives?
"""
if self.paramstyle == "pyformat":
query = query.replace("= ?", "= %s").replace("!~ ?", "!~ %s")

Check warning on line 441 in src/sqlalchemy_cratedb/dialect.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect.py#L441

Added line #L441 was not covered by tests
return query


class DateTrunc(functions.GenericFunction):
name = "date_trunc"
Expand Down
106 changes: 106 additions & 0 deletions src/sqlalchemy_cratedb/dialect_more.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# -*- coding: utf-8; -*-
#
# Licensed to CRATE Technology GmbH ("Crate") under one or more contributor
# license agreements. See the NOTICE file distributed with this work for
# additional information regarding copyright ownership. Crate licenses
# this file to you under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. You may
# obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
#
# However, if you have executed another commercial license agreement
# with Crate these terms will supersede the license and you may use the
# software solely pursuant to the terms of the relevant commercial agreement.
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy_postgresql_relaxed.asyncpg import PGDialect_asyncpg_relaxed
from sqlalchemy_postgresql_relaxed.base import PGDialect_relaxed
from sqlalchemy_postgresql_relaxed.psycopg import (
PGDialect_psycopg_relaxed,
PGDialectAsync_psycopg_relaxed,
)

from sqlalchemy_cratedb import dialect


class CrateDialectPostgresAdapter(PGDialect_relaxed, dialect):
"""
Provide a dialect on top of the relaxed PostgreSQL dialect.
"""

inspector = Inspector

# Need to manually override some methods because of polymorphic inheritance woes.
# TODO: Investigate if this can be solved using metaprogramming or other techniques.
has_schema = dialect.has_schema
has_table = dialect.has_table
get_schema_names = dialect.get_schema_names
get_table_names = dialect.get_table_names
get_view_names = dialect.get_view_names
get_columns = dialect.get_columns
get_pk_constraint = dialect.get_pk_constraint
get_foreign_keys = dialect.get_foreign_keys
get_indexes = dialect.get_indexes

get_multi_columns = dialect.get_multi_columns
get_multi_pk_constraint = dialect.get_multi_pk_constraint
get_multi_foreign_keys = dialect.get_multi_foreign_keys

# TODO: Those may want to go to dialect instead?
def get_multi_indexes(self, *args, **kwargs):
return []

Check warning on line 57 in src/sqlalchemy_cratedb/dialect_more.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect_more.py#L57

Added line #L57 was not covered by tests

def get_multi_unique_constraints(self, *args, **kwargs):
return []

Check warning on line 60 in src/sqlalchemy_cratedb/dialect_more.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect_more.py#L60

Added line #L60 was not covered by tests

def get_multi_check_constraints(self, *args, **kwargs):
return []

Check warning on line 63 in src/sqlalchemy_cratedb/dialect_more.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect_more.py#L63

Added line #L63 was not covered by tests

def get_multi_table_comment(self, *args, **kwargs):
return []

Check warning on line 66 in src/sqlalchemy_cratedb/dialect_more.py

View check run for this annotation

Codecov / codecov/patch

src/sqlalchemy_cratedb/dialect_more.py#L66

Added line #L66 was not covered by tests


class CrateDialect_psycopg(PGDialect_psycopg_relaxed, CrateDialectPostgresAdapter):
driver = "psycopg"

@classmethod
def get_async_dialect_cls(cls, url):
return CrateDialectAsync_psycopg

@classmethod
def import_dbapi(cls):
import psycopg

return psycopg


class CrateDialectAsync_psycopg(PGDialectAsync_psycopg_relaxed, CrateDialectPostgresAdapter):
driver = "psycopg_async"
is_async = True


class CrateDialect_asyncpg(PGDialect_asyncpg_relaxed, CrateDialectPostgresAdapter):
driver = "asyncpg"

# TODO: asyncpg may have `paramstyle="numeric_dollar"`. Review this!

# TODO: AttributeError: module 'asyncpg' has no attribute 'paramstyle'
"""
@classmethod
def import_dbapi(cls):
import asyncpg
return asyncpg
"""


dialect_urllib3 = dialect
dialect_psycopg = CrateDialect_psycopg
dialect_psycopg_async = CrateDialectAsync_psycopg
dialect_asyncpg = CrateDialect_asyncpg
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ def cratedb_service():
Provide a CrateDB service instance to the test suite.
"""
db = CrateDBTestAdapter()
db.start()
db.start(ports={4200: None, 5432: None})
yield db
db.stop()
110 changes: 110 additions & 0 deletions tests/engine_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pytest
import sqlalchemy as sa
from sqlalchemy.dialects import registry as dialect_registry

from sqlalchemy_cratedb.sa_version import SA_2_0, SA_VERSION

if SA_VERSION < SA_2_0:
raise pytest.skip("Only supported on SQLAlchemy 2.0 and higher", allow_module_level=True)

from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

# Registering the additional dialects manually seems to be needed when running
# under tests. Apparently, manual registration is not needed under regular
# circumstances, as this is wired through the `sqlalchemy.dialects` entrypoint
# registrations in `pyproject.toml`. It is definitively weird, but c'est la vie.
dialect_registry.register("crate.urllib3", "sqlalchemy_cratedb.dialect_more", "dialect_urllib3")
dialect_registry.register("crate.asyncpg", "sqlalchemy_cratedb.dialect_more", "dialect_asyncpg")
dialect_registry.register("crate.psycopg", "sqlalchemy_cratedb.dialect_more", "dialect_psycopg")


QUERY = sa.text("SELECT mountain, coordinates FROM sys.summits ORDER BY mountain LIMIT 3;")


def test_engine_sync_vanilla(cratedb_service):
"""
crate:// -- Verify connectivity and data transport with vanilla HTTP-based driver.
"""
port4200 = cratedb_service.cratedb.get_exposed_port(4200)
engine = sa.create_engine(f"crate://crate@localhost:{port4200}/", echo=True)
assert isinstance(engine, sa.engine.Engine)
with engine.connect() as connection:
result = connection.execute(QUERY)
assert result.mappings().fetchone() == {
"mountain": "Acherkogel",
"coordinates": [10.95667, 47.18917],
}


def test_engine_sync_urllib3(cratedb_service):
"""
crate+urllib3:// -- Verify connectivity and data transport *explicitly* selecting the HTTP driver.
""" # noqa: E501
port4200 = cratedb_service.cratedb.get_exposed_port(4200)
engine = sa.create_engine(
f"crate+urllib3://crate@localhost:{port4200}/", isolation_level="AUTOCOMMIT", echo=True
)
assert isinstance(engine, sa.engine.Engine)
with engine.connect() as connection:
result = connection.execute(QUERY)
assert result.mappings().fetchone() == {
"mountain": "Acherkogel",
"coordinates": [10.95667, 47.18917],
}


def test_engine_sync_psycopg(cratedb_service):
"""
crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3).
"""
port5432 = cratedb_service.cratedb.get_exposed_port(5432)
engine = sa.create_engine(
f"crate+psycopg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True
)
assert isinstance(engine, sa.engine.Engine)
with engine.connect() as connection:
result = connection.execute(QUERY)
assert result.mappings().fetchone() == {
"mountain": "Acherkogel",
"coordinates": "(10.95667,47.18917)",
}


@pytest.mark.asyncio
async def test_engine_async_psycopg(cratedb_service):
"""
crate+psycopg:// -- Verify connectivity and data transport using the psycopg driver (version 3).
This time, in asynchronous mode.
"""
port5432 = cratedb_service.cratedb.get_exposed_port(5432)
engine = create_async_engine(
f"crate+psycopg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True
)
assert isinstance(engine, AsyncEngine)
async with engine.begin() as conn:
result = await conn.execute(QUERY)
assert result.mappings().fetchone() == {
"mountain": "Acherkogel",
"coordinates": "(10.95667,47.18917)",
}


@pytest.mark.asyncio
async def test_engine_async_asyncpg(cratedb_service):
"""
crate+asyncpg:// -- Verify connectivity and data transport using the asyncpg driver.
This exclusively uses asynchronous mode.
"""
port5432 = cratedb_service.cratedb.get_exposed_port(5432)
from asyncpg.pgproto.types import Point

engine = create_async_engine(
f"crate+asyncpg://crate@localhost:{port5432}/", isolation_level="AUTOCOMMIT", echo=True
)
assert isinstance(engine, AsyncEngine)
async with engine.begin() as conn:
result = await conn.execute(QUERY)
assert result.mappings().fetchone() == {
"mountain": "Acherkogel",
"coordinates": Point(10.95667, 47.18917),
}

0 comments on commit dd73fd7

Please sign in to comment.