From c2adb6685a65f48604202682febf39e336a68a7b Mon Sep 17 00:00:00 2001 From: cutz Date: Tue, 4 Feb 2020 13:14:49 -0600 Subject: [PATCH] add a gevent+psycopg2 dialect for sqlalchemy --- setup.py | 2 +- src/nti/app/environments/__init__.py | 3 ++ src/nti/app/environments/_monkey.py | 39 ++++++++++++++ src/nti/app/environments/tests/test_monkey.py | 51 +++++++++++++++++++ 4 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 src/nti/app/environments/_monkey.py create mode 100644 src/nti/app/environments/tests/test_monkey.py diff --git a/setup.py b/setup.py index 9c408038..d074fde5 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ 'zope.container', 'zope.site', 'zope.generations', - 'RelStorage==3.0b3', + 'RelStorage', 'zc.zlibstorage', 'nti.i18n', 'z3c.schema', diff --git a/src/nti/app/environments/__init__.py b/src/nti/app/environments/__init__.py index e011be4e..b8cf60fb 100644 --- a/src/nti/app/environments/__init__.py +++ b/src/nti/app/environments/__init__.py @@ -3,6 +3,9 @@ import gevent.monkey gevent.monkey.patch_all() +from nti.app.environments._monkey import patch +patch() + from zope.component import getGlobalSiteManager import zope.i18nmessageid as zope_i18nmessageid diff --git a/src/nti/app/environments/_monkey.py b/src/nti/app/environments/_monkey.py new file mode 100644 index 00000000..90879a5f --- /dev/null +++ b/src/nti/app/environments/_monkey.py @@ -0,0 +1,39 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +When RelStorage is installed with psycopg2 and gevent a global wait callback is installed. +That wait callback expects the Relstorage custom connection which creates issues +when trying to use psycopg2 with sqlalchmey. + +Install a custom dialect that uses the RelStorage connection for sqlalchemy. +""" + +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +logger = __import__('logging').getLogger(__name__) + +try: + from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2 + from relstorage.adapters.postgresql.drivers.psycopg2 import GeventPsycopg2Driver + import psycopg2 +except ImportError: + pass +else: + class geventPostgresclient_dialect(PGDialect_psycopg2): + driver = "gevent+postgres" + + def __init__(self, *args, **kwargs): + super(geventPostgresclient_dialect, self).__init__(*args, **kwargs) + self._conn_class = GeventPsycopg2Driver().connect + + def connect(self, *args, **kwargs): + kwargs['connection_factory'] = self._conn_class + return psycopg2.connect(*args, **kwargs) + + from sqlalchemy.dialects import registry + registry.register("gevent.postgres", __name__, "geventPostgresclient_dialect") + +def patch(): + pass diff --git a/src/nti/app/environments/tests/test_monkey.py b/src/nti/app/environments/tests/test_monkey.py new file mode 100644 index 00000000..466aaabf --- /dev/null +++ b/src/nti/app/environments/tests/test_monkey.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +import fudge + +from hamcrest import not_none +from hamcrest import assert_that +from hamcrest import instance_of + +import unittest + +from sqlalchemy import create_engine + +from .._monkey import geventPostgresclient_dialect + + +class TestPatchSqlalchemy(unittest.TestCase): + + def test_postres_engine(self): + from .._monkey import patch + patch() + engine = create_engine('gevent+postgres:///testdb.db') + assert_that(engine, not_none()) + assert_that(engine.dialect, instance_of(geventPostgresclient_dialect)) + + @fudge.patch('psycopg2.connect') + def test_connect_uses_relstorage(self, pconn): + from .._monkey import patch + patch() + engine = create_engine('gevent+postgres:///testdb.db') + + class StopExecution(Exception): + pass + + # Make sure connect get gets called with the right connection_factory + # Make calling it raise to stop all the crazy first connection initialization that sqlalchemy does + # no luck trying to mock all the necessary things + pconn.expects_call().with_args(database='testdb.db', connection_factory=engine.dialect._conn_class).raises(StopExecution) + + try: + engine.connect() + except StopExecution: + pass + + + +