Skip to content

Commit

Permalink
collections: added collection id to record search
Browse files Browse the repository at this point in the history
* collections: remove query from serialization.
  • Loading branch information
alejandromumo authored and slint committed Oct 16, 2024
1 parent 8ef05a1 commit 3e74953
Show file tree
Hide file tree
Showing 11 changed files with 88 additions and 28 deletions.
16 changes: 8 additions & 8 deletions invenio_rdm_records/collections/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ def community(self):
@property
def query(self):
"""Get the collection query."""
q = ""
for _a in self.ancestors:
q += f"({_a.model.search_query}) AND "
q += f"({self.model.search_query})"

# Query must be validated because it is not being built using dsl
Collection.validate_query(q)
return q
import operator
from functools import reduce

from invenio_search.engine import dsl

queries = [dsl.Q("query_string", query=a.search_query) for a in self.ancestors]
queries.append(dsl.Q("query_string", query=self.search_query))
return reduce(operator.and_, queries)

@cached_property
def ancestors(self):
Expand Down
1 change: 1 addition & 0 deletions invenio_rdm_records/collections/links.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class CollectionLinkstemplate(LinksTemplate):
"""Templates for generating links for a collection object."""

def __init__(self, links=None, context=None):
"""Initialize the links template."""
super().__init__(links, context)


Expand Down
5 changes: 5 additions & 0 deletions invenio_rdm_records/collections/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ def breadcrumbs(self):
)
return res

@property
def query(self):
"""Get the collection query."""
return self._collection.query


