Skip to content

Commit

Permalink
Put refactoring of matcher back
Browse files Browse the repository at this point in the history
  • Loading branch information
woodthom2 committed Jan 22, 2024
1 parent d353f7b commit dae7fd8
Showing 1 changed file with 95 additions and 95 deletions.
190 changes: 95 additions & 95 deletions src/harmony/matching/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,15 @@
SOFTWARE.
'''

from collections import Counter
from typing import List, Callable

import numpy as np
from numpy import dot, mat, matmul, ndarray
from numpy.linalg import norm

from harmony.matching.negator import negate
from harmony.schemas.requests.text import Instrument
from harmony.schemas.text_vector import TextVector
from numpy import dot, mat, matmul, ndarray
from numpy.linalg import norm


def cosine_similarity(vec1: ndarray, vec2: ndarray) -> ndarray:
Expand All @@ -45,105 +43,44 @@ def cosine_similarity(vec1: ndarray, vec2: ndarray) -> ndarray:
return np.asarray(dp / matmul(m1.T, m2))


def match_instruments_with_function(
instruments: List[Instrument],
query: str,
vectorisation_function: Callable,
mhc_questions: List = [],
mhc_all_metadatas: List = [],
mhc_embeddings: np.ndarray = np.zeros((0, 0)),
texts_cached_vectors: dict[str, List[float]] = {},
) -> tuple:
"""
Match instruments
:param instruments: The instruments
:param query: The query
:param vectorisation_function: A function to vectorize a text
:param mhc_questions
:param mhc_all_metadatas
:param mhc_embeddings
:param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector)
"""

# Create a list of text vectors
all_questions = []
text_vectors: List[TextVector] = []
for instrument in instruments:
for question in instrument.questions:
if question.question_text is None or question.question_text.strip() == "":
continue # skip empty questions

question.instrument_id = instrument.instrument_id
all_questions.append(question)

# Text
question_text = question.question_text
if question_text not in texts_cached_vectors.keys():
text_vectors.append(
TextVector(
text=question_text, vector=[], is_negated=False, is_query=False
)
)
else:
vector = texts_cached_vectors[question_text]
text_vectors.append(
TextVector(
text=question_text,
vector=vector,
is_negated=False,
is_query=False,
)
)

# Negated text
negated_text = negate(question_text, instrument.language)
if negated_text not in texts_cached_vectors.keys():
text_vectors.append(
TextVector(
text=negated_text, vector=[], is_negated=True, is_query=False
)
)
else:
vector = texts_cached_vectors[negated_text]
text_vectors.append(
TextVector(
text=negated_text,
vector=vector,
is_negated=True,
is_query=False,
)
)

# Add query
if query:
if query not in texts_cached_vectors.keys():
text_vectors.append(
TextVector(text=query, vector=[], is_negated=False, is_query=True)
def add_text_to_vec(text, texts_cached_vectors, text_vectors, is_negated_, is_query_):
if text not in texts_cached_vectors.keys():
text_vectors.append(
TextVector(
text=text, vector=[], is_negated=is_negated_, is_query=is_query_
)
else:
vector = texts_cached_vectors[query]
text_vectors.append(
TextVector(text=query, vector=vector, is_negated=False, is_query=True)
)
else:
vector = texts_cached_vectors[text]
text_vectors.append(
TextVector(
text=text,
vector=vector,
is_negated=is_negated_,
is_query=is_query_,
)
)
return text_vectors

# Texts with no cached vector
texts_not_cached = [x.text for x in text_vectors if not x.vector]

# Get vectors for all texts not cached
new_vectors_list: List = vectorisation_function(texts_not_cached).tolist()
def process_questions(questions):
texts_cached_vectors: dict[str, List[float]] = {}
text_vectors: List[TextVector] = []
for question_text in questions:
text_vectors = add_text_to_vec(question_text, texts_cached_vectors, text_vectors, False, False)
negated_text = negate(question_text, 'en')
text_vectors = add_text_to_vec(negated_text, texts_cached_vectors, text_vectors, True, False)
return text_vectors

# Create a dictionary with new vectors
new_vectors_dict = {}
for vector, text in zip(new_vectors_list, texts_not_cached):
new_vectors_dict[text] = vector

# Add new vectors to all_texts
def vectorise_texts(text_vectors, vectorisation_function):
for index, text_dict in enumerate(text_vectors):
if not text_dict.vector:
text_vectors[index].vector = new_vectors_list.pop(0)
text_vectors[index].vector = vectorisation_function([text_dict.text]).tolist()[0]
return text_vectors


# Create numpy array of texts vectors
def vectors_pos_neg(text_vectors):
vectors_pos = np.array(
[
x.vector
Expand All @@ -160,8 +97,68 @@ def match_instruments_with_function(
if (x.is_negated is True and x.is_query is False)
]
)
return vectors_pos, vectors_neg

# Get query similarity

def create_full_text_vectors(all_questions, query, vectorisation_function, texts_cached_vectors):
# Create a list of text vectors
text_vectors = process_questions(all_questions)

# Add query
if query:
text_vectors = add_text_to_vec(query, texts_cached_vectors, text_vectors, False, True)

# Texts with no cached vector
texts_not_cached = [x.text for x in text_vectors if not x.vector]

# Get vectors for all texts not cached
new_vectors_list: List = vectorisation_function(texts_not_cached).tolist()

# Create a dictionary with new vectors
new_vectors_dict = {}
for vector, text in zip(new_vectors_list, texts_not_cached):
new_vectors_dict[text] = vector

# Add new vectors to all_texts
for index, text_dict in enumerate(text_vectors):
if not text_dict.vector:
text_vectors[index].vector = new_vectors_list.pop(0)
return text_vectors, new_vectors_dict


#
def match_instruments_with_function(
instruments: List[Instrument],
query: str,
vectorisation_function: Callable,
mhc_questions: List = [],
mhc_all_metadatas: List = [],
mhc_embeddings: np.ndarray = np.zeros((0, 0)),
texts_cached_vectors: dict[str, List[float]] = {},
) -> tuple:
"""
Match instruments
:param instruments: The instruments
:param query: The query
:param vectorisation_function: A function to vectorize a text
:param mhc_questions
:param mhc_all_metadatas
:param mhc_embeddings
:param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector)
"""
all_questions = []
for instrument in instruments:
all_questions.extend(instrument.questions)
# all_questions: List[Question] = all_questions
all_questions_str: List[str] = [q.question_text for q in all_questions]
# all_questions = [instrument["question_text"] for instrument in instruments]

text_vectors, new_vectors_dict = create_full_text_vectors(all_questions_str, query, vectorisation_function,
texts_cached_vectors)
vectors_pos, vectors_neg = vectors_pos_neg(text_vectors)

# Get similarity between the query (only one query?) and the questions
if vectors_pos.any() and query:
vector_query = np.array(
[[x for x in text_vectors if x.is_query is True][0].vector]
Expand Down Expand Up @@ -219,5 +216,8 @@ def match_instruments_with_function(

for question in all_questions:
question.topics_auto = instrument_to_category[question.instrument_id]
else:
for question in all_questions:
question.topics_auto = []

return all_questions, similarity_with_polarity, query_similarity, new_vectors_dict

0 comments on commit dae7fd8

Please sign in to comment.