Skip to content

Commit

Permalink
Add SQL utils
Browse files Browse the repository at this point in the history
  • Loading branch information
t3eHawk committed Dec 20, 2022
1 parent 323c123 commit d73d1a1
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 33 deletions.
45 changes: 16 additions & 29 deletions pydin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .utils import to_sql
from .utils import coalesce
from .utils import read_file_or_string
from .utils import sql_formatter, sql_compiler, sql_converter

from .core import Node

Expand Down Expand Up @@ -112,13 +113,13 @@ def target_date(self):
def date_from(self):
"""The beggining of the target date of this model."""
if hasattr(self, 'date_field'):
return self.converter(self.target_date.start)
return self.target_date.start

@property
def date_to(self):
"""The end of the target date of this model."""
if hasattr(self, 'date_field'):
return self.converter(self.target_date.end)
return self.target_date.end

@property
def value_field(self):
Expand Down Expand Up @@ -210,13 +211,6 @@ def format_custom_query(self, value):
else:
return value

def converter(self, value):
"""Convert the given value into the appropriate model format."""
if value and hasattr(self, '_convert'):
return self._convert(value)
else:
return value

def explain(self, parameter_name=None):
"""Get model or chosen parameter description."""
if not parameter_name:
Expand Down Expand Up @@ -986,26 +980,30 @@ def assemble(self):
table = self.get_table()
insert = table.insert()

select = self.select
columns = [sa.column(column) for column in self.describe()]
select = sa.text(self.select).columns(*columns).alias('s')

if self.key_field:
columns = [*columns, self.key_field.column]
select = sa.select(columns).select_from(select)
expression = f'{self.key_field.value} as {self.key_field.label}'
select = sql_formatter(select, expand_select=expression)

if self.date_field:
date_column = sa.column(self.date_field)
between = sa.between(date_column, self.date_from, self.date_to)
select = select.where(between)
date_from = sql_converter(self.date_from, self.db)
date_to = sql_converter(self.date_to, self.db)
between = f'{self.date_field} between {date_from} and {date_to}'
select = sql_formatter(select, expand_where=between)

if self.value_field:
last_value = self.get_last_value()
value_column = sa.column(self.value_field)
select = select.where(value_column > last_value)
if last_value:
comparison = f'{self.value_field} > {last_value}'
select = sql_formatter(select, make_subquery=True)
select = sql_formatter(select, expand_where=comparison)

select = sa.text(f'\n{select}').columns(*columns)
query = insert.from_select(columns, select)
query = query.compile(compile_kwargs=self.db.ckwargs)
query = str(query)
query = sql_compiler(query, self.db)
return query

def parse(self):
Expand Down Expand Up @@ -1072,17 +1070,6 @@ def endlog(self, output_rows=None, output_text=None,
error_text=error_text)
pass

def _convert(self, value):
if isinstance(value, dt.datetime):
if self.db.vendor == 'oracle':
string = f'{value:%Y-%m-%d %H:%M:%S}'
fmt = 'yyyy-mm-dd hh24:mi:ss'
return sa.func.to_date(string, fmt)
else:
return value
else:
return value

def _format(self, text):
text = text.format(task=self.task)
return text
Expand Down
107 changes: 103 additions & 4 deletions pydin/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import subprocess as sp
import importlib as imp

import sqlparse
import sqlalchemy as sa
import sqlparse as spa

from .const import LINUX, MACOS, WINDOWS

Expand Down Expand Up @@ -271,9 +272,9 @@ def to_upper(value):

def to_sql(text):
"""Format given SQL text."""
result = sqlparse.format(text, keyword_case='upper',
identifier_case='lower',
reindent_aligned=True)
result = spa.format(text, keyword_case='upper',
identifier_case='lower',
reindent_aligned=True)
return result


Expand All @@ -299,5 +300,103 @@ def read_file_or_string(value):
"""Read content from file or string."""
if os.path.isfile(value):
return open(value, 'r').read()
else:
return value


def sql_formatter(query, expand_select=None, expand_where=None,
make_subquery=False):
"""Modify the given SQL query."""
import sqlparse.sql as s
import sqlparse.tokens as t

def is_select(token):
if isinstance(token, s.IdentifierList):
return True
return False

def is_where(token):
if isinstance(token, s.Where):
return True
return False

def has_where(statement):
for token in statement:
if is_where(token):
return True
return False

def get_length(statement):
if statement.tokens:
return len(statement.tokens)
return 0

parsed = spa.parse(query)
statement = parsed[0] if parsed else None

ws = s.Token(t.Whitespace, ' ')
nl = s.Token(t.Newline, '\n')
cm = s.Token(t.Punctuation, ',')
where = s.Token(t.Keyword, 'where')
and_ = s.Token(t.Operator, 'and')

if expand_select:
expansion = s.Token(t.Other, expand_select)
for token in statement.tokens:
if is_select(token):
tokens = s.TokenList([cm, expansion])
last_position = get_length(token)
token.insert_after(last_position, tokens)

if expand_where:
expansion = s.Token(t.Other, expand_where)
if has_where(statement):
tokens = s.TokenList([expansion, nl, and_, ws])
for token in statement.tokens:
if is_where(token):
token.insert_after(1, tokens)
else:
tokens = s.TokenList([where, ws, expansion, nl])
last_position = get_length(statement)
statement.insert_after(last_position, tokens)

if make_subquery:
statement = f'select * from ({statement})'

text = str(statement)
result = spa.format(text, keyword_case='upper', identifier_case='lower',
reindent_aligned=True)
return str(result)


def sql_compiler(obj, db):
"""Compile the given SQL query."""
import sqlalchemy.dialects as dialects
try:
database = getattr(dialects, db.vendor)
if database:
dialect = database.dialect()
result = obj.compile(dialect=dialect, compile_kwargs=db.ckwargs)
return str(result)
else:
raise NotImplementedError
except Exception:
result = obj.compile(compile_kwargs=db.ckwargs)
return str(result)


def sql_converter(value, db):
"""Convert the given SQL value."""
if isinstance(value, int):
return value
elif isinstance(value, str):
return f'\'{value}\''
elif isinstance(value, dt.datetime):
if db.vendor == 'oracle':
string = f'{value:%Y-%m-%d %H:%M:%S}'
fmt = 'YYYY-MM-DD HH24:MI:SS'
return f'to_date(\'{string}\', \'{fmt}\')'
elif db.vendor == 'sqlite':
return f'\'{value:%Y-%m-%d %H:%M:%S}\''
else:
return value

0 comments on commit d73d1a1

Please sign in to comment.