class CollectionList(ServiceListResult):
"""Collection list item."""
Expand Down
1 change: 0 additions & 1 deletion invenio_rdm_records/collections/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,4 @@ class CollectionSchema(Schema):
depth = fields.Int()
order = fields.Int()
id = fields.Int()
query = fields.Str()
num_records = fields.Int()
2 changes: 1 addition & 1 deletion invenio_rdm_records/collections/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def read(
To resolve by slug, the collection tree ID and community ID must be provided.
"""
if id_:
collection = self.collection_cls.resolve(id_, depth=depth)
collection = self.collection_cls.resolve(id_=id_, depth=depth)
elif slug and tree_slug and community_id:
ctree = CollectionTree.resolve(slug=tree_slug, community_id=community_id)
collection = self.collection_cls.resolve(
Expand Down
6 changes: 6 additions & 0 deletions invenio_rdm_records/resources/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,9 @@ class RDMSearchRequestArgsSchema(SearchRequestArgsSchema):
locale = fields.Str()
status = fields.Str()
include_deleted = fields.Bool()


class CommunityRecordsSearchRequestArgsSchema(SearchRequestArgsSchema):
"""Extend schema with collection_id field."""

collection_id = fields.Integer()
4 changes: 3 additions & 1 deletion invenio_rdm_records/resources/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
ReviewStateError,
ValidationErrorWithMessageAsList,
)
from .args import RDMSearchRequestArgsSchema
from .args import CommunityRecordsSearchRequestArgsSchema, RDMSearchRequestArgsSchema
from .deserializers import ROCrateJSONDeserializer
from .deserializers.errors import DeserializerError
from .errors import HTTPJSONException, HTTPJSONValidationWithMessageAsListException
Expand Down Expand Up @@ -558,6 +558,8 @@ class RDMCommunityRecordsResourceConfig(RecordResourceConfig, ConfiguratorMixin)
default=record_serializers,
)

request_search_args = CommunityRecordsSearchRequestArgsSchema


class RDMRecordCommunitiesResourceConfig(CommunityResourceConfig, ConfiguratorMixin):
"""Record communities resource config."""
Expand Down
10 changes: 9 additions & 1 deletion invenio_rdm_records/services/community_records/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from invenio_records_resources.services.uow import unit_of_work
from invenio_search.engine import dsl

from ...proxies import current_record_communities_service
from ...proxies import current_rdm_records, current_record_communities_service
from ...records.systemfields.deletion_status import RecordDeletionStatusEnum


Expand Down Expand Up @@ -67,6 +67,14 @@ def search(
if extra_filter is not None:
community_filter = community_filter & extra_filter

# Search in a specific collection
if "collection_id" in params:
collections_service = current_rdm_records.collections_service
collection = collections_service.read(
identity=identity, id_=params["collection_id"]
)
community_filter &= collection.query

search = self._search(
"search",
identity,
Expand Down
10 changes: 6 additions & 4 deletions tests/collections/test_collections_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# it under the terms of the MIT License; see LICENSE file for more details.
"""Test suite for the collections programmatic API."""

from invenio_search.engine import dsl

from invenio_rdm_records.collections.api import Collection, CollectionTree


Expand All @@ -25,7 +27,7 @@ def test_create(running_app, db, community, community_owner):
ctree=tree,
)

read_c = Collection.resolve(collection.id)
read_c = Collection.resolve(id_=collection.id)
assert read_c.id == collection.id
assert read_c.title == "My Collection"
assert read_c.collection_tree.id == tree.id
Expand All @@ -38,7 +40,7 @@ def test_create(running_app, db, community, community_owner):
ctree=tree.id,
)

read_c = Collection.resolve(collection.id)
read_c = Collection.resolve(id_=collection.id)
assert read_c.id == collection.id
assert collection.title == "My Collection 2"
assert collection.collection_tree.id == tree.id
Expand All @@ -61,7 +63,7 @@ def test_resolve(running_app, db, community):
)

# Read by ID
read_by_id = Collection.resolve(collection.id)
read_by_id = Collection.resolve(id_=collection.id)
assert read_by_id.id == collection.id

# Read by slug
Expand All @@ -88,7 +90,7 @@ def test_query_build(running_app, db):
slug="my-collection-2",
parent=c1,
)
assert c2.query == "(metadata.title:hello) AND (metadata.creators.name:john)"
assert c2.query == c1.query & dsl.Q("query_string", query=c2.search_query)


def test_children(running_app, db):
Expand Down
23 changes: 11 additions & 12 deletions tests/services/test_collections_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,12 @@ def test_collections_read(
c1 = collections[1]

# Read by id
res = collections_service.read(community_owner.identity, id_=c0.id)
res = collections_service.read(identity=community_owner.identity, id_=c0.id)
assert res._collection.id == c0.id

# Read by slug
res = collections_service.read(
community_owner.identity,
identity=community_owner.identity,
community_id=community.id,
tree_slug=c0.collection_tree.slug,
slug=c0.slug,
Expand Down Expand Up @@ -97,7 +97,7 @@ def test_collections_create(
assert collection.collection_tree.id == tree.id

read_collection = collections_service.read(
community_owner.identity, id_=collection.id
identity=community_owner.identity, id_=collection.id
)
assert read_collection._collection.id == collection.id
assert read_collection._collection.title == "My Collection"
Expand All @@ -123,12 +123,12 @@ def test_collections_add(
c3 = c3._collection

# Read the collection
res = collections_service.read(community_owner.identity, id_=c3.id)
res = collections_service.read(identity=community_owner.identity, id_=c3.id)
assert res._collection.id == c3.id
assert res._collection.title == "Collection 3"

# Read the parent collection
res = collections_service.read(community_owner.identity, id_=c2.id)
res = collections_service.read(identity=community_owner.identity, id_=c2.id)
assert res.to_dict()[c2.id]["children"] == [c3.id]


Expand All @@ -150,7 +150,9 @@ def test_collections_results(
query="metadata.title:baz",
)
# Read the collection tree up to depth 2
res = collections_service.read(community_owner.identity, id_=c0.id, depth=2)
res = collections_service.read(
identity=community_owner.identity, id_=c0.id, depth=2
)
r_dict = res.to_dict()

expected = {
Expand All @@ -171,7 +173,6 @@ def test_collections_results(
},
"num_records": 0,
"order": c0.order,
"query": "(metadata.title:foo)",
"slug": "collection-1",
"title": "Collection 1",
},
Expand All @@ -185,15 +186,16 @@ def test_collections_results(
},
"num_records": 0,
"order": c1.order,
"query": "(metadata.title:foo) AND (metadata.title:bar)",
"slug": "collection-2",
"title": "Collection 2",
},
}
assert not list(dictdiffer.diff(expected, r_dict))

# Read the collection tree up to depth 3
res = collections_service.read(community_owner.identity, id_=c0.id, depth=3)
res = collections_service.read(
identity=community_owner.identity, id_=c0.id, depth=3
)
r_dict = res.to_dict()

# Get the API object, just for the sake of testing
Expand All @@ -216,7 +218,6 @@ def test_collections_results(
},
"num_records": 0,
"order": c0.order,
"query": "(metadata.title:foo)",
"slug": "collection-1",
"title": "Collection 1",
},
Expand All @@ -230,7 +231,6 @@ def test_collections_results(
},
"num_records": 0,
"order": c1.order,
"query": "(metadata.title:foo) AND (metadata.title:bar)",
"slug": "collection-2",
"title": "Collection 2",
},
Expand All @@ -244,7 +244,6 @@ def test_collections_results(
},
"num_records": 0,
"order": c3.order,
"query": "(metadata.title:foo) AND (metadata.title:bar) AND (metadata.title:baz)",
"slug": "collection-3",
"title": "Collection 3",
},
Expand Down
38 changes: 38 additions & 0 deletions tests/services/test_service_community_records.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,13 @@

"""Test community records service."""

from copy import deepcopy

import pytest
from invenio_records_resources.services.errors import PermissionDeniedError
from marshmallow import ValidationError

from invenio_rdm_records.collections.api import Collection, CollectionTree
from invenio_rdm_records.proxies import (
current_community_records_service,
current_rdm_records_service,
Expand Down Expand Up @@ -156,3 +159,38 @@ def test_search_community_records(
community_id=str(community.id),
)
assert results.to_dict()["hits"]["total"] == 3


def test_search_community_records_in_collections(
community, record_community, service, uploader, minimal_record
):
"""Test search for records in a community collection."""
rec1 = deepcopy(minimal_record)
rec1["metadata"]["title"] = "Another record"
record_community.create_record(record_dict=rec1)
record_community.create_record(record_dict=minimal_record)
ctree = CollectionTree.create(
title="Tree 1",
order=10,
community_id=community.id,
slug="tree-1",
)
collection = Collection.create(
title="My Collection",
query='metadata.title:"Another record"',
slug="my-collection",
ctree=ctree,
)
all_results = service.search(
uploader.identity,
community_id=str(community.id),
)
assert all_results.total == 2

results = service.search(
uploader.identity,
community_id=str(community.id),
params={"collection_id": str(collection.id)},
)
assert results.total == 1
assert results.to_dict()["hits"]["hits"][0]["metadata"]["title"] == "Another record"

0 comments on commit 3e74953

Please sign in to comment.