Skip to content

Commit

Permalink
Unit test subqueries, typecheck
Browse files Browse the repository at this point in the history
  • Loading branch information
jbothma committed Nov 7, 2024
1 parent 684c5c0 commit 84d47e3
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 79 deletions.
8 changes: 3 additions & 5 deletions nomenklatura/index/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path
from typing import Generic, List, Tuple
from nomenklatura.resolver import Identifier
from typing import Generic, Iterable, List, Tuple
from nomenklatura.resolver import Pair, Identifier
from nomenklatura.dataset import DS
from nomenklatura.entity import CE
from nomenklatura.store import View
Expand All @@ -16,9 +16,7 @@ def __init__(self, view: View[DS, CE], data_dir: Path) -> None:
def build(self) -> None:
raise NotImplementedError

def pairs(
self, max_pairs: int = MAX_PAIRS
) -> List[Tuple[Tuple[Identifier, Identifier], float]]:
def pairs(self, max_pairs: int = MAX_PAIRS) -> Iterable[Tuple[Pair, float]]:
raise NotImplementedError

def match(self, entity: CE) -> List[Tuple[Identifier, float]]:
Expand Down
81 changes: 49 additions & 32 deletions nomenklatura/index/duckdb_index.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
from collections import defaultdict
import csv
from duckdb import DuckDBPyRelation
from followthemoney.types import registry
from pathlib import Path
import logging
from itertools import combinations
from tempfile import mkdtemp
from typing import Any, Dict, Generator, Iterable, List, Set, Tuple
from followthemoney.types import registry
from typing import Iterable, Tuple
import csv
import duckdb
import logging

from nomenklatura.util import PathLike
from nomenklatura.resolver import Pair, Identifier
from nomenklatura.dataset import DS
from nomenklatura.entity import CE
from nomenklatura.store import View
from nomenklatura.index.entry import Field
from nomenklatura.index.tokenizer import NAME_PART_FIELD, WORD_FIELD, Tokenizer
from nomenklatura.index.common import BaseIndex
from nomenklatura.index.tokenizer import NAME_PART_FIELD, WORD_FIELD, Tokenizer
from nomenklatura.resolver import Pair, Identifier
from nomenklatura.store import View

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -55,59 +52,74 @@ def __init__(self, view: View[DS, CE], data_dir: Path):
self.tokenizer = Tokenizer[DS, CE]()
self.path = Path(mkdtemp())
self.con = duckdb.connect((self.path / "duckdb_index.db").as_posix())

# https://duckdb.org/docs/guides/performance/environment
# > For ideal performance, aggregation-heavy workloads require approx.
# > 5 GB memory per thread and join-heavy workloads require approximately
# > 10 GB memory per thread.
# > Aim for 5-10 GB memory per thread.
self.con.execute("SET memory_limit = '2GB';")
self.con.execute("SET max_memory = '2GB';")
# > If you have a limited amount of memory, try to limit the number of threads
self.con.execute("SET threads = 1;")
self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)")
self.con.execute("CREATE TABLE boosts (field TEXT, boost FLOAT)")

def dump(self, writer, entity: CE) -> None:

if not entity.schema.matchable or entity.id is None:
return

for field, token in self.tokenizer.entity(entity):
writer.writerow([entity.id, field, token])

def build(self) -> None:
"""Index all entities in the dataset."""
log.info("Building index from: %r...", self.view)
self.con.execute("CREATE TABLE boosts (field TEXT, boost FLOAT)")
for field, boost in self.BOOSTS.items():
self.con.execute("INSERT INTO boosts VALUES (?, ?)", [field, boost])

self.con.execute("CREATE TABLE entries (id TEXT, field TEXT, token TEXT)")
csv_path = self.path / "mentions.csv"
log.info("Dumping entity tokens to CSV for bulk load into the database...")
with open(csv_path, "w") as fh:
writer = csv.writer(fh)

# csv.writer type gymnastics
def dump_entity(entity: CE) -> None:
if not entity.schema.matchable or entity.id is None:
return
for field, token in self.tokenizer.entity(entity):
writer.writerow([entity.id, field, token])

writer.writerow(["id", "field", "token"])
for idx, entity in enumerate(self.view.entities()):
self.dump(writer, entity)
dump_entity(entity)
if idx % 10000 == 0:
log.info("Dumped %s entities" % idx)
for field, boost in self.BOOSTS.items():
self.con.execute("INSERT INTO boosts VALUES (?, ?)", [field, boost])

log.info("Loading data...")
self.con.execute(f"COPY entries from '{csv_path}'")

log.info("Index built.")

