Skip to content

Commit

Permalink
resolves #27 port field in creds object (#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
wahani authored May 10, 2020
1 parent 39fa85a commit 1a31358
Show file tree
Hide file tree
Showing 6 changed files with 156 additions and 79 deletions.
9 changes: 9 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,14 @@
- see #32 for bugfix in send_query for empty result sets

## Version 1.3.15
- dbrequests:
- see #27 for bugfix in Database class when specifiying a port in a
credentials object.
- the argument 'creds' in the init method of a database class is now
deprecated
- the argument 'db_url' can now handle str and dict type; str is a
sqlalchemy url; a dict a credentials object
- credential objects can now have additional fields which will be used as
elements in connect_args for sqlalchemies create_engine: see #12
- dbrequests.mysql
- see #36 for bugfix while sending an empty frame
112 changes: 72 additions & 40 deletions dbrequests/database.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,89 @@
import os
import warnings
from contextlib import contextmanager

from pandas import DataFrame
from sqlalchemy import create_engine, exc, inspect

from .connection import Connection as DefaultConnection
from .connection import Connection
from .query import Query


class Database(object):
"""A Database. Encapsulates a url and an SQLAlchemy engine with a pool of
connections.
The url to the database can be provided directly or via a credentials-
dictionary `creds` with keys:
- host
- db
- user
- password
- dialect (defaults to mysql)
- driver (defaults to pymysql)
"""
Provides useful methods to send and retrieve data as DataFrame. Manages
opening and closing connections.
- db_url: (str|None|dict):
- str: a sqlalchemy url
- None: using a sqlalchemy url from the environment variable
'DATABASE_URL'
- dict: a dict with credentials and connect_args
- user
- password
- host (defaults to 127.0.0.1)
- port (defaults to 3306)
- db
- dialect (defaults to mysql) -
- driver (defaults to pymysql)
- ...: further fields are added to 'connect_args'
- sql_dir: (str|None) directory where to look for sql queries. Defaults to
'.'.
- escape_percentage: (bool) escape percentages when reading queries from a
file.
- remove_comments: (bool) remove comments when reading queries from a file.
- kwargs:
- creds: (dict) deprecated, provide a dict as db_url
- ...: all arguments are passed to sqlalchemy.create_engine
"""

_connection_class = Connection

def __init__(self, db_url=None, creds=None, sql_dir=None, connection_class=DefaultConnection,
def __init__(self, db_url=None, sql_dir=None,
escape_percentage=False, remove_comments=False, **kwargs):
# If no db_url was provided, fallback to $DATABASE_URL or creds.
self.db_url = db_url or os.environ.get('DATABASE_URL')
self.sql_dir = sql_dir or os.getcwd()
if not self.db_url:
try:
user = creds['user']
password = creds['password']
host = creds['host']
db = creds['db']
dialect = creds.get('dialect', 'mysql')
driver = creds.get('driver', 'pymysql')
self.db_url = '{}+{}://{}:{}@{}/{}'.format(
dialect, driver, user, password, host, db)
except:
raise ValueError('You must provide a db_url or proper creds.')

self.sql_dir = sql_dir or os.getcwd()
self._escape_percentage = escape_percentage
self._remove_comments = remove_comments
self._engine = create_engine(self.db_url, **kwargs)
kwargs = self._init_db_url(db_url, **kwargs)
self._init_engine(**kwargs)
self._open = True
self.connection_class = connection_class

def _init_db_url(self, db_url, **kwargs):
if db_url is None:
db_url = os.environ.get('DATABASE_URL')
if db_url is None:
db_url = kwargs.pop('creds', None)
if db_url is not None:
warnings.warn(
"Parameter 'creds' is depreacated in favor of db_url.",
DeprecationWarning)
else:
raise ValueError('db_url is missing')
if isinstance(db_url, str):
self.db_url = db_url
elif isinstance(db_url, dict):
db_url = db_url.copy()
self.db_url = '{}+{}://{}:{}@{}:{}/{}'.format(
db_url.pop('dialect', 'mysql'),
db_url.pop('driver', 'pymysql'),
db_url.pop('user'),
db_url.pop('password'),
db_url.pop('host', '127.0.0.1'),
db_url.pop('port', 3306),
db_url.pop('db'))
connect_args = kwargs.pop('connect_args', {})
connect_args.update(db_url)
if len(connect_args):
kwargs['connect_args'] = connect_args
else:
raise ValueError('db_url has to be a str or dict')
return kwargs

def _init_engine(self, **kwargs):
# We have this method, so that subclasses may override the init
# process.
self._engine = create_engine(self.db_url, **kwargs)

def close(self):
"""Close the connection."""
Expand All @@ -62,18 +101,13 @@ def __repr__(self):

def get_table_names(self):
"""Returns a list of table names for the connected database."""

# Setup SQLAlchemy for Database inspection.
return inspect(self._engine).get_table_names()

def get_connection(self):
"""Get a connection to this Database. Connections are retrieved from a
pool.
"""
"""Get a connection from the sqlalchemy engine."""
if not self._open:
raise exc.ResourceClosedError('Database closed.')

return self.connection_class(self._engine.connect())
return self._connection_class(self._engine.connect())

def __get_query_text(self, query, escape_percentage, remove_comments, **params):
"""Private wrapper for accessing the text of the query."""
Expand Down Expand Up @@ -114,7 +148,7 @@ def send_bulk_query(self, query, escape_percentage=None, remove_comments=None, *
query, escape_percentage, remove_comments, **params)
return self.bulk_query(text, **params)

def send_data(self, df, table, mode='insert', **params):
def send_data(self, df: DataFrame, table, mode='insert', **params):
"""Sends data to table in database. If the table already exists, different modes of
insertion are provided.
Expand All @@ -127,8 +161,6 @@ def send_data(self, df, table, mode='insert', **params):
- 'replace': replaces duplicate primary keys
- 'update': updates duplicate primary keys
"""
if not isinstance(df, DataFrame):
raise TypeError('df has to be a pandas DataFrame.')
with self.transaction() as conn:
return conn.send_data(df, table, mode, **params)

Expand Down
56 changes: 23 additions & 33 deletions dbrequests/mysql/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,41 +26,37 @@ class Database(SuperDatabase):
the mysqldb driver, which can be 10x faster than the pymysql driver.
"""

