Skip to content

Commit

Permalink
Remove index selection, add slim enricher base
Browse files Browse the repository at this point in the history
  • Loading branch information
jbothma committed Nov 22, 2024
1 parent 2a55e85 commit 084527b
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 59 deletions.
4 changes: 1 addition & 3 deletions nomenklatura/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from followthemoney.cli.aggregate import sorted_aggregate

from nomenklatura.cache import Cache
from nomenklatura.index import Index, INDEX_TYPES
from nomenklatura.index import Index
from nomenklatura.matching import train_v2_matcher, train_v1_matcher
from nomenklatura.store import load_entity_file_store
from nomenklatura.resolver import Resolver
Expand Down Expand Up @@ -64,7 +64,6 @@ def cli() -> None:
@click.option("-l", "--limit", type=click.INT, default=5000)
@click.option("--algorithm", default=DefaultAlgorithm.NAME)
@click.option("--scored/--unscored", is_flag=True, type=click.BOOL, default=True)
@click.option("-i", "--index", type=click.Choice(INDEX_TYPES))
@click.option(
"-c",
"--clear",
Expand Down Expand Up @@ -103,7 +102,6 @@ def xref_file(
algorithm=algorithm_type,
scored=scored,
limit=limit,
index_type=index,
)
resolver_.save()
log.info("Xref complete in: %s", resolver_.path)
Expand Down
41 changes: 23 additions & 18 deletions nomenklatura/enrich/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,14 @@ class EnrichmentAbort(Exception):
pass


class Enricher(Generic[DS], ABC):
class BaseEnricher(Generic[DS]):
def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig):
self.dataset = dataset
self.cache = cache
self.config = config
self.cache_days = int(config.pop("cache_days", 90))
self._filter_schemata = config.pop("schemata", [])
self._filter_topics = config.pop("topics", [])
self._session: Optional[Session] = None

def get_config_expand(
self, name: str, default: Optional[str] = None
Expand All @@ -55,6 +54,28 @@ def get_config_int(self, name: str, default: Union[int, str]) -> int:
def get_config_bool(self, name: str, default: Union[bool, str] = False) -> int:
return as_bool(self.config.get(name, default))

def _filter_entity(self, entity: CompositeEntity) -> bool:
"""Check if the given entity should be filtered out. Filters
can be applied by schema or by topic."""
if len(self._filter_schemata):
if entity.schema.name not in self._filter_schemata:
return False
_filter_topics = set(self._filter_topics)
if "all" in _filter_topics:
assert isinstance(registry.topic, TopicType)
_filter_topics.update(registry.topic.names.keys())
if len(_filter_topics):
topics = set(entity.get_type_values(registry.topic))
if not len(topics.intersection(_filter_topics)):
return False
return True


class Enricher(BaseEnricher[DS], ABC):
def __init__(self, dataset: DS, cache: Cache, config: EnricherConfig):
super().__init__(dataset, cache, config)
self._session: Optional[Session] = None

@property
def session(self) -> Session:
if self._session is None:
Expand Down Expand Up @@ -167,22 +188,6 @@ def make_entity(self, entity: CE, schema: str) -> CE:
"""Create a new entity of the given schema."""
return self._make_data_entity(entity, {"schema": schema})

def _filter_entity(self, entity: CompositeEntity) -> bool:
"""Check if the given entity should be filtered out. Filters
can be applied by schema or by topic."""
if len(self._filter_schemata):
if entity.schema.name not in self._filter_schemata:
return False
_filter_topics = set(self._filter_topics)
if "all" in _filter_topics:
assert isinstance(registry.topic, TopicType)
_filter_topics.update(registry.topic.names.keys())
if len(_filter_topics):
topics = set(entity.get_type_values(registry.topic))
if not len(topics.intersection(_filter_topics)):
return False
return True

def match_wrapped(self, entity: CE) -> Generator[CE, None, None]:
if not self._filter_entity(entity):
return
Expand Down
30 changes: 1 addition & 29 deletions nomenklatura/index/__init__.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,5 @@
import logging
from pathlib import Path
from typing import Type, Optional

from nomenklatura.index.index import Index
from nomenklatura.index.common import BaseIndex
from nomenklatura.store import View
from nomenklatura.dataset import DS
from nomenklatura.entity import CE

log = logging.getLogger(__name__)
INDEX_TYPES = ["tantivy", Index.name]


def get_index(
view: View[DS, CE], path: Path, type_: Optional[str]
) -> BaseIndex[DS, CE]:
"""Get the best available index class to use."""
clazz: Type[BaseIndex[DS, CE]] = Index[DS, CE]
if type_ == "tantivy":
try:
from nomenklatura.index.tantivy_index import TantivyIndex

clazz = TantivyIndex[DS, CE]
except ImportError:
log.warning("`tantivy` is not available, falling back to in-memory index.")

index = clazz(view, path)
index.build()
return index


__all__ = ["BaseIndex", "Index", "TantivyIndex", "get_index"]
__all__ = ["BaseIndex", "Index"]
4 changes: 2 additions & 2 deletions nomenklatura/index/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Generic, List, Tuple
from typing import Generic, Iterable, List, Tuple
from nomenklatura.resolver import Identifier
from nomenklatura.dataset import DS
from nomenklatura.entity import CE
Expand All @@ -18,7 +18,7 @@ def build(self) -> None:

def pairs(
self, max_pairs: int = MAX_PAIRS
) -> List[Tuple[Tuple[Identifier, Identifier], float]]:
) -> Iterable[Tuple[Tuple[Identifier, Identifier], float]]:
raise NotImplementedError

