Skip to content

Commit

Permalink
Reduce memory consumption by
Browse files Browse the repository at this point in the history
By letting it materialise intermediate results more explicitly
instead of doing multiple joins concurrently
  • Loading branch information
jbothma committed Nov 15, 2024
1 parent 287d70a commit 178a941
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 50 deletions.
95 changes: 52 additions & 43 deletions nomenklatura/index/duckdb_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from followthemoney.types import registry
from pathlib import Path
from shutil import rmtree
from typing import Any, Dict, Generator, Iterable, List, Tuple
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple
import csv
import duckdb
import logging
Expand Down Expand Up @@ -53,26 +53,31 @@ def __init__(
self, view: View[DS, CE], data_dir: Path, options: Dict[str, Any] = {}
):
self.view = view
# self.memory_budget = int(options.get("memory_budget", 500) * 1024 * 1024)
memory_budget = options.get("memory_budget", None)
self.memory_budget: Optional[int] = (
(int(memory_budget) * 1024) if memory_budget else None
)
"""Memory budget in megabytes"""
self.max_candidates = int(options.get("max_candidates", 50))
self.tokenizer = Tokenizer[DS, CE]()
self.data_dir = data_dir
if self.data_dir.exists():
rmtree(self.data_dir.as_posix())
rmtree(self.data_dir)
self.data_dir.mkdir(parents=True)
self.con = duckdb.connect((self.data_dir / "duckdb_index.db").as_posix())
self.matching_path = self.data_dir / "matching.csv"
self.matching_path.unlink(missing_ok=True)
self.matching_dump: TextIOWrapper | None = open(self.matching_path, "w")
writer = csv.writer(self.matching_dump)
writer.writerow(["id", "field", "token"])

# https://duckdb.org/docs/guides/performance/environment
# > For ideal performance, aggregation-heavy workloads require approx.
# > 5 GB memory per thread and join-heavy workloads require approximately
# > 10 GB memory per thread.
# > For ideal performance,
# > aggregation-heavy workloads require approx. 5 GB memory per thread and
# > join-heavy workloads require approximately 10 GB memory per thread.
# > Aim for 5-10 GB memory per thread.
self.con.execute("SET memory_limit = '3GB';")
self.con.execute("SET max_memory = '3GB';")
if self.memory_budget is not None:
self.con.execute("SET memory_limit = ?;", [f"{self.memory_budget}MB"])
# > If you have a limited amount of memory, try to limit the number of threads
self.con.execute("SET threads = 1;")

Expand All @@ -96,62 +101,72 @@ def dump_entity(entity: CE) -> None:
return
for field, token in self.tokenizer.entity(entity):
writer.writerow([entity.id, field, token])
writer.writerow(["id", "field", "token"])

writer.writerow(["id", "field", "token"])
for idx, entity in enumerate(self.view.entities()):
dump_entity(entity)
if idx % 50000 == 0:
log.info("Dumped %s entities" % idx)

log.info("Loading data...")
result = self.con.execute(f"COPY entries from '{csv_path}'").fetchall()
log.info("Loaded %r rows", len(result))

log.info("Calculating term frequencies...")
frequencies = self.frequencies_rel() # noqa
self.con.execute("CREATE TABLE term_frequencies as SELECT * FROM frequencies")

log.info("Calculating stopwords...")
token_freq = self.token_freq_rel() # noqa
self.con.execute(
"CREATE TABLE stopwords as SELECT * FROM token_freq where token_freq > 100"
)
self.con.execute(f"COPY entries from '{csv_path}'")
log.info("Done.")

self._build_frequencies()
log.info("Index built.")

def field_len_rel(self) -> DuckDBPyRelation:
def _build_field_len(self) -> None:
self._build_stopwords()
log.info("Calculating field lengths...")
field_len_query = """
SELECT field, id, count(*) as field_len from entries
GROUP BY field, id
CREATE TABLE IF NOT EXISTS field_len as
SELECT entries.field, entries.id, count(*) as field_len from entries
LEFT OUTER JOIN stopwords
ON stopwords.field = entries.field AND stopwords.token = entries.token
WHERE token_freq is NULL
GROUP BY entries.field, entries.id
"""
return self.con.sql(field_len_query)
self.con.execute(field_len_query)

def mentions_rel(self) -> DuckDBPyRelation:
def _build_mentions(self) -> None:
self._build_stopwords()
log.info("Calculating mention counts...")
mentions_query = """
SELECT field, id, token, count(*) as mentions
CREATE TABLE IF NOT EXISTS mentions as
SELECT entries.field, entries.id, entries.token, count(*) as mentions
FROM entries
GROUP BY field, id, token
LEFT OUTER JOIN stopwords
ON stopwords.field = entries.field AND stopwords.token = entries.token
WHERE token_freq is NULL
GROUP BY entries.field, entries.id, entries.token
"""
return self.con.sql(mentions_query)
self.con.execute(mentions_query)

