diff --git a/nomenklatura/index/duckdb_index.py b/nomenklatura/index/duckdb_index.py index e8aa60c9..e4dfb9df 100644 --- a/nomenklatura/index/duckdb_index.py +++ b/nomenklatura/index/duckdb_index.py @@ -4,7 +4,7 @@ import logging from itertools import combinations from tempfile import mkdtemp -from typing import Any, Dict, Generator, List, Set, Tuple +from typing import Any, Dict, Generator, Iterable, List, Set, Tuple from followthemoney.types import registry import duckdb @@ -19,6 +19,8 @@ log = logging.getLogger(__name__) +BATCH_SIZE = 1000 + class DuckDBIndex(BaseIndex[DS, CE]): """ @@ -57,6 +59,7 @@ def __init__(self, view: View[DS, CE], data_dir: Path): self.con.execute("SET max_memory = '2GB';") self.con.execute("SET threads = 1;") self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)") + self.con.execute("CREATE TABLE boosts (field TEXT, boost FLOAT)") def dump(self, writer, entity: CE) -> None: @@ -77,52 +80,17 @@ def build(self) -> None: self.dump(writer, entity) if idx % 10000 == 0: log.info("Dumped %s entities" % idx) + for field, boost in self.BOOSTS.items(): + self.con.execute("INSERT INTO boosts VALUES (?, ?)", [field, boost]) log.info("Loading data...") self.con.execute(f"COPY entries from '{csv_path}'") log.info("Index built.") - def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS) -> List[Tuple[Pair, float]]: - csv_path = self.path / "cooccurrences.csv" - with open(csv_path, "w") as fh: - writer = csv.writer(fh) - writer.writerow(["left", "right", "score"]) - for pair, score in self.cooccurring_tokens(): - writer.writerow([pair[0], pair[1], score]) - log.info("Loading co-occurrences...") - self.con.execute('CREATE TABLE cooccurrences ("left" TEXT, "right" TEXT, score FLOAT)') - self.con.execute(f"COPY cooccurrences from '{csv_path}'") - pairs_query = """ - SELECT "left", "right", sum(score) as score - FROM cooccurrences - GROUP BY "left", "right" - ORDER BY score DESC - LIMIT ? - """ - pairs_rel = self.con.execute(pairs_query, [max_pairs]) - pairs: List[Tuple[Pair, float]] = [] - for left, right, score in pairs_rel.fetchall(): - pairs.append(((Identifier.get(left), Identifier.get(right)), score)) - return pairs - - def cooccurring_tokens(self): - logged = defaultdict(int) - for field_name, token, entities in self.frequencies(): - logged[field_name] += 1 - if logged[field_name] % 10000 == 0: - log.info("Pairwise xref [%s]: %d" % (field_name, logged[field_name])) - boost = self.BOOSTS.get(field_name, 1.0) - for (left, lw), (right, rw) in combinations(entities, 2): - if lw == 0.0 or rw == 0.0: - continue - pair = (max(left, right), min(left, right)) - score = (lw + rw) * boost - yield pair, score - - def frequencies( - self, - ) -> Generator[Tuple[str, str, List[Tuple[Identifier, float]]], None, None]: + def pairs( + self, max_pairs: int = BaseIndex.MAX_PAIRS + ) -> Iterable[Tuple[Pair, float]]: field_len_query = """ SELECT field, id, count(*) as field_len from entries GROUP BY field, id @@ -140,33 +108,32 @@ def frequencies( GROUP BY field, token """ token_freq = self.con.sql(token_freq_query) - query = """ - SELECT mentions.field, mentions.token, mentions.id, mentions/field_len + term_frequencies_query = """ + SELECT mentions.field, mentions.token, mentions.id, mentions/field_len as tf FROM field_len JOIN mentions ON field_len.field = mentions.field AND field_len.id = mentions.id JOIN token_freq ON token_freq.field = mentions.field AND token_freq.token = mentions.token where token_freq < 100 - ORDER BY mentions.field, mentions.token """ - rel = self.con.sql(query) - row = rel.fetchone() - entities = [] # the entities in this field, token group - field_name = None - token = None - while row is not None: - field_name, token, id, freq = row - entities.append((Identifier.get(id), freq)) - - row = rel.fetchone() - if row is None: - yield field_name, token, entities - break - new_field_name, new_token, _, _ = row - if new_field_name != field_name or new_token != token: - yield field_name, token, entities - entities = [] + term_frequencies = self.con.sql(term_frequencies_query) + pairs_query = """ + SELECT "left".id, "right".id, sum(("left".tf + "right".tf) * boost) as score + FROM term_frequencies as "left" + JOIN term_frequencies as "right" + ON "left".field = "right".field AND "left".token = "right".token + JOIN boosts + ON "left".field = boosts.field + WHERE "left".id > "right".id + GROUP BY "left".id, "right".id + ORDER BY score DESC + LIMIT ? + """ + results = self.con.execute(pairs_query, [max_pairs]) + while batch := results.fetchmany(BATCH_SIZE): + for left, right, score in batch: + yield (Identifier.get(left), Identifier.get(right)), score def __repr__(self) -> str: return "" % (