def __init__(self, db_url=None, creds=None, sql_dir=None,
escape_percentage=False, remove_comments=False, **kwargs):
_connection_class = MysqlConnection

def _init_engine(self, **kwargs):
connect_args = kwargs.pop('connect_args', {})
# This option is needed for send data via csv: #20
connect_args['local_infile'] = connect_args.get('local_infile', 1)
# This option is needed for memory efficient send query: #22
# mysqldb can be difficult to install, so we also support
# pymysql. Depending on the driver we pick the apropriate cursorclass.
connect_args['cursorclass'] = connect_args.get(
'cursorclass', self._pick_cursorclass(db_url, creds))
super().__init__(db_url=db_url, creds=creds, sql_dir=sql_dir,
connection_class=MysqlConnection,
escape_percentage=escape_percentage,
remove_comments=remove_comments,
connect_args=connect_args, **kwargs)
'cursorclass', self._pick_cursorclass(self.db_url))
super()._init_engine(connect_args=connect_args, **kwargs)

def send_data(self, df, table, mode='insert', **params):
"""Sends df to table in database.
Args:
- df (DataFrame): internally we use datatable Frame. Any object
that can be converted to a Frame may be supplied.
- table_name (str): Name of the table.
- mode ({'insert', 'truncate', 'replace',
'update'}): Mode of Data Insertion. Defaults to 'insert'.
- 'insert': appends data. Duplicates in the
primary keys are not replaced.
- 'truncate': drop the table, recreate it, then insert. No
rollback on error.
- 'delete': delete all rows in the table, then insert. This
operation can be rolled back on error, but can be very
expensive.
- 'replace': replaces (delete, then insert) duplicate primary
keys.
- 'update': updates duplicate primary keys
- df (DataFrame): internally we use datatable Frame. Any object
that can be converted to a Frame may be supplied.
- table_name (str): Name of the table.
- mode ({'insert', 'truncate', 'replace',
'update'}): Mode of Data Insertion. Defaults to 'insert'.
- 'insert': appends data. Duplicates in the
primary keys are not replaced.
- 'truncate': drop the table, recreate it, then insert. No
rollback on error.
- 'delete': delete all rows in the table, then insert. This
operation can be rolled back on error, but can be very
expensive.
- 'replace': replaces (delete, then insert) duplicate primary
keys.
- 'update': insert but with update on duplicate primary keys
"""
if not isinstance(df, Frame):
df = Frame(df)
Expand All @@ -71,21 +67,15 @@ def send_data(self, df, table, mode='insert', **params):
return conn.send_data(df, table, mode, **params)

