Skip to content

Commit

Permalink
Support for the new sort Atlas functionality (#116)
Browse files Browse the repository at this point in the history
* Support for the new `sort` function

* Typo

* Fix tests

* Fix

* Fix

* Fix test

* Fix test

* Fix transformation
  • Loading branch information
0ssigeno authored Aug 18, 2023
1 parent 463b47e commit 3c4f862
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 12 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,11 @@ MyDocument.atlas.upload_index(index, "user", "pwd", "group", "cluster")
result = MyDocument.atlas.ensure_index("user", "pwd", "group", "cluster")
assert result is True

```
```


### Sort
On the [10th of July 2023](https://www.mongodb.com/docs/atlas/atlas-search/changelog/#10-july-2023-release), the `Sort` functionality was released for Atlas search.

AtlasQ, from version 0.12.0, will support this feature inside the `order_by` function.
To have the old behaviour of the order_by (useful if you want to sort _after_ aggregations and not after the search stage), you can set the kwarg `as_aggregation` as `True`.
14 changes: 9 additions & 5 deletions atlasq/queryset/queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def _aggrs(self):
if self._aggrs_query:
if self._count:
self._aggrs_query[0]["$search"]["count"] = {"type": "total"}
if self._ordering:
self._aggrs_query[0]["$search"]["sort"] = dict(self._ordering)
self._aggrs_query += self._get_projections()
self._aggrs_query += self._other_aggregations
return self._aggrs_query
Expand All @@ -113,13 +115,15 @@ def _cursor(self):
cursor = super()._cursor
return cursor

def order_by(self, *keys):
def order_by(self, *keys, as_aggregation: bool = False):
if not keys:
return self
qs: AtlasQuerySet = self.clone()
order_by: List[Tuple[str, int]] = qs._get_order_by(keys) # pylint: disable=protected-access
aggregation = {"$sort": dict(order_by)}
qs._other_aggregations.append(aggregation) # pylint: disable=protected-access
if not as_aggregation:
qs._ordering = order_by # pylint: disable=protected-access
else:
qs._other_aggregations.append({"$sort": dict(order_by)}) # pylint: disable=protected-access
return qs

def __getitem__(self, key):
Expand Down Expand Up @@ -188,8 +192,8 @@ def _get_projections(self) -> List[Dict[str, Any]]:
return [{"$project": loaded_fields}]
return []

def count(self, with_limit_and_skip=False): # pylint: disable=unused-argument
self._count = True # pylint: disable=protected-access
def count(self, with_limit_and_skip=False) -> int: # pylint: disable=unused-argument
self._count = True
cursor = self.__collection_aggregate(self._aggrs) # pylint: disable=protected-access
try:
count = next(cursor)
Expand Down
22 changes: 16 additions & 6 deletions tests/queryset/test_queryset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,19 +69,29 @@ def test_count(self):
self.assertEqual(r, 0)

def test_order_by(self):
qs = self.base.order_by("-time")
qs = self.base.order_by("-time", as_aggregation=True)
self.assertEqual(qs._aggrs[0], {"$sort": {"time": -1}})
qs = self.base.order_by("+time")
qs = self.base.order_by("+time", as_aggregation=True)
self.assertEqual(qs._aggrs[0], {"$sort": {"time": 1}})
qs = self.base.order_by("time")
qs = self.base.order_by("time", as_aggregation=True)
self.assertEqual(qs._aggrs[0], {"$sort": {"time": 1}})
qs = self.base.order_by("-time").filter(name="123")
qs = self.base.order_by("-time", as_aggregation=True).filter(name="123")
self.assertEqual(qs._aggrs[1], {"$sort": {"time": -1}})

qs = self.base.filter(name="123").order_by("+time")
self.assertIn("sort", qs._aggrs[0]["$search"])
self.assertEqual(qs._aggrs[0]["$search"]["sort"], {"time": 1})
self.assertIn("sort", qs._aggrs[0]["$search"])
qs = self.base.filter(name="123").order_by("time")
self.assertEqual(qs._aggrs[0]["$search"]["sort"], {"time": 1})
qs = self.base.filter(name="123").order_by("-time")
self.assertIn("sort", qs._aggrs[0]["$search"])
self.assertEqual(qs._aggrs[0]["$search"]["sort"], {"time": -1})

def test_only(self):
qs = self.base.only("name").filter(name="123").order_by("-time")
qs = self.base.only("name").filter(name="123")
self.assertEqual(qs._get_projections(), [{"$project": {"name": 1}}])
self.assertEqual(3, len(qs._aggrs))
self.assertEqual(2, len(qs._aggrs))
self.assertEqual(qs._aggrs[1], {"$project": {"name": 1}})

def test_exclude(self):
Expand Down

0 comments on commit 3c4f862

Please sign in to comment.