Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Select proper default for output_field based on distinct value #647

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
10 changes: 8 additions & 2 deletions src/django_mysql/models/aggregates.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from django.db.models import Aggregate, CharField

from django_mysql.models.fields import ListCharField, SetCharField


class BitAnd(Aggregate):
function = "BIT_AND"
Expand All @@ -24,8 +26,12 @@ def __init__(
):

if "output_field" not in extra:
# This can/will be improved to SetTextField or ListTextField
extra["output_field"] = CharField()
if separator is not None:
extra["output_field"] = CharField()
elif distinct:
extra["output_field"] = SetCharField(CharField())
else:
extra["output_field"] = ListCharField(CharField())

super().__init__(expression, **extra)

Expand Down
15 changes: 9 additions & 6 deletions tests/testapp/test_aggregates.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,11 @@ def setUp(self):

def test_basic_aggregate_ids(self):
out = self.shakes.tutees.aggregate(tids=GroupConcat("id"))
concatted_ids = ",".join(self.str_tutee_ids)
assert out == {"tids": concatted_ids}
assert out == {"tids": self.str_tutee_ids}

def test_distinct_aggregate_ids(self):
out = self.shakes.tutees.aggregate(tids=GroupConcat("id", distinct=True))
assert out == {"tids": set(self.str_tutee_ids)}

def test_basic_annotate_ids(self):
concat = GroupConcat("tutees__id")
Expand Down Expand Up @@ -104,14 +107,14 @@ def test_separator_big(self):
def test_expression(self):
concat = GroupConcat(F("id") + 1)
out = self.shakes.tutees.aggregate(tids=concat)
concatted_ids = ",".join([str(self.jk.id + 1), str(self.grisham.id + 1)])
concatted_ids = [str(self.jk.id + 1), str(self.grisham.id + 1)]
assert out == {"tids": concatted_ids}

def test_application_order(self):
out = Author.objects.exclude(id=self.shakes.id).aggregate(
tids=GroupConcat("tutor_id", distinct=True)
)
assert out == {"tids": str(self.shakes.id)}
assert out == {"tids": {str(self.shakes.id)}}

@override_mysql_variables(SQL_MODE="ANSI")
def test_separator_ansi_mode(self):
Expand All @@ -127,11 +130,11 @@ def test_ordering_invalid(self):

def test_ordering_asc(self):
out = self.shakes.tutees.aggregate(tids=GroupConcat("id", ordering="asc"))
assert out == {"tids": ",".join(self.str_tutee_ids)}
assert out == {"tids": self.str_tutee_ids}

def test_ordering_desc(self):
out = self.shakes.tutees.aggregate(tids=GroupConcat("id", ordering="desc"))
assert out == {"tids": ",".join(reversed(self.str_tutee_ids))}
assert out == {"tids": list(reversed(self.str_tutee_ids))}

def test_separator_ordering(self):
concat = GroupConcat("id", separator=":", ordering="asc")
Expand Down