Skip to content

Commit

Permalink
Move even more into the db
Browse files Browse the repository at this point in the history
  • Loading branch information
jbothma committed Nov 6, 2024
1 parent 88e4398 commit 684c5c0
Showing 1 changed file with 28 additions and 61 deletions.
89 changes: 28 additions & 61 deletions nomenklatura/index/duckdb_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from itertools import combinations
from tempfile import mkdtemp
from typing import Any, Dict, Generator, List, Set, Tuple
from typing import Any, Dict, Generator, Iterable, List, Set, Tuple
from followthemoney.types import registry
import duckdb

Expand All @@ -19,6 +19,8 @@

log = logging.getLogger(__name__)

BATCH_SIZE = 1000


class DuckDBIndex(BaseIndex[DS, CE]):
"""
Expand Down Expand Up @@ -57,6 +59,7 @@ def __init__(self, view: View[DS, CE], data_dir: Path):
self.con.execute("SET max_memory = '2GB';")
self.con.execute("SET threads = 1;")
self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)")
self.con.execute("CREATE TABLE boosts (field TEXT, boost FLOAT)")

def dump(self, writer, entity: CE) -> None:

Expand All @@ -77,52 +80,17 @@ def build(self) -> None:
self.dump(writer, entity)
if idx % 10000 == 0:
log.info("Dumped %s entities" % idx)
for field, boost in self.BOOSTS.items():
self.con.execute("INSERT INTO boosts VALUES (?, ?)", [field, boost])

log.info("Loading data...")
self.con.execute(f"COPY entries from '{csv_path}'")

log.info("Index built.")

def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS) -> List[Tuple[Pair, float]]:
csv_path = self.path / "cooccurrences.csv"
with open(csv_path, "w") as fh:
writer = csv.writer(fh)
writer.writerow(["left", "right", "score"])
for pair, score in self.cooccurring_tokens():
writer.writerow([pair[0], pair[1], score])
log.info("Loading co-occurrences...")
self.con.execute('CREATE TABLE cooccurrences ("left" TEXT, "right" TEXT, score FLOAT)')
self.con.execute(f"COPY cooccurrences from '{csv_path}'")
pairs_query = """
SELECT "left", "right", sum(score) as score
FROM cooccurrences
GROUP BY "left", "right"
ORDER BY score DESC
LIMIT ?
"""
pairs_rel = self.con.execute(pairs_query, [max_pairs])
pairs: List[Tuple[Pair, float]] = []
for left, right, score in pairs_rel.fetchall():
pairs.append(((Identifier.get(left), Identifier.get(right)), score))
return pairs

def cooccurring_tokens(self):
logged = defaultdict(int)
for field_name, token, entities in self.frequencies():
logged[field_name] += 1
if logged[field_name] % 10000 == 0:
log.info("Pairwise xref [%s]: %d" % (field_name, logged[field_name]))
boost = self.BOOSTS.get(field_name, 1.0)
for (left, lw), (right, rw) in combinations(entities, 2):
if lw == 0.0 or rw == 0.0:
continue
pair = (max(left, right), min(left, right))
score = (lw + rw) * boost
yield pair, score

def frequencies(
self,
) -> Generator[Tuple[str, str, List[Tuple[Identifier, float]]], None, None]:
def pairs(
self, max_pairs: int = BaseIndex.MAX_PAIRS
) -> Iterable[Tuple[Pair, float]]:
field_len_query = """
SELECT field, id, count(*) as field_len from entries
GROUP BY field, id
Expand All @@ -140,33 +108,32 @@ def frequencies(
GROUP BY field, token
"""
token_freq = self.con.sql(token_freq_query)
query = """
SELECT mentions.field, mentions.token, mentions.id, mentions/field_len
term_frequencies_query = """
SELECT mentions.field, mentions.token, mentions.id, mentions/field_len as tf
FROM field_len
JOIN mentions
ON field_len.field = mentions.field AND field_len.id = mentions.id
JOIN token_freq
ON token_freq.field = mentions.field AND token_freq.token = mentions.token
where token_freq < 100
ORDER BY mentions.field, mentions.token
"""
rel = self.con.sql(query)
row = rel.fetchone()
entities = [] # the entities in this field, token group
field_name = None
token = None
while row is not None:
field_name, token, id, freq = row
entities.append((Identifier.get(id), freq))

row = rel.fetchone()
if row is None:
yield field_name, token, entities
break
new_field_name, new_token, _, _ = row
if new_field_name != field_name or new_token != token:
yield field_name, token, entities
entities = []
term_frequencies = self.con.sql(term_frequencies_query)
pairs_query = """
SELECT "left".id, "right".id, sum(("left".tf + "right".tf) * boost) as score
FROM term_frequencies as "left"
JOIN term_frequencies as "right"
ON "left".field = "right".field AND "left".token = "right".token
JOIN boosts
ON "left".field = boosts.field
WHERE "left".id > "right".id
GROUP BY "left".id, "right".id
ORDER BY score DESC
LIMIT ?
"""
results = self.con.execute(pairs_query, [max_pairs])
while batch := results.fetchmany(BATCH_SIZE):
for left, right, score in batch:
yield (Identifier.get(left), Identifier.get(right)), score

def __repr__(self) -> str:
return "<DuckDBIndex(%r, %r)>" % (
Expand Down

0 comments on commit 684c5c0

Please sign in to comment.