diff --git a/nomenklatura/cli.py b/nomenklatura/cli.py index 8b6f622d..6d320b16 100644 --- a/nomenklatura/cli.py +++ b/nomenklatura/cli.py @@ -10,7 +10,7 @@ from followthemoney.cli.aggregate import sorted_aggregate from nomenklatura.cache import Cache -from nomenklatura.index import MEMORY_INDEX_PATH +from nomenklatura.index import Index from nomenklatura.matching import train_v2_matcher, train_v1_matcher from nomenklatura.store import load_entity_file_store from nomenklatura.resolver import Resolver @@ -64,7 +64,6 @@ def cli() -> None: @click.option("-l", "--limit", type=click.INT, default=5000) @click.option("--algorithm", default=DefaultAlgorithm.NAME) @click.option("--scored/--unscored", is_flag=True, type=click.BOOL, default=True) -@click.option("-i", "--index", type=click.STRING) @click.option( "-c", "--clear", @@ -79,7 +78,7 @@ def xref_file( algorithm: str = DefaultAlgorithm.NAME, limit: int = 5000, scored: bool = True, - index: str = MEMORY_INDEX_PATH, + index: str = Index.name, clear: bool = False, ) -> None: resolver_ = _get_resolver(path, resolver) @@ -103,7 +102,6 @@ def xref_file( algorithm=algorithm_type, scored=scored, limit=limit, - index_path=index, ) resolver_.save() log.info("Xref complete in: %s", resolver_.path) diff --git a/nomenklatura/index/__init__.py b/nomenklatura/index/__init__.py index 0973ef6d..0b0cf721 100644 --- a/nomenklatura/index/__init__.py +++ b/nomenklatura/index/__init__.py @@ -1,40 +1,5 @@ -from importlib import import_module -import logging -from pathlib import Path -from typing import Type, Optional - from nomenklatura.index.index import Index from nomenklatura.index.common import BaseIndex -from nomenklatura.store import View -from nomenklatura.dataset import DS -from nomenklatura.entity import CE - -log = logging.getLogger(__name__) - - -MEMORY_INDEX_PATH = "nomenklatura.index.index.Index" - - -def get_index( - view: View[DS, CE], path: Path, class_path: Optional[str] -) -> BaseIndex[DS, CE]: - """Get the best available index class to use.""" - clazz: Type[BaseIndex[DS, CE]] = Index[DS, CE] - if class_path is not None: - try: - module_path, class_name = class_path.rsplit(".", 1) - module = import_module(module_path) - clazz_ref = getattr(module, class_name) - - clazz = clazz_ref[DS, CE] - except ImportError: - log.warning( - "f`{class_path}` is not available, falling back to in-memory index." - ) - - index = clazz(view, path) - index.build() - return index -__all__ = ["BaseIndex", "Index", "MEMORY_INDEX_PATH", "get_index"] +__all__ = ["BaseIndex", "Index"] diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py deleted file mode 100644 index 3487e002..00000000 --- a/nomenklatura/index/duckdb_index.py +++ /dev/null @@ -1,240 +0,0 @@ -from io import TextIOWrapper -from followthemoney.types import registry -from pathlib import Path -from shutil import rmtree -from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple -import csv -import duckdb -import logging - -from nomenklatura.dataset import DS -from nomenklatura.entity import CE -from nomenklatura.index.common import BaseIndex -from nomenklatura.index.tokenizer import NAME_PART_FIELD, WORD_FIELD, Tokenizer -from nomenklatura.resolver import Pair, Identifier -from nomenklatura.store import View - -BlockingMatches = List[Tuple[Identifier, float]] - -log = logging.getLogger(__name__) - -BATCH_SIZE = 1000 - - -class DuckDBIndex(BaseIndex[DS, CE]): - """ - An index using DuckDB for token matching and scoring, keeping data in memory - until it needs to spill to disk as it approaches the configured memory limit. - - Pairs match if they share one or more tokens. A basic similarity score is calculated - cumulatively based on each token's Term Frequency (TF) and the field's boost factor. - """ - - BOOSTS = { - NAME_PART_FIELD: 2.0, - WORD_FIELD: 0.5, - registry.name.name: 10.0, - # registry.country.name: 1.5, - # registry.date.name: 1.5, - # registry.language: 0.7, - # registry.iban.name: 3.0, - registry.phone.name: 3.0, - registry.email.name: 3.0, - # registry.entity: 0.0, - # registry.topic: 2.1, - registry.address.name: 2.5, - registry.identifier.name: 3.0, - } - - __slots__ = "view", "fields", "tokenizer", "entities" - - def __init__( - self, view: View[DS, CE], data_dir: Path, options: Dict[str, Any] = {} - ): - self.view = view - memory_budget = options.get("memory_budget", None) - self.memory_budget: Optional[int] = ( - int(memory_budget) if memory_budget else None - ) - """Memory budget in megabytes""" - self.max_candidates = int(options.get("max_candidates", 50)) - self.tokenizer = Tokenizer[DS, CE]() - self.data_dir = data_dir - if self.data_dir.exists(): - rmtree(self.data_dir) - self.data_dir.mkdir(parents=True) - self.con = duckdb.connect((self.data_dir / "duckdb_index.db").as_posix()) - self.matching_path = self.data_dir / "matching.csv" - self.matching_path.unlink(missing_ok=True) - self.matching_dump: TextIOWrapper | None = open(self.matching_path, "w") - writer = csv.writer(self.matching_dump) - writer.writerow(["id", "field", "token"]) - - # https://duckdb.org/docs/guides/performance/environment - # > For ideal performance, - # > aggregation-heavy workloads require approx. 5 GB memory per thread and - # > join-heavy workloads require approximately 10 GB memory per thread. - # > Aim for 5-10 GB memory per thread. - if self.memory_budget is not None: - self.con.execute("SET memory_limit = ?;", [f"{self.memory_budget}MB"]) - # > If you have a limited amount of memory, try to limit the number of threads - self.con.execute("SET threads = 1;") - - def build(self) -> None: - """Index all entities in the dataset.""" - log.info("Building index from: %r...", self.view) - self.con.execute("CREATE TABLE boosts (field TEXT, boost FLOAT)") - for field, boost in self.BOOSTS.items(): - self.con.execute("INSERT INTO boosts VALUES (?, ?)", [field, boost]) - - self.con.execute("CREATE TABLE matching (id TEXT, field TEXT, token TEXT)") - self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)") - csv_path = self.data_dir / "mentions.csv" - log.info("Dumping entity tokens to CSV for bulk load into the database...") - with open(csv_path, "w") as fh: - writer = csv.writer(fh) - - # csv.writer type gymnastics - def dump_entity(entity: CE) -> None: - if not entity.schema.matchable or entity.id is None: - return - for field, token in self.tokenizer.entity(entity): - writer.writerow([entity.id, field, token]) - writer.writerow(["id", "field", "token"]) - - for idx, entity in enumerate(self.view.entities()): - dump_entity(entity) - if idx % 50000 == 0: - log.info("Dumped %s entities" % idx) - log.info("Loading data...") - self.con.execute(f"COPY entries from '{csv_path}'") - log.info("Done.") - - self._build_frequencies() - log.info("Index built.") - - def _build_field_len(self) -> None: - self._build_stopwords() - log.info("Calculating field lengths...") - field_len_query = """ - CREATE TABLE IF NOT EXISTS field_len as - SELECT entries.field, entries.id, count(*) as field_len from entries - LEFT OUTER JOIN stopwords - ON stopwords.field = entries.field AND stopwords.token = entries.token - WHERE token_freq is NULL - GROUP BY entries.field, entries.id - """ - self.con.execute(field_len_query) - - def _build_mentions(self) -> None: - self._build_stopwords() - log.info("Calculating mention counts...") - mentions_query = """ - CREATE TABLE IF NOT EXISTS mentions as - SELECT entries.field, entries.id, entries.token, count(*) as mentions - FROM entries - LEFT OUTER JOIN stopwords - ON stopwords.field = entries.field AND stopwords.token = entries.token - WHERE token_freq is NULL - GROUP BY entries.field, entries.id, entries.token - """ - self.con.execute(mentions_query) - - def _build_stopwords(self) -> None: - token_freq_query = """ - SELECT field, token, count(*) as token_freq - FROM entries - GROUP BY field, token - """ - token_freq = self.con.sql(token_freq_query) # noqa - self.con.execute( - """ - CREATE TABLE IF NOT EXISTS stopwords as - SELECT * FROM token_freq where token_freq > 100 - """ - ) - - def _build_frequencies(self) -> None: - self._build_field_len() - self._build_mentions() - log.info("Calculating term frequencies...") - term_frequencies_query = """ - CREATE TABLE IF NOT EXISTS term_frequencies as - SELECT mentions.field, mentions.token, mentions.id, mentions/field_len as tf - FROM field_len - JOIN mentions - ON field_len.field = mentions.field AND field_len.id = mentions.id - """ - self.con.execute(term_frequencies_query) - - def pairs( - self, max_pairs: int = BaseIndex.MAX_PAIRS - ) -> Iterable[Tuple[Pair, float]]: - pairs_query = """ - SELECT "left".id, "right".id, sum(("left".tf + "right".tf) * ifnull(boost, 1)) as score - FROM term_frequencies as "left" - JOIN term_frequencies as "right" - ON "left".field = "right".field AND "left".token = "right".token - LEFT OUTER JOIN boosts - ON "left".field = boosts.field - AND "left".id > "right".id - GROUP BY "left".id, "right".id - ORDER BY score DESC - LIMIT ? - """ - results = self.con.execute(pairs_query, [max_pairs]) - while batch := results.fetchmany(BATCH_SIZE): - for left, right, score in batch: - yield (Identifier.get(left), Identifier.get(right)), score - - def add_matching_subject(self, entity: CE) -> None: - if self.matching_dump is None: - raise Exception("Cannot add matching subject after getting candidates.") - writer = csv.writer(self.matching_dump) - for field, token in self.tokenizer.entity(entity): - writer.writerow([entity.id, field, token]) - - def matches( - self, - ) -> Generator[Tuple[Identifier, BlockingMatches], None, None]: - if self.matching_dump is not None: - self.matching_dump.close() - self.matching_dump = None - log.info("Loading matching subjects...") - self.con.execute(f"COPY matching from '{self.matching_path}'") - log.info("Finished loading matching subjects.") - - match_query = """ - SELECT matching.id, matches.id, sum(matches.tf * ifnull(boost, 1)) as score - FROM term_frequencies as matches - JOIN matching - ON matches.field = matching.field AND matches.token = matching.token - LEFT OUTER JOIN boosts - ON matches.field = boosts.field - GROUP BY matches.id, matching.id - ORDER BY matching.id, score DESC - """ - results = self.con.execute(match_query) - previous_id = None - matches: BlockingMatches = [] - while batch := results.fetchmany(BATCH_SIZE): - for matching_id, match_id, score in batch: - # first row - if previous_id is None: - previous_id = matching_id - # Next pair of subject and candidates - if matching_id != previous_id: - if matches: - yield Identifier.get(previous_id), matches - matches = [] - previous_id = matching_id - matches.append((Identifier.get(match_id), score)) - # Last pair or subject and candidates - if matches and previous_id is not None: - yield Identifier.get(previous_id), matches[: self.max_candidates] - - def __repr__(self) -> str: - return "" % ( - self.view.scope.name, - self.con, - ) diff --git a/nomenklatura/xref.py b/nomenklatura/xref.py index f8901e23..bc22ed5d 100644 --- a/nomenklatura/xref.py +++ b/nomenklatura/xref.py @@ -3,12 +3,13 @@ from followthemoney.schema import Schema from pathlib import Path +from nomenklatura import Index from nomenklatura.dataset import DS from nomenklatura.entity import CE from nomenklatura.store import Store from nomenklatura.judgement import Judgement from nomenklatura.resolver import Resolver -from nomenklatura.index import get_index +from nomenklatura.index import BaseIndex from nomenklatura.matching import DefaultAlgorithm, ScoringAlgorithm from nomenklatura.conflicting_match import ConflictingMatchReporter @@ -31,6 +32,7 @@ def xref( resolver: Resolver[CE], store: Store[DS, CE], index_dir: Path, + index_type: Type[BaseIndex[DS, CE]] = Index, limit: int = 5000, limit_factor: int = 10, scored: bool = True, @@ -41,12 +43,12 @@ def xref( conflicting_match_threshold: Optional[float] = None, focus_dataset: Optional[str] = None, algorithm: Type[ScoringAlgorithm] = DefaultAlgorithm, - index_path: Optional[str] = None, user: Optional[str] = None, ) -> None: log.info("Begin xref: %r, resolver: %s", store, resolver) view = store.default_view(external=external) - index = get_index(view, index_dir, index_path) + index = index_type(view, index_dir) + index.build() conflict_reporter = None if conflicting_match_threshold is not None: conflict_reporter = ConflictingMatchReporter( diff --git a/tests/conftest.py b/tests/conftest.py index c620315c..ee28af17 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,6 @@ from tempfile import mkdtemp from nomenklatura import settings -from nomenklatura.index.duckdb_index import DuckDBIndex from nomenklatura.index.tantivy_index import TantivyIndex from nomenklatura.store import load_entity_file_store, SimpleMemoryStore from nomenklatura.kv import get_redis @@ -82,13 +81,6 @@ def tantivy_index(index_path: Path, dstore: SimpleMemoryStore): yield index -@pytest.fixture(scope="function") -def duckdb_index(index_path: Path, dstore: SimpleMemoryStore): - index = DuckDBIndex(dstore.default_view(), index_path) - index.build() - yield index - - @pytest.fixture(scope="function") def index_path(): index_path = Path(mkdtemp()) / "index-dir" diff --git a/tests/index/test_duckdb_index.py b/tests/index/test_duckdb_index.py deleted file mode 100644 index ed4bf9cb..00000000 --- a/tests/index/test_duckdb_index.py +++ /dev/null @@ -1,150 +0,0 @@ -from collections import defaultdict -from pathlib import Path - -from nomenklatura.dataset import Dataset -from nomenklatura.entity import CompositeEntity -from nomenklatura.index import get_index -from nomenklatura.index.duckdb_index import DuckDBIndex -from nomenklatura.resolver.identifier import Identifier -from nomenklatura.store import SimpleMemoryStore - -DAIMLER = "66ce9f62af8c7d329506da41cb7c36ba058b3d28" -VERBAND_ID = "62ad0fe6f56dbbf6fee57ce3da76e88c437024d5" -VERBAND_BADEN_ID = "69401823a9f0a97cfdc37afa7c3158374e007669" -VERBAND_BADEN_DATA = { - "id": "bla", - "schema": "Company", - "properties": { - "name": ["VERBAND DER METALL UND ELEKTROINDUSTRIE BADEN WURTTEMBERG"] - }, -} - - -def test_import(dstore: SimpleMemoryStore, index_path: Path): - view = dstore.default_view() - index = get_index(view, index_path, "nomenklatura.index.duckdb_index.DuckDBIndex") - assert isinstance(index, DuckDBIndex), type(index) - - -def test_field_lengths(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): - field_names = set() - ids = set() - - field_len_rel = duckdb_index.con.sql("SELECT * FROM field_len") - for field_name, id, field_len in field_len_rel.fetchall(): - field_names.add(field_name) - ids.add(id) - - # Expect to see all matchable entities - # jq .schema tests/fixtures/donations.ijson | sort | uniq -c - # Organizations 17 - # Companies 56 - # Persons 22 - # Addresses 89 - assert len(ids) == 184, len(ids) - - # Expect to see all index fields for the matchable prop types and any applicable synthetic fields - # jq '.properties | keys | .[]' tests/fixtures/donations.ijson --raw-output|sort -u - expected_fields = { - "namepart", - "name", - "country", - "word", - } - assert field_names == expected_fields, field_names - - -def test_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): - ids = set() - field_tokens = defaultdict(set) - - mentions_rel = duckdb_index.con.sql("SELECT * FROM mentions") - for field_name, id, token, count in mentions_rel.fetchall(): - ids.add(id) - field_tokens[field_name].add(token) - - assert len(ids) == 184, len(ids) - assert "verband" in field_tokens["namepart"], field_tokens["namepart"] - assert "gb" in field_tokens["country"], field_tokens["country"] - assert "adolf wurth gmbh" in field_tokens["name"], field_tokens["name"] - assert "dortmund" in field_tokens["word"], field_tokens["word"] - - -def test_index_pairs(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): - view = dstore.default_view() - pairs = list(duckdb_index.pairs()) - - # At least one pair is found - assert len(pairs) > 0, len(pairs) - - # A pair has tokens which overlap - tokenizer = duckdb_index.tokenizer - pair, score = pairs[0] - entity0 = view.get_entity(str(pair[0])) - tokens0 = set(tokenizer.entity(entity0)) - entity1 = view.get_entity(str(pair[1])) - tokens1 = set(tokenizer.entity(entity1)) - overlap = tokens0.intersection(tokens1) - assert len(overlap) > 0, overlap - - # A pair has non-zero score - assert score > 0 - - # pairs are in descending score order - last_score = pairs[0][1] - for pair in pairs[1:]: - assert pair[1] <= last_score - last_score = pair[1] - - # Johanna Quandt <> Frau Johanna Quandt - jq = ( - Identifier.get("9add84cbb7bb48c7552f8ec7ae54de54eed1e361"), - Identifier.get("2d3e50433e36ebe16f3d906b684c9d5124c46d76"), - ) - jq_score = [score for pair, score in pairs if jq == pair][0] - - # Bayerische Motorenwerke AG <> Bayerische Motorenwerke (BMW) AG - bmw = ( - Identifier.get("21cc81bf3b960d2847b66c6c862e7aa9b5e4f487"), - Identifier.get("12570ee94b8dc23bcc080e887539d3742b2a5237"), - ) - bmw_score = [score for pair, score in pairs if bmw == pair][0] - - # More tokens in BMW means lower TF, reducing the score - assert jq_score > bmw_score, (jq_score, bmw_score) - assert jq_score == 19.0, jq_score - assert 3.3 < bmw_score < 3.4, bmw_score - - # FERRING Arzneimittel GmbH <> Clou Container Leasing GmbH - false_pos = ( - Identifier.get("f8867c433ba247cfab74096c73f6ff5e36db3ffe"), - Identifier.get("a061e760dfcf0d5c774fc37c74937193704807b5"), - ) - false_pos_score = [score for pair, score in pairs if false_pos == pair][0] - assert 1.1 < false_pos_score < 1.2, false_pos_score - assert bmw_score > false_pos_score, (bmw_score, false_pos_score) - - assert len(pairs) >= 428, len(pairs) - - -def test_match_score(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex): - """Match an entity that isn't itself in the index""" - dx = Dataset.make({"name": "test", "title": "Test"}) - entity = CompositeEntity.from_data(dx, VERBAND_BADEN_DATA) - duckdb_index.add_matching_subject(entity) - match_sets = list(duckdb_index.matches()) - assert len(match_sets) == 1, match_sets - subject_id, matches = match_sets[0] - assert subject_id == Identifier("bla"), subject_id - - # 9 entities in the index where some token in the query entity matches some - # token in the index. - assert len(matches) == 9, matches - - top_result = matches[0] - assert top_result[0] == Identifier(VERBAND_BADEN_ID), top_result - assert 1.99 < top_result[1] < 2, top_result - - next_result = matches[1] - assert next_result[0] == Identifier(VERBAND_ID), next_result - assert 1.66 < next_result[1] < 1.67, next_result diff --git a/tests/index/test_index.py b/tests/index/test_index.py index 2e590e21..0f218106 100644 --- a/tests/index/test_index.py +++ b/tests/index/test_index.py @@ -62,7 +62,7 @@ def test_index_pairs(dstore: SimpleMemoryStore, dindex: Index): # A pair has non-zero score assert score > 0 -# pairs are in descending score order + # pairs are in descending score order last_score = pairs[0][1] for pair in pairs[1:]: assert pair[1] <= last_score @@ -98,7 +98,7 @@ def test_index_pairs(dstore: SimpleMemoryStore, dindex: Index): assert len(pairs) == 428, len(pairs) - + def test_match_score(dstore: SimpleMemoryStore, dindex: Index): """Match an entity that isn't itself in the index""" dx = Dataset.make({"name": "test", "title": "Test"})