@staticmethod
def _pick_cursorclass(url, creds):
def _pick_cursorclass(url):
"""
Pick the SSCursor for the defined driver in url or creds.
Pick the SSCursor for the defined driver in url.
We can easily extract the driver from the sqlalchemy.engine. BUT: we
want to pass the cursorclass to the create_engine function and hence
need to extract it beforhand.
"""
if creds:
driver = creds.get('driver', 'pymysql')
elif url:
driver = re.findall(r'mysqldb|pymysql', url)[0]
else:
raise ValueError(
'Please provide either a valid db_url or creds object.')
driver = re.findall(r'mysqldb|pymysql', url)[0]
if driver == 'mysqldb':
from MySQLdb.cursors import SSCursor
else:
Expand Down
27 changes: 22 additions & 5 deletions dbrequests/mysql/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
CREDS = {
'user': 'root',
'password': 'root',
'host': '0.0.0.0:3307',
'host': '0.0.0.0',
'db': 'test',
'port': 3307
}
Expand All @@ -20,9 +20,28 @@
@pytest.yield_fixture(scope='module', params=['pymysql', 'mysqldb'])
def db(request):
"""Create instances of database connections."""
creds = CREDS
creds = CREDS.copy()
creds['driver'] = request.param
db = Database(creds=creds)
db = Database(creds)
try:
yield db
except BaseException as error:
raise error
finally:
db.close()


@pytest.fixture(scope="module")
def db_connect_args(request):
"""Create instance with connect args."""
creds = CREDS.copy()
creds['driver'] = 'pymysql'
# switch local_infile off so we can see that:
# - override of defaults work
# - connect_args is appended from creds object
# - we expect working send_query and failing send_data
creds['local_infile'] = 0
db = Database(creds)
try:
yield db
except BaseException as error:
Expand All @@ -33,7 +52,6 @@ def db(request):

def run_docker_container():
"""Run mariadb-docker container and return proper url for access."""

client = from_env()
try:
container = client.containers.run(
Expand All @@ -49,7 +67,6 @@ def run_docker_container():
time.sleep(60)
except APIError:
container = client.containers.get('test-mariadb-database')

return container


Expand Down
29 changes: 29 additions & 0 deletions dbrequests/mysql/tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""Testing connection bahaviours."""

import pandas as pd
import pytest
from dbrequests.mysql.tests.conftest import set_up_cats as reset
from sqlalchemy.exc import InternalError


@pytest.mark.usefixtures('db_connect_args')
class TestConnectionWithConnectArgs:
"""Test that passing on connect_args works for credentials."""

def test_send_query(self, db_connect_args):
"""Test that we have working connection."""
reset(db_connect_args)
res = db_connect_args.send_query('select 1 as x')
assert res.shape == (1, 1)

def test_send_data(self, db_connect_args):
"""The connection has local_infile set to 0, so we expect an error."""
df_add = pd.DataFrame({
'name': ['Chill'],
'owner': ['Alex'],
'birth': ['2018-03-03']
})

reset(db_connect_args)
with pytest.raises(InternalError):
db_connect_args.send_data(df_add, 'cats', mode='insert')
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def run(self):

requires = ['SQLAlchemy;python_version>="3.0"',
'pandas']
version = '1.3.14'
version = '1.3.15'


def read(f):
Expand Down

0 comments on commit 1a31358

Please sign in to comment.