Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding missing GroupConcat features issue #1101 #1111

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
28 changes: 23 additions & 5 deletions src/django_mysql/models/aggregates.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

from typing import Any
from typing import List # noqa: F401
from typing import Union # noqa: F401

from django.db.backends.base.base import BaseDatabaseWrapper
from django.db.models import Aggregate
Expand Down Expand Up @@ -33,6 +35,7 @@ def __init__(
distinct: bool = False,
separator: str | None = None,
ordering: str | None = None,
column_order: list[str] | str | None = None,
**extra: Any,
) -> None:
if "output_field" not in extra:
Expand All @@ -46,7 +49,13 @@ def __init__(

if ordering not in ("asc", "desc", None):
raise ValueError("'ordering' must be one of 'asc', 'desc', or None")
if ordering is not None:
if column_order is not None and isinstance(column_order, list):
raise ValueError(
"When having a list in column_order, you can specify the ordering of each column inside the list. Example: ['column_a DESC',...]"
)
self.ordering = ordering
self.column_order = column_order

def as_sql(
self,
Expand All @@ -69,12 +78,21 @@ def as_sql(

sql.append(expr_sql)

if self.ordering is not None:
if self.ordering is not None or self.column_order is not None:
sql.append(" ORDER BY ")
sql.append(expr_sql)
params.extend(params[:])
sql.append(" ")
sql.append(self.ordering.upper())

if self.column_order is not None:
if isinstance(self.column_order, str):
sql.append(self.column_order)
if isinstance(self.column_order, list):
sql.append(", ".join(self.column_order))
else:
sql.append(expr_sql)
params.extend(params[:])

if self.ordering is not None and not isinstance(self.column_order, list):
sql.append(" ")
sql.append(self.ordering.upper())

if self.separator is not None:
sql.append(f" SEPARATOR '{self.separator}'") # noqa: B028
Expand Down
40 changes: 35 additions & 5 deletions tests/testapp/test_aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,16 +74,26 @@ class GroupConcatTests(TestCase):
shakes: Author
jk: Author
grisham: Author
agatha: Author
str_tutee_ids: list[str]

@classmethod
def setUpTestData(cls):
super().setUpTestData()
cls.shakes = Author.objects.create(name="William Shakespeare")
cls.jk = Author.objects.create(name="JK Rowling", tutor=cls.shakes)
cls.grisham = Author.objects.create(name="Grisham", tutor=cls.shakes)
cls.shakes = Author.objects.create(
name="William Shakespeare", bio="British author."
)
cls.jk = Author.objects.create(
name="JK Rowling", tutor=cls.shakes, bio="British author."
)
cls.grisham = Author.objects.create(
name="Grisham", tutor=cls.shakes, bio="American author."
)
cls.agatha = Author.objects.create(
name="Agatha Christie", tutor=cls.shakes, bio="British author."
)

cls.str_tutee_ids = [str(cls.jk.id), str(cls.grisham.id)]
cls.str_tutee_ids = [str(cls.jk.id), str(cls.grisham.id), str(cls.agatha.id)]

def test_basic_aggregate_ids(self):
out = self.shakes.tutees.aggregate(tids=GroupConcat("id"))
Expand Down Expand Up @@ -124,7 +134,9 @@ def test_separator_big(self):
def test_expression(self):
concat = GroupConcat(F("id") + 1)
out = self.shakes.tutees.aggregate(tids=concat)
concatted_ids = ",".join([str(self.jk.id + 1), str(self.grisham.id + 1)])
concatted_ids = ",".join(
[str(self.jk.id + 1), str(self.grisham.id + 1), str(self.agatha.id + 1)]
)
assert out == {"tids": concatted_ids}

def test_application_order(self):
Expand Down Expand Up @@ -158,3 +170,21 @@ def test_separator_ordering(self):
out = self.shakes.tutees.aggregate(tids=concat)
concatted_ids = ":".join(self.str_tutee_ids)
assert out == {"tids": concatted_ids}

def test_multiple_column_ordering(self):
concat = GroupConcat("id", column_order=["bio", "name desc"])
out = self.shakes.tutees.aggregate(tids=concat)
concatted_ids = ",".join(
[str(self.grisham.id), str(self.jk.id), str(self.agatha.id)]
)

assert out == {"tids": concatted_ids}

def test_multiple_column_ordering_2(self):
concat = GroupConcat("id", column_order="name")
out = self.shakes.tutees.aggregate(tids=concat)
concatted_ids = ",".join(
[str(self.agatha.id), str(self.grisham.id), str(self.jk.id)]
)

assert out == {"tids": concatted_ids}
Loading