def match(self, entity: CE) -> List[Tuple[Identifier, float]]:
Expand Down
8 changes: 5 additions & 3 deletions nomenklatura/xref.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from followthemoney.schema import Schema
from pathlib import Path

from nomenklatura import Index
from nomenklatura.dataset import DS
from nomenklatura.entity import CE
from nomenklatura.store import Store
from nomenklatura.judgement import Judgement
from nomenklatura.resolver import Resolver
from nomenklatura.index import get_index
from nomenklatura.index import BaseIndex
from nomenklatura.matching import DefaultAlgorithm, ScoringAlgorithm
from nomenklatura.conflicting_match import ConflictingMatchReporter

Expand All @@ -31,6 +32,7 @@ def xref(
resolver: Resolver[CE],
store: Store[DS, CE],
index_dir: Path,
index_type: Type[BaseIndex[DS, CE]] = Index,
limit: int = 5000,
limit_factor: int = 10,
scored: bool = True,
Expand All @@ -41,12 +43,12 @@ def xref(
conflicting_match_threshold: Optional[float] = None,
focus_dataset: Optional[str] = None,
algorithm: Type[ScoringAlgorithm] = DefaultAlgorithm,
index_type: Optional[str] = None,
user: Optional[str] = None,
) -> None:
log.info("Begin xref: %r, resolver: %s", store, resolver)
view = store.default_view(external=external)
index = get_index(view, index_dir, index_type)
index = index_type(view, index_dir)
index.build()
conflict_reporter = None
if conflicting_match_threshold is not None:
conflict_reporter = ConflictingMatchReporter(
Expand Down
46 changes: 42 additions & 4 deletions tests/index/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,11 @@ def test_index_persist(dstore: SimpleMemoryStore, dindex):
def test_index_pairs(dstore: SimpleMemoryStore, dindex: Index):
view = dstore.default_view()
pairs = dindex.pairs()
assert len(pairs) > 0, pairs

# At least one pair is found
assert len(pairs) > 0, len(pairs)

# A pair has tokens which overlap
tokenizer = dindex.tokenizer
pair, score = pairs[0]
entity0 = view.get_entity(str(pair[0]))
Expand All @@ -55,10 +59,44 @@ def test_index_pairs(dstore: SimpleMemoryStore, dindex: Index):
tokens1 = set(tokenizer.entity(entity1))
overlap = tokens0.intersection(tokens1)
assert len(overlap) > 0, overlap
# assert "Schnabel" in (overlap, tokens0, tokens1)
# assert "Schnabel" in (entity0.caption, entity1.caption)

# A pair has non-zero score
assert score > 0
# assert False
# pairs are in descending score order
last_score = pairs[0][1]
for pair in pairs[1:]:
assert pair[1] <= last_score
last_score = pair[1]

# Johanna Quandt <> Frau Johanna Quandt
jq = (
Identifier.get("9add84cbb7bb48c7552f8ec7ae54de54eed1e361"),
Identifier.get("2d3e50433e36ebe16f3d906b684c9d5124c46d76"),
)
jq_score = [score for pair, score in pairs if jq == pair][0]

# Bayerische Motorenwerke AG <> Bayerische Motorenwerke (BMW) AG
bmw = (
Identifier.get("21cc81bf3b960d2847b66c6c862e7aa9b5e4f487"),
Identifier.get("12570ee94b8dc23bcc080e887539d3742b2a5237"),
)
bmw_score = [score for pair, score in pairs if bmw == pair][0]

# More tokens in BMW means lower TF, reducing the score
assert jq_score > bmw_score, (jq_score, bmw_score)
assert jq_score == 19.0, jq_score
assert 3.3 < bmw_score < 3.4, bmw_score

# FERRING Arzneimittel GmbH <> Clou Container Leasing GmbH
false_pos = (
Identifier.get("f8867c433ba247cfab74096c73f6ff5e36db3ffe"),
Identifier.get("a061e760dfcf0d5c774fc37c74937193704807b5"),
)
false_pos_score = [score for pair, score in pairs if false_pos == pair][0]
assert 1.1 < false_pos_score < 1.2, false_pos_score
assert bmw_score > false_pos_score, (bmw_score, false_pos_score)

assert len(pairs) == 428, len(pairs)


def test_match_score(dstore: SimpleMemoryStore, dindex: Index):
Expand Down

0 comments on commit 084527b

Please sign in to comment.