From 250a2e198ca81dc4f8c7f64b2615f88db6422afa Mon Sep 17 00:00:00 2001 From: Denis Stepanov Date: Thu, 5 Sep 2024 13:16:21 +0200 Subject: [PATCH] Fix mongo custom query pagination (#3106) --- .../io/micronaut/data/model/Pageable.java | 33 ++- .../DefaultMongoRepositoryOperations.java | 8 +- .../operations/DefaultMongoStoredQuery.java | 13 +- .../MongoDocumentRepositorySpec.groovy | 223 ++++++++++++++++++ .../repositories/MongoPersonRepository.java | 8 + .../query/DefaultStoredQueryResolver.java | 6 +- 6 files changed, 266 insertions(+), 25 deletions(-) diff --git a/data-model/src/main/java/io/micronaut/data/model/Pageable.java b/data-model/src/main/java/io/micronaut/data/model/Pageable.java index 70ea4c294c6..603e6216a66 100644 --- a/data-model/src/main/java/io/micronaut/data/model/Pageable.java +++ b/data-model/src/main/java/io/micronaut/data/model/Pageable.java @@ -115,37 +115,36 @@ default Sort getSort() { /** * @return The next pageable. */ - default @NonNull Pageable next() { - int size = getSize(); - if (size < 0) { - // unpaged - return Pageable.from(0, size, getSort()); - } - int newNumber = getNumber() + 1; - // handle overflow - if (newNumber < 0) { - return Pageable.from(0, size, getSort()); - } else { - return Pageable.from(newNumber, size, getSort()); - } + @NonNull + default Pageable next() { + return getPageable(getNumber() + 1); } /** * @return The previous pageable */ - default @NonNull Pageable previous() { + @NonNull + default Pageable previous() { + return getPageable(getNumber() - 1); + } + + private Pageable getPageable(int newNumber) { int size = getSize(); if (size < 0) { // unpaged return Pageable.from(0, size, getSort()); } - int newNumber = getNumber() - 1; + Pageable newPageable; // handle overflow if (newNumber < 0) { - return Pageable.from(0, size, getSort()); + newPageable = Pageable.from(0, size, getSort()); } else { - return Pageable.from(newNumber, size, getSort()); + newPageable = Pageable.from(newNumber, size, getSort()); + } + if (!requestTotal()) { + newPageable = newPageable.withoutTotal(); } + return newPageable; } /** diff --git a/data-mongodb/src/main/java/io/micronaut/data/mongodb/operations/DefaultMongoRepositoryOperations.java b/data-mongodb/src/main/java/io/micronaut/data/mongodb/operations/DefaultMongoRepositoryOperations.java index b31f3bdf0e9..6c581812a33 100644 --- a/data-mongodb/src/main/java/io/micronaut/data/mongodb/operations/DefaultMongoRepositoryOperations.java +++ b/data-mongodb/src/main/java/io/micronaut/data/mongodb/operations/DefaultMongoRepositoryOperations.java @@ -387,12 +387,12 @@ private FindIterable applyFindOptions(@Nullable MongoFindOptions findOp findIterable = findIterable.collation(collation); } Integer skip = findOptions.getSkip(); - if (skip != null) { + if (skip != null && skip > 0) { findIterable = findIterable.skip(skip); } Integer limit = findOptions.getLimit(); - if (limit != null) { - findIterable = findIterable.limit(Math.max(limit, 0)); + if (limit != null && limit > 0) { + findIterable = findIterable.limit(limit); } Bson sort = findOptions.getSort(); if (sort != null) { @@ -403,7 +403,7 @@ private FindIterable applyFindOptions(@Nullable MongoFindOptions findOp findIterable = findIterable.projection(projection); } Integer batchSize = findOptions.getBatchSize(); - if (batchSize != null) { + if (batchSize != null && batchSize > 0) { findIterable = findIterable.batchSize(batchSize); } Boolean allowDiskUse = findOptions.getAllowDiskUse(); diff --git a/data-mongodb/src/main/java/io/micronaut/data/mongodb/operations/DefaultMongoStoredQuery.java b/data-mongodb/src/main/java/io/micronaut/data/mongodb/operations/DefaultMongoStoredQuery.java index a35681983e2..ddcab0b62ef 100644 --- a/data-mongodb/src/main/java/io/micronaut/data/mongodb/operations/DefaultMongoStoredQuery.java +++ b/data-mongodb/src/main/java/io/micronaut/data/mongodb/operations/DefaultMongoStoredQuery.java @@ -144,7 +144,7 @@ final class DefaultMongoStoredQuery extends DefaultBindableParametersStore aggregateData = null; findData = new FindData(BsonDocument.parse(query)); } else if (query.startsWith("[")) { - aggregateData = new AggregateData(BsonArray.parse(query).stream().map(BsonValue::asDocument).collect(Collectors.toList())); + aggregateData = new AggregateData(parseAggregation(query, storedQuery.isCount())); findData = null; } else { aggregateData = null; @@ -184,6 +184,17 @@ final class DefaultMongoStoredQuery extends DefaultBindableParametersStore } } + private List parseAggregation(String query, boolean isCount) { + List pipeline = BsonArray.parse(query).stream().map(BsonValue::asDocument).toList(); + if (isCount && pipeline.stream().noneMatch(p -> p.toBsonDocument().containsKey("$count"))) { + // We can probably remove sorting projection etc. or allow a user to specify a custom count pipeline + List countPipeline = new ArrayList<>(pipeline); + countPipeline.add(BsonDocument.parse("{ $count: \"totalCount\" }")); + return countPipeline; + } + return pipeline; + } + @Override public boolean isCount() { return isCount; diff --git a/data-mongodb/src/test/groovy/io/micronaut/data/document/mongodb/MongoDocumentRepositorySpec.groovy b/data-mongodb/src/test/groovy/io/micronaut/data/document/mongodb/MongoDocumentRepositorySpec.groovy index c7e142019e0..a0639c4c11d 100644 --- a/data-mongodb/src/test/groovy/io/micronaut/data/document/mongodb/MongoDocumentRepositorySpec.groovy +++ b/data-mongodb/src/test/groovy/io/micronaut/data/document/mongodb/MongoDocumentRepositorySpec.groovy @@ -116,6 +116,229 @@ class MongoDocumentRepositorySpec extends AbstractDocumentRepositorySpec impleme people[1].age == 0 } + void "test custom find paginated"() { + given: + savePersons(["Dennis", "Jeff", "James", "Dennis", "Josh", "Steven", "Jake", "Jim"]) + def peopleToUpdate = personRepository.findAll().toList() + peopleToUpdate.forEach {it.age = 100 } + personRepository.updateAll(peopleToUpdate) + when: + def peoplePage = personRepository.customFindPage("J.*", Pageable.from(0, 2)) + def people = peoplePage.getContent() + then: + peoplePage.hasTotalSize() + peoplePage.getTotalPages() == 3 + peoplePage.pageNumber == 0 + people.size() == 2 + people[0].name == "Jake" + people[0].age == 0 // Projection works + people[1].name == "James" + when: + peoplePage = personRepository.customFindPage("J.*", peoplePage.nextPageable()) + people = peoplePage.getContent() + then: + peoplePage.hasTotalSize() + peoplePage.getTotalPages() == 3 + peoplePage.pageNumber == 1 + people.size() == 2 + people[0].name == "Jeff" + people[0].age == 0 // Projection works + people[1].name == "Jim" + when: + peoplePage = personRepository.customFindPage("J.*", peoplePage.nextPageable()) + people = peoplePage.getContent() + then: + peoplePage.hasTotalSize() + peoplePage.getTotalPages() == 3 + peoplePage.pageNumber == 2 + people.size() == 1 + people[0].name == "Josh" + } + + void "test custom find paginated without count"() { + given: + savePersons(["Dennis", "Jeff", "James", "Dennis", "Josh", "Steven", "Jake", "Jim"]) + def peopleToUpdate = personRepository.findAll().toList() + peopleToUpdate.forEach {it.age = 100 } + personRepository.updateAll(peopleToUpdate) + when: + def peoplePage = personRepository.customFindPage("J.*", Pageable.from(0, 2).withoutTotal()) + def people = peoplePage.getContent() + then: + !peoplePage.hasTotalSize() + peoplePage.pageNumber == 0 + people.size() == 2 + people[0].name == "Jake" + people[0].age == 0 // Projection works + people[1].name == "James" + when: + peoplePage = personRepository.customFindPage("J.*", peoplePage.nextPageable()) + people = peoplePage.getContent() + then: + !peoplePage.hasTotalSize() + peoplePage.pageNumber == 1 + people.size() == 2 + people[0].name == "Jeff" + people[0].age == 0 // Projection works + people[1].name == "Jim" + when: + peoplePage = personRepository.customFindPage("J.*", peoplePage.nextPageable()) + people = peoplePage.getContent() + then: + !peoplePage.hasTotalSize() + peoplePage.pageNumber == 2 + people.size() == 1 + people[0].name == "Josh" + } + + void "test custom find paginated without count backwards"() { + given: + savePersons(["Dennis", "Jeff", "James", "Dennis", "Josh", "Steven", "Jake", "Jim"]) + def peopleToUpdate = personRepository.findAll().toList() + peopleToUpdate.forEach {it.age = 100 } + personRepository.updateAll(peopleToUpdate) + when: + def peoplePage = personRepository.customFindPage("J.*", Pageable.from(2, 2).withoutTotal()) + def people = peoplePage.getContent() + then: + !peoplePage.hasTotalSize() + peoplePage.pageNumber == 2 + people.size() == 1 + people[0].name == "Josh" + people[0].age == 0 // Projection works + + when: + peoplePage = personRepository.customFindPage("J.*", peoplePage.previousPageable()) + people = peoplePage.getContent() + then: + !peoplePage.hasTotalSize() + peoplePage.pageNumber == 1 + people.size() == 2 + people[0].name == "Jeff" + people[0].age == 0 // Projection works + people[1].name == "Jim" + when: + peoplePage = personRepository.customFindPage("J.*", peoplePage.previousPageable()) + people = peoplePage.getContent() + then: + !peoplePage.hasTotalSize() + peoplePage.pageNumber == 0 + people.size() == 2 + people[0].name == "Jake" + people[1].name == "James" + } + + void "test custom aggr paginated"() { + given: + savePersons(["Dennis", "Jeff", "James", "Dennis", "Josh", "Steven", "Jake", "Jim"]) + def peopleToUpdate = personRepository.findAll().toList() + peopleToUpdate.forEach {it.age = 100 } + personRepository.updateAll(peopleToUpdate) + when: + def peoplePage = personRepository.customAggrPage("J.*", Pageable.from(0, 2)) + def people = peoplePage.getContent() + then: + peoplePage.hasTotalSize() + peoplePage.getTotalPages() == 3 + peoplePage.pageNumber == 0 + people.size() == 2 + people[0].name == "Jake" + people[0].age == 0 // Projection works + people[1].name == "James" + when: + peoplePage = personRepository.customAggrPage("J.*", peoplePage.nextPageable()) + people = peoplePage.getContent() + then: + peoplePage.hasTotalSize() + peoplePage.getTotalPages() == 3 + peoplePage.pageNumber == 1 + people.size() == 2 + people[0].name == "Jeff" + people[0].age == 0 // Projection works + people[1].name == "Jim" + when: + peoplePage = personRepository.customAggrPage("J.*", peoplePage.nextPageable()) + people = peoplePage.getContent() + then: + peoplePage.hasTotalSize() + peoplePage.getTotalPages() == 3 + peoplePage.pageNumber == 2 + people.size() == 1 + people[0].name == "Josh" + } + + void "test custom aggr paginated without count"() { + given: + savePersons(["Dennis", "Jeff", "James", "Dennis", "Josh", "Steven", "Jake", "Jim"]) + def peopleToUpdate = personRepository.findAll().toList() + peopleToUpdate.forEach {it.age = 100 } + personRepository.updateAll(peopleToUpdate) + when: + def peoplePage = personRepository.customAggrPage("J.*", Pageable.from(0, 2).withoutTotal()) + def people = peoplePage.getContent() + then: + !peoplePage.hasTotalSize() + peoplePage.pageNumber == 0 + people.size() == 2 + people[0].name == "Jake" + people[0].age == 0 // Projection works + people[1].name == "James" + when: + peoplePage = personRepository.customAggrPage("J.*", peoplePage.nextPageable()) + people = peoplePage.getContent() + then: + !peoplePage.hasTotalSize() + peoplePage.pageNumber == 1 + people.size() == 2 + people[0].name == "Jeff" + people[0].age == 0 // Projection works + people[1].name == "Jim" + when: + peoplePage = personRepository.customAggrPage("J.*", peoplePage.nextPageable()) + people = peoplePage.getContent() + then: + !peoplePage.hasTotalSize() + peoplePage.pageNumber == 2 + people.size() == 1 + people[0].name == "Josh" + } + + void "test custom aggr paginated without count backwards"() { + given: + savePersons(["Dennis", "Jeff", "James", "Dennis", "Josh", "Steven", "Jake", "Jim"]) + def peopleToUpdate = personRepository.findAll().toList() + peopleToUpdate.forEach {it.age = 100 } + personRepository.updateAll(peopleToUpdate) + when: + def peoplePage = personRepository.customAggrPage("J.*", Pageable.from(2, 2).withoutTotal()) + def people = peoplePage.getContent() + then: + !peoplePage.hasTotalSize() + peoplePage.pageNumber == 2 + people.size() == 1 + people[0].name == "Josh" + people[0].age == 0 // Projection works + when: + peoplePage = personRepository.customAggrPage("J.*", peoplePage.previousPageable()) + people = peoplePage.getContent() + then: + !peoplePage.hasTotalSize() + peoplePage.pageNumber == 1 + people.size() == 2 + people[0].name == "Jeff" + people[0].age == 0 // Projection works + people[1].name == "Jim" + when: + peoplePage = personRepository.customAggrPage("J.*", peoplePage.previousPageable()) + people = peoplePage.getContent() + then: + !peoplePage.hasTotalSize() + peoplePage.pageNumber == 0 + people.size() == 2 + people[0].name == "Jake" + people[1].name == "James" + } + void "test custom aggr"() { given: savePersons(["Dennis", "Jeff", "James", "Dennis"]) diff --git a/data-mongodb/src/test/java/io/micronaut/data/document/mongodb/repositories/MongoPersonRepository.java b/data-mongodb/src/test/java/io/micronaut/data/document/mongodb/repositories/MongoPersonRepository.java index 73efe1623a7..dfd0ce62c28 100644 --- a/data-mongodb/src/test/java/io/micronaut/data/document/mongodb/repositories/MongoPersonRepository.java +++ b/data-mongodb/src/test/java/io/micronaut/data/document/mongodb/repositories/MongoPersonRepository.java @@ -2,6 +2,8 @@ import io.micronaut.data.document.tck.entities.Person; import io.micronaut.data.document.tck.repositories.PersonRepository; +import io.micronaut.data.model.Page; +import io.micronaut.data.model.Pageable; import io.micronaut.data.mongodb.annotation.MongoAggregateOptions; import io.micronaut.data.mongodb.annotation.MongoAggregateQuery; import io.micronaut.data.mongodb.annotation.MongoDeleteOptions; @@ -27,9 +29,15 @@ public interface MongoPersonRepository extends PersonRepository { @MongoFindQuery(filter = "{name:{$regex: :t}}", sort = "{ name : 1 }", project = "{ name: 1}") List customFind(String t); + @MongoFindQuery(filter = "{name:{$regex: :t}}", sort = "{ name : 1 }", project = "{ name: 1}") + Page customFindPage(String t, Pageable pageable); + @MongoAggregateQuery("[{$match: {name:{$regex: :t}}}, {$sort: {name: 1}}, {$project: {name: 1}}]") List customAgg(String t); + @MongoAggregateQuery("[{$match: {name:{$regex: :t}}}, {$sort: {name: 1}}, {$project: {name: 1}}]") + Page customAggrPage(String t, Pageable pageable); + @MongoUpdateQuery(update = "{$set:{name: :newName}}", filter = "{name:{$eq: :oldName}}") long updateNamesCustom(String newName, String oldName); diff --git a/data-runtime/src/main/java/io/micronaut/data/runtime/query/DefaultStoredQueryResolver.java b/data-runtime/src/main/java/io/micronaut/data/runtime/query/DefaultStoredQueryResolver.java index 7147589234d..47373dba398 100644 --- a/data-runtime/src/main/java/io/micronaut/data/runtime/query/DefaultStoredQueryResolver.java +++ b/data-runtime/src/main/java/io/micronaut/data/runtime/query/DefaultStoredQueryResolver.java @@ -66,9 +66,9 @@ public StoredQuery resolveQuery(MethodInvocationContext conte @Override public StoredQuery resolveCountQuery(MethodInvocationContext context, Class entityClass, Class resultType) { - String query = context.stringValue(Query.class, DataMethod.META_MEMBER_COUNT_QUERY).orElseThrow(() -> - new IllegalStateException("No query present in method") - ); + String query = context.stringValue(Query.class, DataMethod.META_MEMBER_COUNT_QUERY) + .orElseGet(() -> context.stringValue(Query.class) + .orElseThrow(() -> new IllegalStateException("No query present in method"))); return new DefaultStoredQuery<>( context.getExecutableMethod(), resultType,