Skip to content

Commit

Permalink
experiment with less complex queries against the tantivy index
Browse files Browse the repository at this point in the history
  • Loading branch information
pudo committed Aug 26, 2024
1 parent ff196d7 commit 4501f5e
Showing 1 changed file with 120 additions and 80 deletions.
200 changes: 120 additions & 80 deletions nomenklatura/index/tantivy_index.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import math
import logging
from normality import WS
from pathlib import Path
from rigour.ids import StrictFormat
from functools import lru_cache
from normality import WS, ascii_text
from followthemoney.types import registry
from typing import Any, Dict, List, Tuple, Generator, Set
from typing import Any, Dict, List, Tuple, Generator, Set, Optional
from tantivy import Query, Occur, Index, SchemaBuilder, Document
import math
from collections import defaultdict
from rigour.text import metaphone
from rigour.text.scripts import is_modern_alphabet
from fingerprints.cleanup import clean_name_light

from nomenklatura.dataset import DS
from nomenklatura.entity import CE
Expand All @@ -17,6 +20,10 @@

log = logging.getLogger(__name__)

FIELD_ID = "entity_id"
FIELD_SCHEMA = "schemata"
FIELD_PHONETIC = "phonetic_name"
FIELD_TEXT = registry.text.name
INDEX_IGNORE = (
registry.entity,
registry.url,
Expand All @@ -30,22 +37,45 @@
registry.topic,
)
FULL_TEXT = {
# registry.text,
registry.text,
registry.string,
registry.address,
registry.identifier,
registry.name,
# registry.email,
}
BOOST_NAME_PHRASE = 4.0
BOOSTS = {
registry.name.name: 4.0,
registry.phone.name: 3.0,
registry.email.name: 3.0,
registry.address.name: 3.0,
registry.identifier.name: 5.0,
registry.phone.name: 5.0,
registry.email.name: 5.0,
# registry.address.name: 3.0,
registry.identifier.name: 7.0,
}


@lru_cache(maxsize=10000)
def _ascii_word(word: str) -> Optional[str]:
return ascii_text(word)


@lru_cache(maxsize=10000)
def _phonetic_word(word: str) -> Optional[str]:
if not is_modern_alphabet(word) or len(word) < 3:
return None
phon = metaphone(_ascii_word(word))
if not len(phon):
return None
return phon


def _identifier_clean(value: str) -> Optional[str]:
chars = [c for c in value if c.isalnum()]
if len(chars) < 4:
return None
return "".join(chars).upper()


class TantivyIndex(BaseIndex[DS, CE]):
name = "tantivy"

