diff --git a/src/harmony/matching/matcher.py b/src/harmony/matching/matcher.py index b838752..0da7436 100644 --- a/src/harmony/matching/matcher.py +++ b/src/harmony/matching/matcher.py @@ -24,17 +24,15 @@ SOFTWARE. ''' - from collections import Counter from typing import List, Callable import numpy as np -from numpy import dot, mat, matmul, ndarray -from numpy.linalg import norm - from harmony.matching.negator import negate from harmony.schemas.requests.text import Instrument from harmony.schemas.text_vector import TextVector +from numpy import dot, mat, matmul, ndarray +from numpy.linalg import norm def cosine_similarity(vec1: ndarray, vec2: ndarray) -> ndarray: @@ -45,105 +43,44 @@ def cosine_similarity(vec1: ndarray, vec2: ndarray) -> ndarray: return np.asarray(dp / matmul(m1.T, m2)) -def match_instruments_with_function( - instruments: List[Instrument], - query: str, - vectorisation_function: Callable, - mhc_questions: List = [], - mhc_all_metadatas: List = [], - mhc_embeddings: np.ndarray = np.zeros((0, 0)), - texts_cached_vectors: dict[str, List[float]] = {}, -) -> tuple: - """ - Match instruments - - :param instruments: The instruments - :param query: The query - :param vectorisation_function: A function to vectorize a text - :param mhc_questions - :param mhc_all_metadatas - :param mhc_embeddings - :param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector) - """ - - # Create a list of text vectors - all_questions = [] - text_vectors: List[TextVector] = [] - for instrument in instruments: - for question in instrument.questions: - if question.question_text is None or question.question_text.strip() == "": - continue # skip empty questions - - question.instrument_id = instrument.instrument_id - all_questions.append(question) - - # Text - question_text = question.question_text - if question_text not in texts_cached_vectors.keys(): - text_vectors.append( - TextVector( - text=question_text, vector=[], is_negated=False, is_query=False - ) - ) - else: - vector = texts_cached_vectors[question_text] - text_vectors.append( - TextVector( - text=question_text, - vector=vector, - is_negated=False, - is_query=False, - ) - ) - - # Negated text - negated_text = negate(question_text, instrument.language) - if negated_text not in texts_cached_vectors.keys(): - text_vectors.append( - TextVector( - text=negated_text, vector=[], is_negated=True, is_query=False - ) - ) - else: - vector = texts_cached_vectors[negated_text] - text_vectors.append( - TextVector( - text=negated_text, - vector=vector, - is_negated=True, - is_query=False, - ) - ) - - # Add query - if query: - if query not in texts_cached_vectors.keys(): - text_vectors.append( - TextVector(text=query, vector=[], is_negated=False, is_query=True) +def add_text_to_vec(text, texts_cached_vectors, text_vectors, is_negated_, is_query_): + if text not in texts_cached_vectors.keys(): + text_vectors.append( + TextVector( + text=text, vector=[], is_negated=is_negated_, is_query=is_query_ ) - else: - vector = texts_cached_vectors[query] - text_vectors.append( - TextVector(text=query, vector=vector, is_negated=False, is_query=True) + ) + else: + vector = texts_cached_vectors[text] + text_vectors.append( + TextVector( + text=text, + vector=vector, + is_negated=is_negated_, + is_query=is_query_, ) + ) + return text_vectors - # Texts with no cached vector - texts_not_cached = [x.text for x in text_vectors if not x.vector] - # Get vectors for all texts not cached - new_vectors_list: List = vectorisation_function(texts_not_cached).tolist() +def process_questions(questions): + texts_cached_vectors: dict[str, List[float]] = {} + text_vectors: List[TextVector] = [] + for question_text in questions: + text_vectors = add_text_to_vec(question_text, texts_cached_vectors, text_vectors, False, False) + negated_text = negate(question_text, 'en') + text_vectors = add_text_to_vec(negated_text, texts_cached_vectors, text_vectors, True, False) + return text_vectors - # Create a dictionary with new vectors - new_vectors_dict = {} - for vector, text in zip(new_vectors_list, texts_not_cached): - new_vectors_dict[text] = vector - # Add new vectors to all_texts +def vectorise_texts(text_vectors, vectorisation_function): for index, text_dict in enumerate(text_vectors): if not text_dict.vector: - text_vectors[index].vector = new_vectors_list.pop(0) + text_vectors[index].vector = vectorisation_function([text_dict.text]).tolist()[0] + return text_vectors + - # Create numpy array of texts vectors +def vectors_pos_neg(text_vectors): vectors_pos = np.array( [ x.vector @@ -160,8 +97,68 @@ def match_instruments_with_function( if (x.is_negated is True and x.is_query is False) ] ) + return vectors_pos, vectors_neg - # Get query similarity + +def create_full_text_vectors(all_questions, query, vectorisation_function, texts_cached_vectors): + # Create a list of text vectors + text_vectors = process_questions(all_questions) + + # Add query + if query: + text_vectors = add_text_to_vec(query, texts_cached_vectors, text_vectors, False, True) + + # Texts with no cached vector + texts_not_cached = [x.text for x in text_vectors if not x.vector] + + # Get vectors for all texts not cached + new_vectors_list: List = vectorisation_function(texts_not_cached).tolist() + + # Create a dictionary with new vectors + new_vectors_dict = {} + for vector, text in zip(new_vectors_list, texts_not_cached): + new_vectors_dict[text] = vector + + # Add new vectors to all_texts + for index, text_dict in enumerate(text_vectors): + if not text_dict.vector: + text_vectors[index].vector = new_vectors_list.pop(0) + return text_vectors, new_vectors_dict + + +# +def match_instruments_with_function( + instruments: List[Instrument], + query: str, + vectorisation_function: Callable, + mhc_questions: List = [], + mhc_all_metadatas: List = [], + mhc_embeddings: np.ndarray = np.zeros((0, 0)), + texts_cached_vectors: dict[str, List[float]] = {}, +) -> tuple: + """ + Match instruments + + :param instruments: The instruments + :param query: The query + :param vectorisation_function: A function to vectorize a text + :param mhc_questions + :param mhc_all_metadatas + :param mhc_embeddings + :param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector) + """ + all_questions = [] + for instrument in instruments: + all_questions.extend(instrument.questions) + # all_questions: List[Question] = all_questions + all_questions_str: List[str] = [q.question_text for q in all_questions] + # all_questions = [instrument["question_text"] for instrument in instruments] + + text_vectors, new_vectors_dict = create_full_text_vectors(all_questions_str, query, vectorisation_function, + texts_cached_vectors) + vectors_pos, vectors_neg = vectors_pos_neg(text_vectors) + + # Get similarity between the query (only one query?) and the questions if vectors_pos.any() and query: vector_query = np.array( [[x for x in text_vectors if x.is_query is True][0].vector] @@ -219,5 +216,8 @@ def match_instruments_with_function( for question in all_questions: question.topics_auto = instrument_to_category[question.instrument_id] + else: + for question in all_questions: + question.topics_auto = [] return all_questions, similarity_with_polarity, query_similarity, new_vectors_dict