Skip to content

Commit

Permalink
Fix mongo custom query pagination (#3106)
Browse files Browse the repository at this point in the history
  • Loading branch information
dstepanov authored Sep 5, 2024
1 parent f1a6fdd commit 250a2e1
Show file tree
Hide file tree
Showing 6 changed files with 266 additions and 25 deletions.
33 changes: 16 additions & 17 deletions data-model/src/main/java/io/micronaut/data/model/Pageable.java
Original file line number Diff line number Diff line change
Expand Up @@ -115,37 +115,36 @@ default Sort getSort() {
/**
* @return The next pageable.
*/
default @NonNull Pageable next() {
int size = getSize();
if (size < 0) {
// unpaged
return Pageable.from(0, size, getSort());
}
int newNumber = getNumber() + 1;
// handle overflow
if (newNumber < 0) {
return Pageable.from(0, size, getSort());
} else {
return Pageable.from(newNumber, size, getSort());
}
@NonNull
default Pageable next() {
return getPageable(getNumber() + 1);
}

/**
* @return The previous pageable
*/
default @NonNull Pageable previous() {
@NonNull
default Pageable previous() {
return getPageable(getNumber() - 1);
}

private Pageable getPageable(int newNumber) {
int size = getSize();
if (size < 0) {
// unpaged
return Pageable.from(0, size, getSort());
}
int newNumber = getNumber() - 1;
Pageable newPageable;
// handle overflow
if (newNumber < 0) {
return Pageable.from(0, size, getSort());
newPageable = Pageable.from(0, size, getSort());
} else {
return Pageable.from(newNumber, size, getSort());
newPageable = Pageable.from(newNumber, size, getSort());
}
if (!requestTotal()) {
newPageable = newPageable.withoutTotal();
}
return newPageable;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -387,12 +387,12 @@ private <MR> FindIterable<MR> applyFindOptions(@Nullable MongoFindOptions findOp
findIterable = findIterable.collation(collation);
}
Integer skip = findOptions.getSkip();
if (skip != null) {
if (skip != null && skip > 0) {
findIterable = findIterable.skip(skip);
}
Integer limit = findOptions.getLimit();
if (limit != null) {
findIterable = findIterable.limit(Math.max(limit, 0));
if (limit != null && limit > 0) {
findIterable = findIterable.limit(limit);
}
Bson sort = findOptions.getSort();
if (sort != null) {
Expand All @@ -403,7 +403,7 @@ private <MR> FindIterable<MR> applyFindOptions(@Nullable MongoFindOptions findOp
findIterable = findIterable.projection(projection);
}
Integer batchSize = findOptions.getBatchSize();
if (batchSize != null) {
if (batchSize != null && batchSize > 0) {
findIterable = findIterable.batchSize(batchSize);
}
Boolean allowDiskUse = findOptions.getAllowDiskUse();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ final class DefaultMongoStoredQuery<E, R> extends DefaultBindableParametersStore
aggregateData = null;
findData = new FindData(BsonDocument.parse(query));
} else if (query.startsWith("[")) {
aggregateData = new AggregateData(BsonArray.parse(query).stream().map(BsonValue::asDocument).collect(Collectors.toList()));
aggregateData = new AggregateData(parseAggregation(query, storedQuery.isCount()));
findData = null;
} else {
aggregateData = null;
Expand Down Expand Up @@ -184,6 +184,17 @@ final class DefaultMongoStoredQuery<E, R> extends DefaultBindableParametersStore
}
}

private List<Bson> parseAggregation(String query, boolean isCount) {
List<Bson> pipeline = BsonArray.parse(query).stream().<Bson>map(BsonValue::asDocument).toList();
if (isCount && pipeline.stream().noneMatch(p -> p.toBsonDocument().containsKey("$count"))) {
// We can probably remove sorting projection etc. or allow a user to specify a custom count pipeline
List<Bson> countPipeline = new ArrayList<>(pipeline);
countPipeline.add(BsonDocument.parse("{ $count: \"totalCount\" }"));
return countPipeline;
}
return pipeline;
}

@Override
public boolean isCount() {
return isCount;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,229 @@ class MongoDocumentRepositorySpec extends AbstractDocumentRepositorySpec impleme
people[1].age == 0
}

void "test custom find paginated"() {
given:
savePersons(["Dennis", "Jeff", "James", "Dennis", "Josh", "Steven", "Jake", "Jim"])
def peopleToUpdate = personRepository.findAll().toList()
peopleToUpdate.forEach {it.age = 100 }
personRepository.updateAll(peopleToUpdate)
when:
def peoplePage = personRepository.customFindPage("J.*", Pageable.from(0, 2))
def people = peoplePage.getContent()
then:
peoplePage.hasTotalSize()
peoplePage.getTotalPages() == 3
peoplePage.pageNumber == 0
people.size() == 2
people[0].name == "Jake"
people[0].age == 0 // Projection works
people[1].name == "James"
when:
peoplePage = personRepository.customFindPage("J.*", peoplePage.nextPageable())
people = peoplePage.getContent()
then:
peoplePage.hasTotalSize()
peoplePage.getTotalPages() == 3
peoplePage.pageNumber == 1
people.size() == 2
people[0].name == "Jeff"
people[0].age == 0 // Projection works
people[1].name == "Jim"
when:
peoplePage = personRepository.customFindPage("J.*", peoplePage.nextPageable())
people = peoplePage.getContent()
then:
peoplePage.hasTotalSize()
peoplePage.getTotalPages() == 3
peoplePage.pageNumber == 2
people.size() == 1
people[0].name == "Josh"
}

void "test custom find paginated without count"() {
given:
savePersons(["Dennis", "Jeff", "James", "Dennis", "Josh", "Steven", "Jake", "Jim"])
def peopleToUpdate = personRepository.findAll().toList()
peopleToUpdate.forEach {it.age = 100 }
personRepository.updateAll(peopleToUpdate)
when:
def peoplePage = personRepository.customFindPage("J.*", Pageable.from(0, 2).withoutTotal())
def people = peoplePage.getContent()
then:
!peoplePage.hasTotalSize()
peoplePage.pageNumber == 0
people.size() == 2
people[0].name == "Jake"
people[0].age == 0 // Projection works
people[1].name == "James"
when:
peoplePage = personRepository.customFindPage("J.*", peoplePage.nextPageable())
people = peoplePage.getContent()
then:
!peoplePage.hasTotalSize()
peoplePage.pageNumber == 1
people.size() == 2
people[0].name == "Jeff"
people[0].age == 0 // Projection works
people[1].name == "Jim"
when:
peoplePage = personRepository.customFindPage("J.*", peoplePage.nextPageable())
people = peoplePage.getContent()
then:
!peoplePage.hasTotalSize()
peoplePage.pageNumber == 2
people.size() == 1
people[0].name == "Josh"
}

void "test custom find paginated without count backwards"() {
given:
savePersons(["Dennis", "Jeff", "James", "Dennis", "Josh", "Steven", "Jake", "Jim"])
def peopleToUpdate = personRepository.findAll().toList()
peopleToUpdate.forEach {it.age = 100 }
personRepository.updateAll(peopleToUpdate)
when:
def peoplePage = personRepository.customFindPage("J.*", Pageable.from(2, 2).withoutTotal())
def people = peoplePage.getContent()
then:
!peoplePage.hasTotalSize()
peoplePage.pageNumber == 2
people.size() == 1
people[0].name == "Josh"
people[0].age == 0 // Projection works

when:
peoplePage = personRepository.customFindPage("J.*", peoplePage.previousPageable())
people = peoplePage.getContent()
then:
!peoplePage.hasTotalSize()
peoplePage.pageNumber == 1
people.size() == 2
people[0].name == "Jeff"
people[0].age == 0 // Projection works
people[1].name == "Jim"
when:
peoplePage = personRepository.customFindPage("J.*", peoplePage.previousPageable())
people = peoplePage.getContent()
then:
!peoplePage.hasTotalSize()
peoplePage.pageNumber == 0
people.size() == 2
people[0].name == "Jake"
people[1].name == "James"
}

void "test custom aggr paginated"() {
given:
savePersons(["Dennis", "Jeff", "James", "Dennis", "Josh", "Steven", "Jake", "Jim"])
def peopleToUpdate = personRepository.findAll().toList()
peopleToUpdate.forEach {it.age = 100 }
personRepository.updateAll(peopleToUpdate)
when:
def peoplePage = personRepository.customAggrPage("J.*", Pageable.from(0, 2))
def people = peoplePage.getContent()
then:
peoplePage.hasTotalSize()
peoplePage.getTotalPages() == 3
peoplePage.pageNumber == 0
people.size() == 2
people[0].name == "Jake"
people[0].age == 0 // Projection works
people[1].name == "James"
when:
peoplePage = personRepository.customAggrPage("J.*", peoplePage.nextPageable())
people = peoplePage.getContent()
then:
peoplePage.hasTotalSize()
peoplePage.getTotalPages() == 3
peoplePage.pageNumber == 1
people.size() == 2
people[0].name == "Jeff"
people[0].age == 0 // Projection works
people[1].name == "Jim"
when:
peoplePage = personRepository.customAggrPage("J.*", peoplePage.nextPageable())
people = peoplePage.getContent()
then:
peoplePage.hasTotalSize()
peoplePage.getTotalPages() == 3
peoplePage.pageNumber == 2
people.size() == 1
people[0].name == "Josh"
}

void "test custom aggr paginated without count"() {
given:
savePersons(["Dennis", "Jeff", "James", "Dennis", "Josh", "Steven", "Jake", "Jim"])
def peopleToUpdate = personRepository.findAll().toList()
peopleToUpdate.forEach {it.age = 100 }
personRepository.updateAll(peopleToUpdate)
when:
def peoplePage = personRepository.customAggrPage("J.*", Pageable.from(0, 2).withoutTotal())
def people = peoplePage.getContent()
then:
!peoplePage.hasTotalSize()
peoplePage.pageNumber == 0
people.size() == 2
people[0].name == "Jake"
people[0].age == 0 // Projection works
people[1].name == "James"
when:
peoplePage = personRepository.customAggrPage("J.*", peoplePage.nextPageable())
people = peoplePage.getContent()
then:
!peoplePage.hasTotalSize()
peoplePage.pageNumber == 1
people.size() == 2
people[0].name == "Jeff"
people[0].age == 0 // Projection works
people[1].name == "Jim"
when:
peoplePage = personRepository.customAggrPage("J.*", peoplePage.nextPageable())
people = peoplePage.getContent()
then:
!peoplePage.hasTotalSize()
peoplePage.pageNumber == 2
people.size() == 1
people[0].name == "Josh"
}

void "test custom aggr paginated without count backwards"() {
given:
savePersons(["Dennis", "Jeff", "James", "Dennis", "Josh", "Steven", "Jake", "Jim"])
def peopleToUpdate = personRepository.findAll().toList()
peopleToUpdate.forEach {it.age = 100 }
personRepository.updateAll(peopleToUpdate)
when:
def peoplePage = personRepository.customAggrPage("J.*", Pageable.from(2, 2).withoutTotal())
def people = peoplePage.getContent()
then:
!peoplePage.hasTotalSize()
peoplePage.pageNumber == 2
people.size() == 1
people[0].name == "Josh"
people[0].age == 0 // Projection works
when:
peoplePage = personRepository.customAggrPage("J.*", peoplePage.previousPageable())
people = peoplePage.getContent()
then:
!peoplePage.hasTotalSize()
peoplePage.pageNumber == 1
people.size() == 2
people[0].name == "Jeff"
people[0].age == 0 // Projection works
people[1].name == "Jim"
when:
peoplePage = personRepository.customAggrPage("J.*", peoplePage.previousPageable())
people = peoplePage.getContent()
then:
!peoplePage.hasTotalSize()
peoplePage.pageNumber == 0
people.size() == 2
people[0].name == "Jake"
people[1].name == "James"
}

void "test custom aggr"() {
given:
savePersons(["Dennis", "Jeff", "James", "Dennis"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import io.micronaut.data.document.tck.entities.Person;
import io.micronaut.data.document.tck.repositories.PersonRepository;
import io.micronaut.data.model.Page;
import io.micronaut.data.model.Pageable;
import io.micronaut.data.mongodb.annotation.MongoAggregateOptions;
import io.micronaut.data.mongodb.annotation.MongoAggregateQuery;
import io.micronaut.data.mongodb.annotation.MongoDeleteOptions;
Expand All @@ -27,9 +29,15 @@ public interface MongoPersonRepository extends PersonRepository {
@MongoFindQuery(filter = "{name:{$regex: :t}}", sort = "{ name : 1 }", project = "{ name: 1}")
List<Person> customFind(String t);

@MongoFindQuery(filter = "{name:{$regex: :t}}", sort = "{ name : 1 }", project = "{ name: 1}")
Page<Person> customFindPage(String t, Pageable pageable);

@MongoAggregateQuery("[{$match: {name:{$regex: :t}}}, {$sort: {name: 1}}, {$project: {name: 1}}]")
List<Person> customAgg(String t);

@MongoAggregateQuery("[{$match: {name:{$regex: :t}}}, {$sort: {name: 1}}, {$project: {name: 1}}]")
Page<Person> customAggrPage(String t, Pageable pageable);

@MongoUpdateQuery(update = "{$set:{name: :newName}}", filter = "{name:{$eq: :oldName}}")
long updateNamesCustom(String newName, String oldName);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ public <E, R> StoredQuery<E, R> resolveQuery(MethodInvocationContext<?, ?> conte

@Override
public <E, R> StoredQuery<E, R> resolveCountQuery(MethodInvocationContext<?, ?> context, Class<E> entityClass, Class<R> resultType) {
String query = context.stringValue(Query.class, DataMethod.META_MEMBER_COUNT_QUERY).orElseThrow(() ->
new IllegalStateException("No query present in method")
);
String query = context.stringValue(Query.class, DataMethod.META_MEMBER_COUNT_QUERY)
.orElseGet(() -> context.stringValue(Query.class)
.orElseThrow(() -> new IllegalStateException("No query present in method")));
return new DefaultStoredQuery<>(
context.getExecutableMethod(),
resultType,
Expand Down

0 comments on commit 250a2e1

Please sign in to comment.