def pairs(
self, max_pairs: int = BaseIndex.MAX_PAIRS
) -> Iterable[Tuple[Pair, float]]:
def field_len_rel(self) -> DuckDBPyRelation:
field_len_query = """
SELECT field, id, count(*) as field_len from entries
GROUP BY field, id
"""
field_len = self.con.sql(field_len_query)
return self.con.sql(field_len_query)

def mentions_rel(self) -> DuckDBPyRelation:
mentions_query = """
SELECT field, id, token, count(*) as mentions
FROM entries
GROUP BY field, id, token
"""
mentions = self.con.sql(mentions_query)
return self.con.sql(mentions_query)

def token_freq_rel(self) -> DuckDBPyRelation:
token_freq_query = """
SELECT field, token, count(*) as token_freq
FROM entries
GROUP BY field, token
"""
token_freq = self.con.sql(token_freq_query)
return self.con.sql(token_freq_query)

def frequencies_rel(self) -> DuckDBPyRelation:
field_len = self.field_len_rel() # noqa
mentions = self.mentions_rel() # noqa
token_freq = self.token_freq_rel() # noqa
term_frequencies_query = """
SELECT mentions.field, mentions.token, mentions.id, mentions/field_len as tf
FROM field_len
Expand All @@ -117,7 +129,12 @@ def pairs(
ON token_freq.field = mentions.field AND token_freq.token = mentions.token
where token_freq < 100
"""
term_frequencies = self.con.sql(term_frequencies_query)
return self.con.sql(term_frequencies_query)

