Skip to content

Commit

Permalink
Move more into the db
Browse files Browse the repository at this point in the history
  • Loading branch information
jbothma committed Nov 6, 2024
1 parent c2fbc29 commit 88e4398
Showing 1 changed file with 50 additions and 92 deletions.
142 changes: 50 additions & 92 deletions nomenklatura/index/duckdb_index.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import csv
from pathlib import Path
import logging
Expand Down Expand Up @@ -52,6 +53,9 @@ def __init__(self, view: View[DS, CE], data_dir: Path):
self.tokenizer = Tokenizer[DS, CE]()
self.path = Path(mkdtemp())
self.con = duckdb.connect((self.path / "duckdb_index.db").as_posix())
self.con.execute("SET memory_limit = '2GB';")
self.con.execute("SET max_memory = '2GB';")
self.con.execute("SET threads = 1;")
self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)")

def dump(self, writer, entity: CE) -> None:
Expand All @@ -77,122 +81,76 @@ def build(self) -> None:
log.info("Loading data...")
self.con.execute(f"COPY entries from '{csv_path}'")

self.calculate_frequencies()
log.info("Index built.")

def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS):
pairs: Dict[Pair, float] = {}
def pairs(self, max_pairs: int = BaseIndex.MAX_PAIRS) -> List[Tuple[Pair, float]]:
csv_path = self.path / "cooccurrences.csv"
with open(csv_path, "w") as fh:
writer = csv.writer(fh)
writer.writerow(["left", "right", "score"])
for pair, score in self.cooccurring_tokens():
writer.writerow([pair[0], pair[1], score])
log.info("Loading co-occurrences...")
self.con.execute('CREATE TABLE cooccurrences ("left" TEXT, "right" TEXT, score FLOAT)')
self.con.execute(f"COPY cooccurrences from '{csv_path}'")
pairs_query = """
SELECT "left", "right", sum(score) as score
FROM cooccurrences
GROUP BY "left", "right"
ORDER BY score DESC
LIMIT ?
"""
pairs_rel = self.con.execute(pairs_query, [max_pairs])
pairs: List[Tuple[Pair, float]] = []
for left, right, score in pairs_rel.fetchall():
pairs.append(((Identifier.get(left), Identifier.get(right)), score))
return pairs

def cooccurring_tokens(self):
logged = defaultdict(int)
for field_name, token, entities in self.frequencies():
logged[field_name] += 1
if logged[field_name] % 10000 == 0:
log.info("Pairwise xref [%s]: %d" % (field_name, logged[field_name]))
boost = self.BOOSTS.get(field_name, 1.0)
for (left, lw), (right, rw) in combinations(entities, 2):
if lw == 0.0 or rw == 0.0:
continue
pair = (max(left, right), min(left, right))
if pair not in pairs:
pairs[pair] = 0
score = (lw + rw) * boost
pairs[pair] += score
return sorted(pairs.items(), key=lambda p: p[1], reverse=True)[:max_pairs]
yield pair, score

def field_lengths(self):
def frequencies(
self,
) -> Generator[Tuple[str, str, List[Tuple[Identifier, float]]], None, None]:
field_len_query = """
SELECT field, id, count(*) as field_len from entries
GROUP BY field, id
ORDER by field, id
"""
field_len_rel = self.con.sql(field_len_query)
row = field_len_rel.fetchone()
while row is not None:
yield row
row = field_len_rel.fetchone()

def mentions(self) -> Generator[Tuple[str, str, str, int], None, None]:
"""Yields tuples of (field_name, entity_id, token, mention_count)"""

field_len = self.con.sql(field_len_query)
mentions_query = """
SELECT field, id, token, count(*) as mentions
FROM entries
GROUP BY field, id, token
ORDER by field, id, token
"""
mentions_rel = self.con.sql(mentions_query)
row = mentions_rel.fetchone()
while row is not None:
yield row
row = mentions_rel.fetchone()

def common_tokens(self) -> Set[Tuple[str, str]]:
"""Yields tuples of (field_name, token)"""
query = """
SELECT field, token, count(*) as frequency
mentions = self.con.sql(mentions_query)
token_freq_query = """
SELECT field, token, count(*) as token_freq
FROM entries
GROUP BY field, token
"""
token_counts_rel = self.con.sql(query)
filter_query = "SELECT * from token_counts_rel where frequency > 100"
common_tokens_rel = self.con.sql(filter_query)
tokens: Set[Tuple[str, str]] = set()
for field_name, token, freq in common_tokens_rel.fetchall():
tokens.add((field_name, token))
return tokens

def id_grouped_mentions(
self,
) -> Generator[Tuple[str, str, int, List[Tuple[str, int]]], None, None]:
"""
Yields tuples of (field_name, entity_id, field_len, [(token, mention_count)])
"""
common_tokens = self.common_tokens()
mentions_gen = self.mentions()
mention_row = None
# Read all field lengths into memory because the concurrent iteration
# sees to be exiting the outer loop early and giving partial results.
for field_name, id, field_len in list(self.field_lengths()):
mentions = []
try:
if mention_row is None: # first iteration
mention_row = next(mentions_gen)
mention_field_name, mention_id, token, mention_count = mention_row

while mention_field_name == field_name and mention_id == id:
if (mention_field_name, token) not in common_tokens:
mentions.append((token, mention_count))

mention_row = next(mentions_gen)
mention_field_name, mention_id, token, mention_count = mention_row

yield field_name, id, field_len, mentions

except StopIteration:
yield field_name, id, field_len, mentions
break

def calculate_frequencies(self) -> None:
csv_path = self.path / "frequencies.csv"
with open(csv_path, "w") as fh:
writer = csv.writer(fh)
writer.writerow(["field", "id", "token", "frequency"])

for field_name, id, field_len, mentions in self.id_grouped_mentions():
for token, freq in mentions:
writer.writerow([field_name, id, token, freq / field_len])

log.info(f"Loading frequencies data... ({csv_path})")
self.con.execute(
"CREATE TABLE frequencies (field TEXT, id TEXT, token TEXT, frequency FLOAT)"
)
self.con.execute(f"COPY frequencies from '{csv_path}'")
log.info("Frequencies are loaded")

def frequencies(
self,
) -> Generator[Tuple[str, str, List[Tuple[Identifier, float]]], None, None]:
token_freq = self.con.sql(token_freq_query)
query = """
SELECT field, token, id, frequency
FROM frequencies
ORDER by field, token
SELECT mentions.field, mentions.token, mentions.id, mentions/field_len
FROM field_len
JOIN mentions
ON field_len.field = mentions.field AND field_len.id = mentions.id
JOIN token_freq
ON token_freq.field = mentions.field AND token_freq.token = mentions.token
where token_freq < 100
ORDER BY mentions.field, mentions.token
"""
rel = self.con.sql(query, alias="mentions")
rel = self.con.sql(query)
row = rel.fetchone()
entities = [] # the entities in this field, token group
field_name = None
Expand Down

0 comments on commit 88e4398

Please sign in to comment.