Skip to content

Commit

Permalink
table alias bugfix from adamchainz#286
Browse files Browse the repository at this point in the history
  • Loading branch information
victor.lee committed Jul 28, 2018
1 parent f3aee5a commit 24116a9
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 42 deletions.
106 changes: 69 additions & 37 deletions django_mysql/models/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from subprocess import PIPE, Popen

from django.db import connections, models
from django.db.models.sql.where import ExtraWhere
from django.db.models.sql.where import AND, ExtraWhere
from django.db.transaction import atomic
from django.test.utils import CaptureQueriesContext
from django.utils import six
Expand Down Expand Up @@ -155,44 +155,15 @@ def ignore_index(self, *index_names, **kwargs):
return self._index_hint(*index_names, **kwargs)

def _index_hint(self, *index_names, **kwargs):
hint = kwargs.pop('hint')
table_name = kwargs.pop('table_name', None)
for_ = kwargs.pop('for_', None)
if kwargs:
kwargs.setdefault('table_name', self.model._meta.db_table)
if set(kwargs.keys()) - {'table_name', 'for_', 'hint'}:
raise ValueError(
"{}_index accepts only 'for_' and 'table_name' as keyword "
"arguments"
.format(hint.lower())
"{}_index accepts only 'for_' and 'table_name' as "
"keyword arguments"
.format(kwargs['hint'].lower())
)

if hint != 'USE' and not len(index_names):
raise ValueError(
"{}_index requires at least one index name"
.format(hint.lower())
)

if table_name is None:
table_name = self.model._meta.db_table

if for_ in ('JOIN', 'ORDER BY', 'GROUP BY'):
for_bit = 'FOR {} '.format(for_)
elif for_ is None:
for_bit = ''
else:
raise ValueError("for_ must be one of: None, 'JOIN', 'ORDER BY', "
"'GROUP BY'")

if len(index_names) == 0:
indexes = "NONE"
else:
indexes = "`" + "`,`".join(index_names) + "`"

hint = (
"/*QueryRewrite':index=`{table_name}` {hint} {for_bit}{indexes}*/1"
.format(table_name=table_name, hint=hint, for_bit=for_bit,
indexes=indexes)
)
return self.extra(where=[hint])
hint = IndexHint(index_names, **kwargs)
return self.extra(hints=(hint, ))

# Features handled by extra classes/functions

Expand All @@ -217,6 +188,17 @@ def pt_visual_explain(self, display=True):
def handler(self):
return Handler(self)

def extra(self, select=None, where=None, params=None, tables=None,
order_by=None, select_params=None, hints=None):
clone = super(QuerySetMixin, self).extra(
select, where, params, tables, order_by, select_params)
if hints:
for hint in hints:
if not isinstance(hint, IndexHint):
raise ValueError("hint should be instance of IndexHint")
clone.query.where.add(hint, AND)
return clone


class QuerySet(QuerySetMixin, models.QuerySet):
pass
Expand Down Expand Up @@ -608,3 +590,53 @@ def pt_visual_explain(queryset, display=True):
print(explanation)
else:
return explanation


class IndexHint(object):
contains_aggregate = False

def __init__(self, index_names, table_name, hint, for_=None, alias=None):
if hint != 'USE' and not len(index_names):
raise ValueError(
"{}_index requires at least one index name"
.format(hint.lower()))
self.hint = hint

self.for_ = for_
if self.for_ not in ('JOIN', 'ORDER BY', 'GROUP BY') \
and self.for_ is not None:
raise ValueError("for_ must be one of: None, 'JOIN', 'ORDER BY', "
"'GROUP BY'")

self.table_name = table_name
self.indexes = index_names
self.alias = alias or self.table_name

def clone(self):
return type(self)(
self.indexes, self.table_name, self.hint, self.for_, self.alias)

def __str__(self):
alias_str = ""
if self.alias and self.alias != self.table_name:
alias_str = " AS `{}`".format(self.alias)
indexes = "NONE"
if len(self.indexes) > 0:
indexes = "`{}`".format("`,`".join(self.indexes))
for_bit = ""
if self.for_ is not None:
for_bit = "FOR {} " .format(self.for_)
return (
"/*QueryRewrite':index=`{table_name}`{alias_str} {hint} "
"{for_bit}{indexes}*/1"
.format(table_name=self.table_name, hint=self.hint,
for_bit=for_bit, indexes=indexes, alias_str=alias_str)
)

def as_sql(self, compiler, connection):
return str(self), []

def relabeled_clone(self, change_map):
clone = self.clone()
clone.alias = change_map.get(self.alias, self.alias)
return clone
18 changes: 13 additions & 5 deletions django_mysql/rewrite_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
r"""
index=
(?P<table_name>`[^`]+`)
(?:\ AS\ `(?P<alias>[^`]+)`)?
\ # space
(?P<rule>USE|IGNORE|FORCE)
\ # space
Expand Down Expand Up @@ -58,6 +59,7 @@ def rewrite_query(sql):
index_match.group('rule'),
index_match.group('index_names'),
index_match.group('for_what'),
index_match.group('alias'),
))

# Silently fail on unrecognized rewrite requests
Expand Down Expand Up @@ -167,18 +169,23 @@ def modify_sql(sql, add_comments, add_hints, add_index_hints):
table_spec_re_template = r'''
\b(?P<operator>FROM|JOIN)
\s+
{table_name}
{table_name}{optional_alias}
\s+
'''

replacement_template = (
r'\g<operator> {table_name} '
r'\g<operator> {table_name}{optional_alias} '
r'{rule} INDEX {for_section}({index_names}) '
)


def modify_sql_index_hints(sql, table_name, rule, index_names, for_what):
table_spec_re = table_spec_re_template.format(table_name=table_name)
def modify_sql_index_hints(
sql, table_name, rule, index_names, for_what, alias):
alias_re = ''
if alias:
alias_re = '\s+{}'.format(re.escape(alias))
table_spec_re = table_spec_re_template.format(
table_name=re.escape(table_name), optional_alias=alias_re)
if for_what:
for_section = 'FOR {} '.format(for_what)
else:
Expand All @@ -187,6 +194,7 @@ def modify_sql_index_hints(sql, table_name, rule, index_names, for_what):
table_name=table_name,
rule=rule,
for_section=for_section,
index_names=('' if index_names == 'NONE' else index_names)
index_names=('' if index_names == 'NONE' else index_names),
optional_alias=' {}'.format(alias) if alias else ''
)
return re.sub(table_spec_re, replacement, sql, count=1, flags=re.VERBOSE)

0 comments on commit 24116a9

Please sign in to comment.