def token_freq_rel(self) -> DuckDBPyRelation:
def _build_stopwords(self) -> None:
token_freq_query = """
SELECT field, token, count(*) as token_freq
FROM entries
GROUP BY field, token
"""
return self.con.sql(token_freq_query)
token_freq = self.con.sql(token_freq_query) # noqa
self.con.execute(
"""
CREATE TABLE IF NOT EXISTS stopwords as
SELECT * FROM token_freq where token_freq > 100
"""
)

def frequencies_rel(self) -> DuckDBPyRelation:
field_len = self.field_len_rel() # noqa
mentions = self.mentions_rel() # noqa
def _build_frequencies(self) -> None:
self._build_field_len()
self._build_mentions()
log.info("Calculating term frequencies...")
term_frequencies_query = """
CREATE TABLE IF NOT EXISTS term_frequencies as
SELECT mentions.field, mentions.token, mentions.id, mentions/field_len as tf
FROM field_len
JOIN mentions
ON field_len.field = mentions.field AND field_len.id = mentions.id
"""
return self.con.sql(term_frequencies_query)
self.con.execute(term_frequencies_query)

def pairs(
self, max_pairs: int = BaseIndex.MAX_PAIRS
Expand All @@ -163,9 +178,6 @@ def pairs(
ON "left".field = "right".field AND "left".token = "right".token
LEFT OUTER JOIN boosts
ON "left".field = boosts.field
LEFT OUTER JOIN stopwords
ON stopwords.field = "left".field AND stopwords.token = "left".token
WHERE token_freq is NULL
AND "left".id > "right".id
GROUP BY "left".id, "right".id
ORDER BY score DESC
Expand Down Expand Up @@ -194,13 +206,10 @@ def matches(
match_query = """
SELECT matching.id, matches.id, sum(matches.tf * ifnull(boost, 1)) as score
FROM term_frequencies as matches
LEFT OUTER JOIN stopwords
ON stopwords.field = matches.field AND stopwords.token = matches.token
JOIN matching
ON matches.field = matching.field AND matches.token = matching.token
LEFT OUTER JOIN boosts
ON matches.field = boosts.field
WHERE token_freq is NULL
GROUP BY matches.id, matching.id
ORDER BY matching.id, score DESC
"""
Expand All @@ -217,7 +226,7 @@ def matches(
matches = []
previous_id = matching_id
matches.append((Identifier.get(match_id), score))
yield Identifier.get(previous_id), matches
yield Identifier.get(previous_id), matches[: self.max_candidates]

def __repr__(self) -> str:
return "<DuckDBIndex(%r, %r)>" % (
Expand Down
16 changes: 9 additions & 7 deletions tests/index/test_duckdb_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def test_import(dstore: SimpleMemoryStore, index_path: Path):
def test_field_lengths(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex):
field_names = set()
ids = set()
for field_name, id, field_len in duckdb_index.field_len_rel().fetchall():

field_len_rel = duckdb_index.con.sql("SELECT * FROM field_len")
for field_name, id, field_len in field_len_rel.fetchall():
field_names.add(field_name)
ids.add(id)

Expand All @@ -56,7 +58,8 @@ def test_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex):
ids = set()
field_tokens = defaultdict(set)

for field_name, id, token, count in duckdb_index.mentions_rel().fetchall():
mentions_rel = duckdb_index.con.sql("SELECT * FROM mentions")
for field_name, id, token, count in mentions_rel.fetchall():
ids.add(id)
field_tokens[field_name].add(token)

Expand Down Expand Up @@ -121,11 +124,10 @@ def test_index_pairs(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex):
assert 1.1 < false_pos_score < 1.2, false_pos_score
assert bmw_score > false_pos_score, (bmw_score, false_pos_score)

assert len(pairs) == 428, len(pairs)
assert len(pairs) >= 428, len(pairs)


def test_match_score(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex):
print(duckdb_index.data_dir)
"""Match an entity that isn't itself in the index"""
dx = Dataset.make({"name": "test", "title": "Test"})
entity = CompositeEntity.from_data(dx, VERBAND_BADEN_DATA)
Expand All @@ -147,12 +149,12 @@ def test_match_score(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex):
assert next_result[0] == Identifier(VERBAND_ID), next_result
assert 1.66 < next_result[1] < 1.67, next_result

match_identifiers = set(str(m[0]) for m in matches)
#match_identifiers = set(str(m[0]) for m in matches)


#def test_top_match_matches_strong_pairs(
# def test_top_match_matches_strong_pairs(
# dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex
#):
# ):
# """Pairs with high scores are each others' top matches"""
#
# view = dstore.default_view()
Expand Down

0 comments on commit 178a941

Please sign in to comment.