Expand All @@ -58,14 +88,14 @@ def __init__(
self.threshold = float(options.get("threshold", 1.0))

schema_builder = SchemaBuilder()
schema_builder.add_text_field("entity_id", tokenizer_name="raw", stored=True)
schema_builder.add_text_field("schemata", tokenizer_name="raw")
schema_builder.add_text_field(registry.name.name)
schema_builder.add_text_field(FIELD_ID, tokenizer_name="raw", stored=True)
schema_builder.add_text_field(FIELD_SCHEMA, tokenizer_name="raw")
schema_builder.add_text_field(FIELD_PHONETIC, tokenizer_name="raw")
schema_builder.add_text_field(FIELD_TEXT)
schema_builder.add_text_field(registry.name.name, tokenizer_name="raw")
schema_builder.add_text_field(registry.email.name)
schema_builder.add_text_field(registry.address.name)
schema_builder.add_text_field(registry.text.name)
# schema_builder.add_text_field(registry.address.name)
schema_builder.add_text_field(registry.identifier.name, tokenizer_name="raw")
schema_builder.add_text_field(registry.iban.name, tokenizer_name="raw")
schema_builder.add_text_field(registry.phone.name, tokenizer_name="raw")
schema_builder.add_text_field(registry.country.name, tokenizer_name="raw")
schema_builder.add_text_field(registry.date.name, tokenizer_name="raw")
Expand All @@ -81,15 +111,8 @@ def __init__(
self.build_index = True

@classmethod
def entity_fields(cls, entity: CE) -> Generator[Tuple[str, Set[str]], None, None]:
"""
A generator of each
- index field name and
- the set of normalised but not tokenised values for that field
for the given entity.
"""
def index_entity(cls, entity: CE) -> Dict[str, Set[str]]:
"""Convert the given entity's properties into fields in a format suitable for indexing."""
fields: Dict[str, Set[str]] = defaultdict(set)

for prop, value in entity.itervalues():
Expand All @@ -98,94 +121,111 @@ def entity_fields(cls, entity: CE) -> Generator[Tuple[str, Set[str]], None, None
continue

if type in FULL_TEXT:
fields[registry.text.name].add(value.lower())
fields[FIELD_TEXT].add(value)

if type == registry.name:
fields[type.name].add(value.lower())
norm = fingerprint_name(value)
if norm is not None:
fields[type.name].add(norm)
clean = clean_name_light(value)
if clean is None:
continue
for word in clean.split(WS):
fields[type.name].add(word)
ascii = _ascii_word(word)
if ascii is not None:
fields[type.name].add(ascii)
phonetic = _phonetic_word(word)
if phonetic is not None:
fields[FIELD_PHONETIC].add(phonetic)
continue

if type == registry.date and prop.matchable:
if len(value) > 4:
fields[type.name].add(value[:4])
fields[type.name].add(value[:10])
fields[type.name].add(value[:4])
if len(value) > 9:
fields[type.name].add(value[:10])
continue

if type == registry.identifier and prop.matchable:
clean_id = StrictFormat.normalize(value)
clean_id = _identifier_clean(value)
if clean_id is not None:
fields[type.name].add(clean_id)
continue

if type == registry.address and prop.matchable:
cleaned = clean_text_basic(value)
if cleaned is not None:
fields[type.name].add(cleaned)
continue

if prop.matchable and type in (
registry.phone,
registry.email,
registry.country,
# registry.address,
):
fields[type.name].add(value)
yield from fields.items()

def field_queries(
self, field: str, values: Set[str]
) -> Generator[Query, None, None]:
# print(dict(fields))
return fields

def field_queries(self, entity: CE) -> Generator[Query, None, None]:
"""
A generator of queries for the given index field and set of values.
"""
# Name phrase
if field == registry.name.name:
for value in values:
words = value.split(WS)
word_count = len(words)
if word_count > 1:
slop = math.ceil(2 * math.log(word_count))
names: Set[str] = set()
phonetics: Set[str] = set()

for prop, value in entity.itervalues():
type = prop.type
if type in INDEX_IGNORE:
continue

if prop.matchable and type in (
registry.phone,
registry.email,
registry.country,
):
yield Query.boost_query(
Query.term_query(self.schema, type.name, value),
BOOSTS.get(type.name, 1.0),
)

if type == registry.identifier and prop.matchable:
clean_id = _identifier_clean(value)
if clean_id is not None:
yield Query.boost_query(
Query.phrase_query(self.schema, field, words, slop), # type: ignore
BOOST_NAME_PHRASE,
Query.term_query(self.schema, type.name, clean_id),
BOOSTS.get(type.name, 1.0),
)
yield Query.term_query(self.schema, FIELD_TEXT, value)
continue

# Any of set of tokens in all values of the field
if field in {
registry.address.name,
registry.name.name,
registry.text.name,
registry.string.name,
}:
word_set: Set[str] = set()
for value in values:
word_set.update(value.split(WS))
term_queries: List[Query] = []
for word in word_set:
term_queries.append(Query.term_query(self.schema, field, word))
yield Query.boost_query(
Query.boolean_query([(Occur.Should, q) for q in term_queries]),
BOOSTS.get(field, 1.0),
)
return
if type == registry.name:
clean = clean_name_light(value)
if clean is None:
continue
for word in clean.split(WS):
names.add(word)
phonetic = _phonetic_word(word)
if phonetic is not None:
phonetics.add(phonetic)
ascii = _ascii_word(word)
if ascii is not None:
names.add(ascii)
continue

# entire value as a term
for value in values:
# TODO: not doing addresses at all for the moment.

for name in names:
yield Query.boost_query(
Query.term_query(self.schema, field, value),
BOOSTS.get(field, 1.0),
Query.term_query(self.schema, registry.name.name, name),
BOOSTS.get(registry.name.name, 1.0),
)
yield Query.term_query(self.schema, FIELD_TEXT, name)

for phonetic in phonetics:
yield Query.term_query(self.schema, FIELD_PHONETIC, phonetic)

def entity_query(self, entity: CE) -> Query:
schema_query = Query.term_query(self.schema, "schemata", entity.schema.name)
schema_query = Query.term_query(self.schema, FIELD_SCHEMA, entity.schema.name)
queries: List[Tuple[Occur, Query]] = [(Occur.Must, schema_query)]
if entity.id is not None:
id_query = Query.term_query(self.schema, "entity_id", entity.id)
id_query = Query.term_query(self.schema, FIELD_ID, entity.id)
queries.append((Occur.MustNot, id_query))
for field, value in self.entity_fields(entity):
for query in self.field_queries(field, value):
queries.append((Occur.Should, query))
for query in self.field_queries(entity):
queries.append((Occur.Should, query))
return Query.boolean_query(queries)

def build(self) -> None:
Expand All @@ -205,7 +245,7 @@ def build(self) -> None:
idx += 1
schemata = [s.name for s in entity.schema.matchable_schemata]
document = Document(entity_id=entity.id, schemata=schemata)
for field, values in self.entity_fields(entity):
for field, values in self.index_entity(entity).items():
for value in values:
document.add_text(field, value)
writer.add_document(document)
Expand Down

0 comments on commit 4501f5e

Please sign in to comment.