From 3c4f8621209875e4e87232ab182c03a502f62729 Mon Sep 17 00:00:00 2001 From: Simone Berni Date: Fri, 18 Aug 2023 10:48:22 +0200 Subject: [PATCH] Support for the new `sort` Atlas functionality (#116) * Support for the new `sort` function * Typo * Fix tests * Fix * Fix * Fix test * Fix test * Fix transformation --- README.md | 9 ++++++++- atlasq/queryset/queryset.py | 14 +++++++++----- tests/queryset/test_queryset.py | 22 ++++++++++++++++------ 3 files changed, 33 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index 368db8f..e1e0d03 100644 --- a/README.md +++ b/README.md @@ -126,4 +126,11 @@ MyDocument.atlas.upload_index(index, "user", "pwd", "group", "cluster") result = MyDocument.atlas.ensure_index("user", "pwd", "group", "cluster") assert result is True -``` \ No newline at end of file +``` + + +### Sort +On the [10th of July 2023](https://www.mongodb.com/docs/atlas/atlas-search/changelog/#10-july-2023-release), the `Sort` functionality was released for Atlas search. + +AtlasQ, from version 0.12.0, will support this feature inside the `order_by` function. +To have the old behaviour of the order_by (useful if you want to sort _after_ aggregations and not after the search stage), you can set the kwarg `as_aggregation` as `True`. \ No newline at end of file diff --git a/atlasq/queryset/queryset.py b/atlasq/queryset/queryset.py index e2ed1de..22c23f5 100644 --- a/atlasq/queryset/queryset.py +++ b/atlasq/queryset/queryset.py @@ -100,6 +100,8 @@ def _aggrs(self): if self._aggrs_query: if self._count: self._aggrs_query[0]["$search"]["count"] = {"type": "total"} + if self._ordering: + self._aggrs_query[0]["$search"]["sort"] = dict(self._ordering) self._aggrs_query += self._get_projections() self._aggrs_query += self._other_aggregations return self._aggrs_query @@ -113,13 +115,15 @@ def _cursor(self): cursor = super()._cursor return cursor - def order_by(self, *keys): + def order_by(self, *keys, as_aggregation: bool = False): if not keys: return self qs: AtlasQuerySet = self.clone() order_by: List[Tuple[str, int]] = qs._get_order_by(keys) # pylint: disable=protected-access - aggregation = {"$sort": dict(order_by)} - qs._other_aggregations.append(aggregation) # pylint: disable=protected-access + if not as_aggregation: + qs._ordering = order_by # pylint: disable=protected-access + else: + qs._other_aggregations.append({"$sort": dict(order_by)}) # pylint: disable=protected-access return qs def __getitem__(self, key): @@ -188,8 +192,8 @@ def _get_projections(self) -> List[Dict[str, Any]]: return [{"$project": loaded_fields}] return [] - def count(self, with_limit_and_skip=False): # pylint: disable=unused-argument - self._count = True # pylint: disable=protected-access + def count(self, with_limit_and_skip=False) -> int: # pylint: disable=unused-argument + self._count = True cursor = self.__collection_aggregate(self._aggrs) # pylint: disable=protected-access try: count = next(cursor) diff --git a/tests/queryset/test_queryset.py b/tests/queryset/test_queryset.py index 710ad84..f626a79 100644 --- a/tests/queryset/test_queryset.py +++ b/tests/queryset/test_queryset.py @@ -69,19 +69,29 @@ def test_count(self): self.assertEqual(r, 0) def test_order_by(self): - qs = self.base.order_by("-time") + qs = self.base.order_by("-time", as_aggregation=True) self.assertEqual(qs._aggrs[0], {"$sort": {"time": -1}}) - qs = self.base.order_by("+time") + qs = self.base.order_by("+time", as_aggregation=True) self.assertEqual(qs._aggrs[0], {"$sort": {"time": 1}}) - qs = self.base.order_by("time") + qs = self.base.order_by("time", as_aggregation=True) self.assertEqual(qs._aggrs[0], {"$sort": {"time": 1}}) - qs = self.base.order_by("-time").filter(name="123") + qs = self.base.order_by("-time", as_aggregation=True).filter(name="123") self.assertEqual(qs._aggrs[1], {"$sort": {"time": -1}}) + qs = self.base.filter(name="123").order_by("+time") + self.assertIn("sort", qs._aggrs[0]["$search"]) + self.assertEqual(qs._aggrs[0]["$search"]["sort"], {"time": 1}) + self.assertIn("sort", qs._aggrs[0]["$search"]) + qs = self.base.filter(name="123").order_by("time") + self.assertEqual(qs._aggrs[0]["$search"]["sort"], {"time": 1}) + qs = self.base.filter(name="123").order_by("-time") + self.assertIn("sort", qs._aggrs[0]["$search"]) + self.assertEqual(qs._aggrs[0]["$search"]["sort"], {"time": -1}) + def test_only(self): - qs = self.base.only("name").filter(name="123").order_by("-time") + qs = self.base.only("name").filter(name="123") self.assertEqual(qs._get_projections(), [{"$project": {"name": 1}}]) - self.assertEqual(3, len(qs._aggrs)) + self.assertEqual(2, len(qs._aggrs)) self.assertEqual(qs._aggrs[1], {"$project": {"name": 1}}) def test_exclude(self):