From 084527b0d94221751445356fc8e68c024f7d501f Mon Sep 17 00:00:00 2001 From: JD Bothma Date: Thu, 21 Nov 2024 13:20:01 +0000 Subject: [PATCH] Remove index selection, add slim enricher base --- nomenklatura/cli.py | 4 +-- nomenklatura/enrich/common.py | 41 +++++++++++++++++------------- nomenklatura/index/__init__.py | 30 +--------------------- nomenklatura/index/common.py | 4 +-- nomenklatura/xref.py | 8 +++--- tests/index/test_index.py | 46 +++++++++++++++++++++++++++++++--- 6 files changed, 74 insertions(+), 59 deletions(-) diff --git a/nomenklatura/cli.py b/nomenklatura/cli.py index 51b2b3e8..6d320b16 100644 --- a/nomenklatura/cli.py +++ b/nomenklatura/cli.py @@ -10,7 +10,7 @@ from followthemoney.cli.aggregate import sorted_aggregate from nomenklatura.cache import Cache -from nomenklatura.index import Index, INDEX_TYPES +from nomenklatura.index import Index from nomenklatura.matching import train_v2_matcher, train_v1_matcher from nomenklatura.store import load_entity_file_store from nomenklatura.resolver import Resolver @@ -64,7 +64,6 @@ def cli() -> None: @click.option("-l", "--limit", type=click.INT, default=5000) @click.option("--algorithm", default=DefaultAlgorithm.NAME) @click.option("--scored/--unscored", is_flag=True, type=click.BOOL, default=True) -@click.option("-i", "--index", type=click.Choice(INDEX_TYPES)) @click.option( "-c", "--clear", @@ -103,7 +102,6 @@ def xref_file( algorithm=algorithm_type, scored=scored, limit=limit, - index_type=index, ) resolver_.save() log.info("Xref complete in: %s", resolver_.path) diff --git a/nomenklatura/enrich/common.py b/nomenklatura/enrich/common.py index c4523227..be20c42c 100644 --- a/nomenklatura/enrich/common.py +++ b/nomenklatura/enrich/common.py @@ -31,7 +31,7 @@ class EnrichmentAbort(Exception): pass -class Enricher(Generic[DS], ABC): +class BaseEnricher(Generic[DS]): def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig): self.dataset = dataset self.cache = cache @@ -39,7 +39,6 @@ def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig): self.cache_days = int(config.pop("cache_days", 90)) self._filter_schemata = config.pop("schemata", []) self._filter_topics = config.pop("topics", []) - self._session: Optional[Session] = None def get_config_expand( self, name: str, default: Optional[str] = None @@ -55,6 +54,28 @@ def get_config_int(self, name: str, default: Union[int, str]) -> int: def get_config_bool(self, name: str, default: Union[bool, str] = False) -> int: return as_bool(self.config.get(name, default)) + def _filter_entity(self, entity: CompositeEntity) -> bool: + """Check if the given entity should be filtered out. Filters + can be applied by schema or by topic.""" + if len(self._filter_schemata): + if entity.schema.name not in self._filter_schemata: + return False + _filter_topics = set(self._filter_topics) + if "all" in _filter_topics: + assert isinstance(registry.topic, TopicType) + _filter_topics.update(registry.topic.names.keys()) + if len(_filter_topics): + topics = set(entity.get_type_values(registry.topic)) + if not len(topics.intersection(_filter_topics)): + return False + return True + + +class Enricher(BaseEnricher[DS], ABC): + def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig): + super().__init__(dataset, cache, config) + self._session: Optional[Session] = None + @property def session(self) -> Session: if self._session is None: @@ -167,22 +188,6 @@ def make_entity(self, entity: CE, schema: str) -> CE: """Create a new entity of the given schema.""" return self._make_data_entity(entity, {"schema": schema}) - def _filter_entity(self, entity: CompositeEntity) -> bool: - """Check if the given entity should be filtered out. Filters - can be applied by schema or by topic.""" - if len(self._filter_schemata): - if entity.schema.name not in self._filter_schemata: - return False - _filter_topics = set(self._filter_topics) - if "all" in _filter_topics: - assert isinstance(registry.topic, TopicType) - _filter_topics.update(registry.topic.names.keys()) - if len(_filter_topics): - topics = set(entity.get_type_values(registry.topic)) - if not len(topics.intersection(_filter_topics)): - return False - return True - def match_wrapped(self, entity: CE) -> Generator[CE, None, None]: if not self._filter_entity(entity): return diff --git a/nomenklatura/index/__init__.py b/nomenklatura/index/__init__.py index bf33cfdb..0b0cf721 100644 --- a/nomenklatura/index/__init__.py +++ b/nomenklatura/index/__init__.py @@ -1,33 +1,5 @@ -import logging -from pathlib import Path -from typing import Type, Optional - from nomenklatura.index.index import Index from nomenklatura.index.common import BaseIndex -from nomenklatura.store import View -from nomenklatura.dataset import DS -from nomenklatura.entity import CE - -log = logging.getLogger(__name__) -INDEX_TYPES = ["tantivy", Index.name] - - -def get_index( - view: View[DS, CE], path: Path, type_: Optional[str] -) -> BaseIndex[DS, CE]: - """Get the best available index class to use.""" - clazz: Type[BaseIndex[DS, CE]] = Index[DS, CE] - if type_ == "tantivy": - try: - from nomenklatura.index.tantivy_index import TantivyIndex - - clazz = TantivyIndex[DS, CE] - except ImportError: - log.warning("`tantivy` is not available, falling back to in-memory index.") - - index = clazz(view, path) - index.build() - return index -__all__ = ["BaseIndex", "Index", "TantivyIndex", "get_index"] +__all__ = ["BaseIndex", "Index"] diff --git a/nomenklatura/index/common.py b/nomenklatura/index/common.py index 8cc75d49..b630a708 100644 --- a/nomenklatura/index/common.py +++ b/nomenklatura/index/common.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Generic, List, Tuple +from typing import Generic, Iterable, List, Tuple from nomenklatura.resolver import Identifier from nomenklatura.dataset import DS from nomenklatura.entity import CE @@ -18,7 +18,7 @@ def build(self) -> None: def pairs( self, max_pairs: int = MAX_PAIRS - ) -> List[Tuple[Tuple[Identifier, Identifier], float]]: + ) -> Iterable[Tuple[Tuple[Identifier, Identifier], float]]: raise NotImplementedError def match(self, entity: CE) -> List[Tuple[Identifier, float]]: diff --git a/nomenklatura/xref.py b/nomenklatura/xref.py index 06fdfb13..bc22ed5d 100644 --- a/nomenklatura/xref.py +++ b/nomenklatura/xref.py @@ -3,12 +3,13 @@ from followthemoney.schema import Schema from pathlib import Path +from nomenklatura import Index from nomenklatura.dataset import DS from nomenklatura.entity import CE from nomenklatura.store import Store from nomenklatura.judgement import Judgement from nomenklatura.resolver import Resolver -from nomenklatura.index import get_index +from nomenklatura.index import BaseIndex from nomenklatura.matching import DefaultAlgorithm, ScoringAlgorithm from nomenklatura.conflicting_match import ConflictingMatchReporter @@ -31,6 +32,7 @@ def xref( resolver: Resolver[CE], store: Store[DS, CE], index_dir: Path, + index_type: Type[BaseIndex[DS, CE]] = Index, limit: int = 5000, limit_factor: int = 10, scored: bool = True, @@ -41,12 +43,12 @@ def xref( conflicting_match_threshold: Optional[float] = None, focus_dataset: Optional[str] = None, algorithm: Type[ScoringAlgorithm] = DefaultAlgorithm, - index_type: Optional[str] = None, user: Optional[str] = None, ) -> None: log.info("Begin xref: %r, resolver: %s", store, resolver) view = store.default_view(external=external) - index = get_index(view, index_dir, index_type) + index = index_type(view, index_dir) + index.build() conflict_reporter = None if conflicting_match_threshold is not None: conflict_reporter = ConflictingMatchReporter( diff --git a/tests/index/test_index.py b/tests/index/test_index.py index dc5d6913..0f218106 100644 --- a/tests/index/test_index.py +++ b/tests/index/test_index.py @@ -46,7 +46,11 @@ def test_index_persist(dstore: SimpleMemoryStore, dindex): def test_index_pairs(dstore: SimpleMemoryStore, dindex: Index): view = dstore.default_view() pairs = dindex.pairs() - assert len(pairs) > 0, pairs + + # At least one pair is found + assert len(pairs) > 0, len(pairs) + + # A pair has tokens which overlap tokenizer = dindex.tokenizer pair, score = pairs[0] entity0 = view.get_entity(str(pair[0])) @@ -55,10 +59,44 @@ def test_index_pairs(dstore: SimpleMemoryStore, dindex: Index): tokens1 = set(tokenizer.entity(entity1)) overlap = tokens0.intersection(tokens1) assert len(overlap) > 0, overlap - # assert "Schnabel" in (overlap, tokens0, tokens1) - # assert "Schnabel" in (entity0.caption, entity1.caption) + + # A pair has non-zero score assert score > 0 - # assert False + # pairs are in descending score order + last_score = pairs[0][1] + for pair in pairs[1:]: + assert pair[1] <= last_score + last_score = pair[1] + + # Johanna Quandt <> Frau Johanna Quandt + jq = ( + Identifier.get("9add84cbb7bb48c7552f8ec7ae54de54eed1e361"), + Identifier.get("2d3e50433e36ebe16f3d906b684c9d5124c46d76"), + ) + jq_score = [score for pair, score in pairs if jq == pair][0] + + # Bayerische Motorenwerke AG <> Bayerische Motorenwerke (BMW) AG + bmw = ( + Identifier.get("21cc81bf3b960d2847b66c6c862e7aa9b5e4f487"), + Identifier.get("12570ee94b8dc23bcc080e887539d3742b2a5237"), + ) + bmw_score = [score for pair, score in pairs if bmw == pair][0] + + # More tokens in BMW means lower TF, reducing the score + assert jq_score > bmw_score, (jq_score, bmw_score) + assert jq_score == 19.0, jq_score + assert 3.3 < bmw_score < 3.4, bmw_score + + # FERRING Arzneimittel GmbH <> Clou Container Leasing GmbH + false_pos = ( + Identifier.get("f8867c433ba247cfab74096c73f6ff5e36db3ffe"), + Identifier.get("a061e760dfcf0d5c774fc37c74937193704807b5"), + ) + false_pos_score = [score for pair, score in pairs if false_pos == pair][0] + assert 1.1 < false_pos_score < 1.2, false_pos_score + assert bmw_score > false_pos_score, (bmw_score, false_pos_score) + + assert len(pairs) == 428, len(pairs) def test_match_score(dstore: SimpleMemoryStore, dindex: Index):