diff --git a/pydin/models.py b/pydin/models.py index 684addc..f8404be 100644 --- a/pydin/models.py +++ b/pydin/models.py @@ -25,6 +25,7 @@ from .utils import to_sql from .utils import coalesce from .utils import read_file_or_string +from .utils import sql_formatter, sql_compiler, sql_converter from .core import Node @@ -112,13 +113,13 @@ def target_date(self): def date_from(self): """The beggining of the target date of this model.""" if hasattr(self, 'date_field'): - return self.converter(self.target_date.start) + return self.target_date.start @property def date_to(self): """The end of the target date of this model.""" if hasattr(self, 'date_field'): - return self.converter(self.target_date.end) + return self.target_date.end @property def value_field(self): @@ -210,13 +211,6 @@ def format_custom_query(self, value): else: return value - def converter(self, value): - """Convert the given value into the appropriate model format.""" - if value and hasattr(self, '_convert'): - return self._convert(value) - else: - return value - def explain(self, parameter_name=None): """Get model or chosen parameter description.""" if not parameter_name: @@ -986,26 +980,30 @@ def assemble(self): table = self.get_table() insert = table.insert() + select = self.select columns = [sa.column(column) for column in self.describe()] - select = sa.text(self.select).columns(*columns).alias('s') if self.key_field: columns = [*columns, self.key_field.column] - select = sa.select(columns).select_from(select) + expression = f'{self.key_field.value} as {self.key_field.label}' + select = sql_formatter(select, expand_select=expression) if self.date_field: - date_column = sa.column(self.date_field) - between = sa.between(date_column, self.date_from, self.date_to) - select = select.where(between) + date_from = sql_converter(self.date_from, self.db) + date_to = sql_converter(self.date_to, self.db) + between = f'{self.date_field} between {date_from} and {date_to}' + select = sql_formatter(select, expand_where=between) if self.value_field: last_value = self.get_last_value() - value_column = sa.column(self.value_field) - select = select.where(value_column > last_value) + if last_value: + comparison = f'{self.value_field} > {last_value}' + select = sql_formatter(select, make_subquery=True) + select = sql_formatter(select, expand_where=comparison) + select = sa.text(f'\n{select}').columns(*columns) query = insert.from_select(columns, select) - query = query.compile(compile_kwargs=self.db.ckwargs) - query = str(query) + query = sql_compiler(query, self.db) return query def parse(self): @@ -1072,17 +1070,6 @@ def endlog(self, output_rows=None, output_text=None, error_text=error_text) pass - def _convert(self, value): - if isinstance(value, dt.datetime): - if self.db.vendor == 'oracle': - string = f'{value:%Y-%m-%d %H:%M:%S}' - fmt = 'yyyy-mm-dd hh24:mi:ss' - return sa.func.to_date(string, fmt) - else: - return value - else: - return value - def _format(self, text): text = text.format(task=self.task) return text diff --git a/pydin/utils.py b/pydin/utils.py index 449a58d..0cf51ff 100644 --- a/pydin/utils.py +++ b/pydin/utils.py @@ -12,7 +12,8 @@ import subprocess as sp import importlib as imp -import sqlparse +import sqlalchemy as sa +import sqlparse as spa from .const import LINUX, MACOS, WINDOWS @@ -271,9 +272,9 @@ def to_upper(value): def to_sql(text): """Format given SQL text.""" - result = sqlparse.format(text, keyword_case='upper', - identifier_case='lower', - reindent_aligned=True) + result = spa.format(text, keyword_case='upper', + identifier_case='lower', + reindent_aligned=True) return result @@ -299,5 +300,103 @@ def read_file_or_string(value): """Read content from file or string.""" if os.path.isfile(value): return open(value, 'r').read() + else: + return value + + +def sql_formatter(query, expand_select=None, expand_where=None, + make_subquery=False): + """Modify the given SQL query.""" + import sqlparse.sql as s + import sqlparse.tokens as t + + def is_select(token): + if isinstance(token, s.IdentifierList): + return True + return False + + def is_where(token): + if isinstance(token, s.Where): + return True + return False + + def has_where(statement): + for token in statement: + if is_where(token): + return True + return False + + def get_length(statement): + if statement.tokens: + return len(statement.tokens) + return 0 + + parsed = spa.parse(query) + statement = parsed[0] if parsed else None + + ws = s.Token(t.Whitespace, ' ') + nl = s.Token(t.Newline, '\n') + cm = s.Token(t.Punctuation, ',') + where = s.Token(t.Keyword, 'where') + and_ = s.Token(t.Operator, 'and') + + if expand_select: + expansion = s.Token(t.Other, expand_select) + for token in statement.tokens: + if is_select(token): + tokens = s.TokenList([cm, expansion]) + last_position = get_length(token) + token.insert_after(last_position, tokens) + + if expand_where: + expansion = s.Token(t.Other, expand_where) + if has_where(statement): + tokens = s.TokenList([expansion, nl, and_, ws]) + for token in statement.tokens: + if is_where(token): + token.insert_after(1, tokens) + else: + tokens = s.TokenList([where, ws, expansion, nl]) + last_position = get_length(statement) + statement.insert_after(last_position, tokens) + + if make_subquery: + statement = f'select * from ({statement})' + + text = str(statement) + result = spa.format(text, keyword_case='upper', identifier_case='lower', + reindent_aligned=True) + return str(result) + + +def sql_compiler(obj, db): + """Compile the given SQL query.""" + import sqlalchemy.dialects as dialects + try: + database = getattr(dialects, db.vendor) + if database: + dialect = database.dialect() + result = obj.compile(dialect=dialect, compile_kwargs=db.ckwargs) + return str(result) + else: + raise NotImplementedError + except Exception: + result = obj.compile(compile_kwargs=db.ckwargs) + return str(result) + + +def sql_converter(value, db): + """Convert the given SQL value.""" + if isinstance(value, int): + return value + elif isinstance(value, str): + return f'\'{value}\'' + elif isinstance(value, dt.datetime): + if db.vendor == 'oracle': + string = f'{value:%Y-%m-%d %H:%M:%S}' + fmt = 'YYYY-MM-DD HH24:MI:SS' + return f'to_date(\'{string}\', \'{fmt}\')' + elif db.vendor == 'sqlite': + return f'\'{value:%Y-%m-%d %H:%M:%S}\'' else: return value \ No newline at end of file