def pairs(
self, max_pairs: int = BaseIndex.MAX_PAIRS
) -> Iterable[Tuple[Pair, float]]:
term_frequencies = self.frequencies_rel() # noqa
pairs_query = """
SELECT "left".id, "right".id, sum(("left".tf + "right".tf) * boost) as score
FROM term_frequencies as "left"
Expand Down
71 changes: 44 additions & 27 deletions tests/index/test_duckdb_index.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from collections import defaultdict
from pathlib import Path
from tempfile import NamedTemporaryFile, TemporaryDirectory

from nomenklatura.dataset import Dataset
from nomenklatura.entity import CompositeEntity
from nomenklatura.index import Index
from nomenklatura.index.duckdb_index import DuckDBIndex
from nomenklatura.resolver.identifier import Identifier
from nomenklatura.store import SimpleMemoryStore
Expand All @@ -24,7 +19,7 @@
def test_field_lengths(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex):
field_names = set()
ids = set()
for field_name, id, field_len in duckdb_index.field_lengths():
for field_name, id, field_len in duckdb_index.field_len_rel().fetchall():
field_names.add(field_name)
ids.add(id)

Expand All @@ -51,7 +46,7 @@ def test_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex):
ids = set()
field_tokens = defaultdict(set)

for field_name, id, token, count in duckdb_index.mentions():
for field_name, id, token, count in duckdb_index.mentions_rel().fetchall():
ids.add(id)
field_tokens[field_name].add(token)

Expand All @@ -62,25 +57,14 @@ def test_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex):
assert "dortmund" in field_tokens["word"], field_tokens["word"]


def test_id_grouped_mentions(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex):
ids = set()
field_tokens = defaultdict(set)
for field_name, id, field_len, mentions in duckdb_index.id_grouped_mentions():
ids.add(id)
for token, count in mentions:
field_tokens[field_name].add(token)

assert len(ids) == 184, len(ids)
assert "verband" in field_tokens["namepart"], field_tokens["namepart"]
assert "gb" in field_tokens["country"], field_tokens["country"]
assert "adolf wurth gmbh" in field_tokens["name"], field_tokens["name"]
assert "dortmund" in field_tokens["word"], field_tokens["word"]


def test_index_pairs(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex):
view = dstore.default_view()
pairs = duckdb_index.pairs()
assert len(pairs) > 0, pairs
pairs = list(duckdb_index.pairs())

# At least one pair is found
assert len(pairs) > 0, len(pairs)

# A pair has tokens which overlap
tokenizer = duckdb_index.tokenizer
pair, score = pairs[0]
entity0 = view.get_entity(str(pair[0]))
Expand All @@ -89,7 +73,40 @@ def test_index_pairs(dstore: SimpleMemoryStore, duckdb_index: DuckDBIndex):
tokens1 = set(tokenizer.entity(entity1))
overlap = tokens0.intersection(tokens1)
assert len(overlap) > 0, overlap
# assert "Schnabel" in (overlap, tokens0, tokens1)
# assert "Schnabel" in (entity0.caption, entity1.caption)

# A pair has non-zero score
assert score > 0
# assert False

# pairs are in descending score order
last_score = pairs[0][1]
for pair in pairs[1:]:
assert pair[1] <= last_score
last_score = pair[1]

# Johanna Quandt <> Frau Johanna Quandt
jq = (
Identifier.get("9add84cbb7bb48c7552f8ec7ae54de54eed1e361"),
Identifier.get("2d3e50433e36ebe16f3d906b684c9d5124c46d76"),
)
jq_score = [score for pair, score in pairs if jq == pair][0]

# Bayerische Motorenwerke AG <> Bayerische Motorenwerke (BMW) AG
bmw = (
Identifier.get("21cc81bf3b960d2847b66c6c862e7aa9b5e4f487"),
Identifier.get("12570ee94b8dc23bcc080e887539d3742b2a5237"),
)
bmw_score = [score for pair, score in pairs if bmw == pair][0]

# More tokens in BMW means lower TF, reducing the score
assert jq_score > bmw_score, (jq_score, bmw_score)
assert jq_score == 19.0, jq_score
assert 3.3 < bmw_score < 3.4, bmw_score

# FERRING Arzneimittel GmbH <> Clou Container Leasing GmbH
false_pos = (
Identifier.get("f8867c433ba247cfab74096c73f6ff5e36db3ffe"),
Identifier.get("a061e760dfcf0d5c774fc37c74937193704807b5"),
)
false_pos_score = [score for pair, score in pairs if false_pos == pair][0]
assert 1.1 < false_pos_score < 1.2, false_pos_score
assert bmw_score > false_pos_score, (bmw_score, false_pos_score)
56 changes: 41 additions & 15 deletions tests/index/test_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,6 @@ def test_index_build(index_path: Path, dstore: SimpleMemoryStore):
assert len(index) == 184, len(index)


def test_frequencies(dstore: SimpleMemoryStore, dindex: Index):
view = dstore.default_view()

for field_name, field in dindex.fields.items():
for token, entry in field.tokens.items():
print(field_name, token)
for ident, tf in entry.frequencies(field):
print(" ", ident.id, tf)
assert False


def test_index_persist(dstore: SimpleMemoryStore, dindex):
view = dstore.default_view()
with TemporaryDirectory() as tmpdir:
Expand All @@ -57,7 +46,11 @@ def test_index_persist(dstore: SimpleMemoryStore, dindex):
def test_index_pairs(dstore: SimpleMemoryStore, dindex: Index):
view = dstore.default_view()
pairs = dindex.pairs()
assert len(pairs) > 0, pairs

# At least one pair is found
assert len(pairs) > 0, len(pairs)

# A pair has tokens which overlap
tokenizer = dindex.tokenizer
pair, score = pairs[0]
entity0 = view.get_entity(str(pair[0]))
Expand All @@ -66,10 +59,43 @@ def test_index_pairs(dstore: SimpleMemoryStore, dindex: Index):
tokens1 = set(tokenizer.entity(entity1))
overlap = tokens0.intersection(tokens1)
assert len(overlap) > 0, overlap
# assert "Schnabel" in (overlap, tokens0, tokens1)
# assert "Schnabel" in (entity0.caption, entity1.caption)

# A pair has non-zero score
assert score > 0
# assert False

# pairs are in descending score order
last_score = pairs[0][1]
for pair in pairs[1:]:
assert pair[1] <= last_score
last_score = pair[1]

# Johanna Quandt <> Frau Johanna Quandt
jq = (
Identifier.get("9add84cbb7bb48c7552f8ec7ae54de54eed1e361"),
Identifier.get("2d3e50433e36ebe16f3d906b684c9d5124c46d76"),
)
jq_score = [score for pair, score in pairs if jq == pair][0]

# Bayerische Motorenwerke AG <> Bayerische Motorenwerke (BMW) AG
bmw = (
Identifier.get("21cc81bf3b960d2847b66c6c862e7aa9b5e4f487"),
Identifier.get("12570ee94b8dc23bcc080e887539d3742b2a5237"),
)
bmw_score = [score for pair, score in pairs if bmw == pair][0]

# More tokens in BMW means lower TF, reducing the score
assert jq_score > bmw_score, (jq_score, bmw_score)
assert jq_score == 19.0, jq_score
assert 3.3 < bmw_score < 3.4, bmw_score

# FERRING Arzneimittel GmbH <> Clou Container Leasing GmbH
false_pos = (
Identifier.get("f8867c433ba247cfab74096c73f6ff5e36db3ffe"),
Identifier.get("a061e760dfcf0d5c774fc37c74937193704807b5"),
)
false_pos_score = [score for pair, score in pairs if false_pos == pair][0]
assert 1.1 < false_pos_score < 1.2, false_pos_score
assert bmw_score > false_pos_score, (bmw_score, false_pos_score)


def test_match_score(dstore: SimpleMemoryStore, dindex: Index):
Expand Down

0 comments on commit 84d47e3

Please sign in to comment.