From af8447f6a98e4d864c4dc889c6af0d5e006395a4 Mon Sep 17 00:00:00 2001 From: Zairon Jacobs Date: Thu, 20 Jun 2024 08:50:53 +0200 Subject: [PATCH 1/7] add catalogue match in matcher.py --- src/harmony/matching/matcher.py | 312 +++++++++++++++++++- src/harmony/schemas/catalogue_instrument.py | 9 + src/harmony/schemas/catalogue_question.py | 10 + src/harmony/schemas/requests/text.py | 79 ++++- src/harmony/schemas/responses/text.py | 17 +- 5 files changed, 413 insertions(+), 14 deletions(-) create mode 100644 src/harmony/schemas/catalogue_instrument.py create mode 100644 src/harmony/schemas/catalogue_question.py diff --git a/src/harmony/matching/matcher.py b/src/harmony/matching/matcher.py index ebb9a3b..8a1b0ec 100644 --- a/src/harmony/matching/matcher.py +++ b/src/harmony/matching/matcher.py @@ -1,4 +1,4 @@ -''' +""" MIT License Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk). @@ -22,18 +22,25 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +""" -''' +import statistics from collections import Counter from typing import List, Callable import numpy as np -from harmony.matching.negator import negate -from harmony.schemas.requests.text import Instrument -from harmony.schemas.text_vector import TextVector from numpy import dot, mat, matmul, ndarray from numpy.linalg import norm +from harmony.matching.negator import negate +from harmony.schemas.catalogue_instrument import CatalogueInstrument +from harmony.schemas.catalogue_question import CatalogueQuestion +from harmony.schemas.requests.text import ( + Instrument, + Question, +) +from harmony.schemas.text_vector import TextVector + def cosine_similarity(vec1: ndarray, vec2: ndarray) -> ndarray: dp = dot(vec1, vec2.T) @@ -99,7 +106,16 @@ def vectors_pos_neg(text_vectors): return vectors_pos, vectors_neg -def create_full_text_vectors(all_questions, query, vectorisation_function, texts_cached_vectors): +def create_full_text_vectors( + all_questions: List[str], + query: str | None, + vectorisation_function: Callable, + texts_cached_vectors: dict[str, list[float]], +) -> tuple[List[TextVector], dict]: + """ + Create full text vectors. + """ + # Create a list of text vectors text_vectors = process_questions(all_questions, texts_cached_vectors) @@ -122,9 +138,288 @@ def create_full_text_vectors(all_questions, query, vectorisation_function, texts for index, text_dict in enumerate(text_vectors): if not text_dict.vector: text_vectors[index].vector = new_vectors_list.pop(0) + return text_vectors, new_vectors_dict +def match_instruments_with_catalogue_instruments( + instruments: List[Instrument], + catalogue_data: dict, + vectorisation_function: Callable, + texts_cached_vectors: dict[str, List[float]], +) -> tuple[List[Instrument], CatalogueInstrument]: + """ + Match instruments with catalogue instruments. + + :param instruments: The instruments. + :param catalogue_data: The catalogue data. + :param vectorisation_function: A function to vectorize a text. + :param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector). + + :return: A tuple, index 0 contains the list of instruments that now each contain the best instrument match from the + catalog, and index 1 contains the closest instrument match from the catalog for all the instruments. + """ + + # For each instrument, find the best matching instrument for it in the catalogue + for instrument in instruments: + instrument.closest_catalogue_instrument_match = ( + match_questions_with_catalogue_instruments( + questions=instrument.questions, + catalogue_data=catalogue_data, + vectorisation_function=vectorisation_function, + texts_cached_vectors=texts_cached_vectors, + questions_are_from_one_instrument=True, + ) + ) + + # Gather all questions from all instruments and find the best matching instrument in the catalogue + all_instrument_questions: List[Question] = [] + for instrument in instruments: + all_instrument_questions.extend(instrument.questions) + closest_catalogue_instrument_match = match_questions_with_catalogue_instruments( + questions=all_instrument_questions, + catalogue_data=catalogue_data, + vectorisation_function=vectorisation_function, + texts_cached_vectors=texts_cached_vectors, + questions_are_from_one_instrument=False, + ) + + return instruments, closest_catalogue_instrument_match + + +def match_questions_with_catalogue_instruments( + questions: List[Question], + catalogue_data: dict, + vectorisation_function: Callable, + texts_cached_vectors: dict[str, List[float]], + questions_are_from_one_instrument: bool, +) -> CatalogueInstrument: + """ + Match questions with catalogue instruments. + Each question will receive a list of closest instrument matches, and at the end one closest instrument match for + all questions is returned. + + :param questions: The questions. + :param catalogue_data: The catalogue data. + :param vectorisation_function: A function to vectorize a text. + :param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector). + :param questions_are_from_one_instrument: If the questions provided are coming from one instrument only. + + :return: The closest instrument match for the questions provided. + """ + + # Catalogue data + catalogue_instrument_idx_to_catalogue_questions_idx = catalogue_data[ + "instrument_idx_to_question_idx" + ] + all_catalogue_questions_embeddings_concatenated = catalogue_data[ + "all_embeddings_concatenated" + ] + all_catalogue_instruments = catalogue_data["all_instruments"] + all_catalogue_questions = catalogue_data["all_questions"] + + # Create text vectors + text_vectors, new_vectors_dict = create_full_text_vectors( + all_questions=[q.question_text for q in questions], + query=None, + vectorisation_function=vectorisation_function, + texts_cached_vectors=texts_cached_vectors, + ) + + # The total number of questions we received as input. + num_input_questions = len(questions) + + # Get an array of dimensions. + # (number of input questions) x (number of dimensions of LLM - typically 768, 384, 500, 512, etc.) + text_vectors_dict = { + text_vector.text: text_vector.vector for text_vector in text_vectors + } + vectors = np.array( + [text_vectors_dict[question.question_text] for question in questions] + ) + + # Get a 2D array of (number of input questions) x (number of questions in catalogue). + # E.g. index 0 (matches for the first input question) will contain a list of matches for each question in the + # catalogue. So the best match for the first input question is the highest similarity found in index 0. + catalogue_similarities = cosine_similarity( + vectors, all_catalogue_questions_embeddings_concatenated + ) + + # Get a 1D array of length (number of input questions). + # For each input question, this is the index of the single closest matching question text in our catalogues. + # Note that each question text in the catalogue (vector index) is unique, and we must later do a further mapping to + # find out which instrument(s) it occurs in. + idxs_top_input_questions_matches = np.argmax(catalogue_similarities, axis=1) + + # Get a set of all the top matching question text indices in our catalogue. + # idxs_top_input_questions_matches_set = set(idxs_top_input_questions_matches) + + # This keeps track of each instrument matches how many question items in the query + # e.g. if the first instrument in our catalogue (instrument 0) matches 4 items, then this dictionary will + # contain {0: 4}. + # instrument_idx_to_num_matching_items_with_query = {} + + # This dictionary will contain the index of the instrument and the cosine similarities to the top matched questions + # in that instrument e.g. {50: [ ... ]} + instrument_idx_to_cosine_similarities_top_match: dict[int, []] = {} + + # This keeps track of how many question items in total are contained in each instrument, irrespective of the + # number of matches. + # This is needed for stats such as precision and recall. + instrument_idx_to_total_num_question_items_present = {} + + # Find any instruments matching + input_question_idx_to_matching_instruments: List[List[dict]] = [] + for input_question_idx in range(len(questions)): + input_question_idx_to_matching_instruments.append([]) + for input_question_idx in range(len(questions)): + top_match_catalogue_question_idx = idxs_top_input_questions_matches[ + input_question_idx + ] + for instrument_idx, question_idxs_in_this_instrument in enumerate( + catalogue_instrument_idx_to_catalogue_questions_idx + ): + if top_match_catalogue_question_idx in question_idxs_in_this_instrument: + instrument_from_catalogue = all_catalogue_instruments[instrument_idx] + if not any( + x["instrument_name"] == instrument_from_catalogue["instrument_name"] + for x in input_question_idx_to_matching_instruments[input_question_idx] + ): + input_question_idx_to_matching_instruments[ + input_question_idx + ].append(instrument_from_catalogue) + + # For each catalogue instrument get the total number of question matches in the query + # For each catalogue instrument get the total number of questions + for instrument_idx, question_idxs_in_this_instrument in enumerate( + catalogue_instrument_idx_to_catalogue_questions_idx + ): + catalogue_question_idxs_in_this_instrument_set = set( + question_idxs_in_this_instrument + ) + # instrument_idx_to_num_matching_items_with_query[instrument_idx] = len( + # catalogue_question_idxs_in_this_instrument_set.intersection( + # idxs_top_input_questions_matches_set + # ) + # ) + instrument_idx_to_total_num_question_items_present[instrument_idx] = len( + catalogue_question_idxs_in_this_instrument_set + ) + + # Question similarity with catalogue questions + for idx, question in enumerate(questions): + + seen_in_instruments: List[CatalogueInstrument] = [] + for instrument in input_question_idx_to_matching_instruments[idx]: + instrument_name = instrument["instrument_name"] + instrument_url = instrument["metadata"].get("url", "") + source = instrument["metadata"]["source"].upper() + sweep = instrument["metadata"].get("sweep_id", "") + seen_in_instruments.append( + CatalogueInstrument( + instrument_name=instrument_name, + instrument_url=instrument_url, + source=source, + sweep=sweep, + ) + ) + + question.closest_catalogue_question_match = CatalogueQuestion( + question=all_catalogue_questions[idxs_top_input_questions_matches[idx]], + seen_in_instruments=seen_in_instruments, + ) + + # Instrument index to list of cosine similarities top question match + for input_question_idx, idx_top_input_question_match_in_catalogue in enumerate( + idxs_top_input_questions_matches + ): + for ( + catalogue_instrument_idx, + catalogue_question_idxs_in_this_instrument, + ) in enumerate(catalogue_instrument_idx_to_catalogue_questions_idx): + catalogue_question_idxs_set = set( + catalogue_question_idxs_in_this_instrument + ) + if idx_top_input_question_match_in_catalogue in catalogue_question_idxs_set: + # Create the list if it doesn't exist yet + if not instrument_idx_to_cosine_similarities_top_match.get( + catalogue_instrument_idx + ): + instrument_idx_to_cosine_similarities_top_match[ + catalogue_instrument_idx + ] = [] + + # Add the cosine similarity + instrument_idx_to_cosine_similarities_top_match[ + catalogue_instrument_idx + ].append( + catalogue_similarities[input_question_idx][ + idx_top_input_question_match_in_catalogue + ] + ) + + # Keep track of the instrument id and the count of top question matches that belong to it + instrument_idx_to_top_matches_ct = { + k: len(v) for k, v in instrument_idx_to_cosine_similarities_top_match.items() + } + + # Calculate the average for each list of cosine similarities from instruments + for ( + instrument_idx, + cosine_similarities, + ) in instrument_idx_to_cosine_similarities_top_match.items(): + instrument_idx_to_cosine_similarities_top_match[instrument_idx] = ( + statistics.mean(cosine_similarities) + ) + + # Find the index of the best instrument match + best_catalogue_instrument_idx = max( + instrument_idx_to_cosine_similarities_top_match, + key=instrument_idx_to_cosine_similarities_top_match.get, + ) + + # Get the best instrument match + best_catalogue_instrument = all_catalogue_instruments[best_catalogue_instrument_idx] + num_questions_in_ref_instrument = ( + instrument_idx_to_total_num_question_items_present[ + best_catalogue_instrument_idx + ] + ) + num_top_match_questions = instrument_idx_to_top_matches_ct[ + best_catalogue_instrument_idx + ] + + instrument_name = best_catalogue_instrument["instrument_name"] + instrument_url = best_catalogue_instrument["metadata"].get("url", "") + source = best_catalogue_instrument["metadata"]["source"].upper() + sweep = best_catalogue_instrument["metadata"].get("sweep_id", "") + + if questions_are_from_one_instrument: + info = ( + f"{instrument_name} Sweep {sweep if sweep else 'UNKNOWN'} matched {num_top_match_questions} " + f"question(s) in your instrument, your instrument contains {num_input_questions} question(s). " + f"The reference instrument contains {num_questions_in_ref_instrument} question(s)." + ) + else: + info = ( + f"{instrument_name} Sweep {sweep if sweep else 'UNKNOWN'} matched {num_top_match_questions} " + f"question(s) in all of your instruments, your instruments contains {num_input_questions} " + f"question(s). The reference instrument contains {num_questions_in_ref_instrument} question(s)." + ) + + return CatalogueInstrument( + instrument_name=instrument_name, + instrument_url=instrument_url, + source=source, + sweep=sweep, + metadata={ + "info": info, + "num_matched_questions": num_top_match_questions, + "num_ref_instrument_questions": num_questions_in_ref_instrument, + }, + ) + + # def match_instruments_with_function( instruments: List[Instrument], @@ -136,7 +431,7 @@ def match_instruments_with_function( texts_cached_vectors: dict[str, List[float]] = {}, ) -> tuple: """ - Match instruments + Match instruments. :param instruments: The instruments :param query: The query @@ -146,6 +441,7 @@ def match_instruments_with_function( :param mhc_embeddings :param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector) """ + all_questions = [] for instrument in instruments: all_questions.extend(instrument.questions) @@ -155,7 +451,7 @@ def match_instruments_with_function( text_vectors, new_vectors_dict = create_full_text_vectors(all_questions_str, query, vectorisation_function, texts_cached_vectors) - # get vectors for all orignal texts and vectors for negated texts + # get vectors for all original texts and vectors for negated texts vectors_pos, vectors_neg = vectors_pos_neg(text_vectors) # Get similarity between the query (only one query?) and the questions diff --git a/src/harmony/schemas/catalogue_instrument.py b/src/harmony/schemas/catalogue_instrument.py new file mode 100644 index 0000000..505599c --- /dev/null +++ b/src/harmony/schemas/catalogue_instrument.py @@ -0,0 +1,9 @@ +from pydantic import BaseModel, Field + + +class CatalogueInstrument(BaseModel): + instrument_name: str = Field(description="Instrument name") + instrument_url: str = Field(description="Instrument URL") + source: str = Field(description="Source") + sweep: str = Field(description="Sweep") + metadata: dict = Field(default=None, description="Metadata") diff --git a/src/harmony/schemas/catalogue_question.py b/src/harmony/schemas/catalogue_question.py new file mode 100644 index 0000000..a18721a --- /dev/null +++ b/src/harmony/schemas/catalogue_question.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel, Field + +from harmony.schemas.catalogue_instrument import CatalogueInstrument + + +class CatalogueQuestion(BaseModel): + question: str = Field(description="The catalogue question") + seen_in_instruments: list[CatalogueInstrument] = Field( + description="The instruments from the catalogue were the question was seen in" + ) diff --git a/src/harmony/schemas/requests/text.py b/src/harmony/schemas/requests/text.py index 26b01eb..3d54c02 100644 --- a/src/harmony/schemas/requests/text.py +++ b/src/harmony/schemas/requests/text.py @@ -29,6 +29,8 @@ from pydantic import BaseModel, Field +from harmony.schemas.catalogue_instrument import CatalogueInstrument +from harmony.schemas.catalogue_question import CatalogueQuestion from harmony.schemas.enums.file_types import FileType from harmony.schemas.enums.languages import Language @@ -64,8 +66,13 @@ class Question(BaseModel): instrument_id: str = Field(None, description="Unique identifier for the instrument (UUID-4)") instrument_name: str = Field(None, description="Human readable name for the instrument") topics_auto: list = Field(None, description="Automated list of topics identified by model") - topics_strengths: dict = Field(None, description="Automated list of topics identified by model with strength of topic") + topics_strengths: dict = Field( + None, description="Automated list of topics identified by model with strength of topic" + ) nearest_match_from_mhc_auto: dict = Field(None, description="Automatically identified nearest MHC match") + closest_catalogue_question_match: CatalogueQuestion = Field( + None, description="The closest question match in the catalogue for the question" + ) class Config: schema_extra = { @@ -92,7 +99,10 @@ class Instrument(BaseModel): description="Optional metadata about the instrument (URL, citation, DOI, copyright holder)") language: Language = Field(Language.English, description="The ISO 639-2 (alpha-2) encoding of the instrument language") - questions: List[Question] = Field(description="the items inside the instrument") + questions: List[Question] = Field(description="The items inside the instrument") + closest_catalogue_instrument_match: CatalogueInstrument = Field( + None, description="The closest instrument match in the catalogue for the instrument" + ) class Config: schema_extra = { @@ -198,3 +208,68 @@ class Config: "model": DEFAULT_MODEL} } } + + +class MatchCatalogueBody(BaseModel): + instruments: List[Instrument] = Field(description="Instruments to match") + parameters: MatchParameters = Field(DEFAULT_MATCH_PARAMETERS, description="Parameters on how to match") + + class Config: + schema_extra = { + "example": { + "instruments": [{ + "file_id": "fd60a9a64b1b4078a68f4bc06f20253c", + "instrument_id": "7829ba96f48e4848abd97884911b6795", + "instrument_name": "GAD-7 English", + "file_name": "GAD-7 EN.pdf", + "file_type": "pdf", + "file_section": "GAD-7 English", + "language": "en", + "questions": [{"question_no": "1", + "question_intro": "Over the last two weeks, how often have you been bothered by the following problems?", + "question_text": "Feeling nervous, anxious, or on edge", + "options": ["Not at all", "Several days", "More than half the days", + "Nearly every day"], + "source_page": 0 + }, + {"question_no": "2", + "question_intro": "Over the last two weeks, how often have you been bothered by the following problems?", + "question_text": "Not being able to stop or control worrying", + "options": ["Not at all", "Several days", "More than half the days", + "Nearly every day"], + "source_page": 0 + } + + ] + }, + { + "file_id": "fd60a9a64b1b4078a68f4bc06f20253c", + "instrument_id": "7829ba96f48e4848abd97884911b6795", + "instrument_name": "GAD-7 Portuguese", + "file_name": "GAD-7 PT.pdf", + "file_type": "pdf", + "file_section": "GAD-7 Portuguese", + "language": "en", + "questions": [{"question_no": "1", + "question_intro": "Durante as últimas 2 semanas, com que freqüência você foi incomodado/a pelos problemas abaixo?", + "question_text": "Sentir-se nervoso/a, ansioso/a ou muito tenso/a", + "options": ["Nenhuma vez", "Vários dias", "Mais da metade dos dias", + "Quase todos os dias"], + "source_page": 0 + }, + {"question_no": "2", + "question_intro": "Durante as últimas 2 semanas, com que freqüência você foi incomodado/a pelos problemas abaixo?", + "question_text": " Não ser capaz de impedir ou de controlar as preocupações", + "options": ["Nenhuma vez", "Vários dias", "Mais da metade dos dias", + "Quase todos os dias"], + "source_page": 0 + } + + ] + } + + ], + "parameters": {"framework": DEFAULT_FRAMEWORK, + "model": DEFAULT_MODEL} + } + } \ No newline at end of file diff --git a/src/harmony/schemas/responses/text.py b/src/harmony/schemas/responses/text.py index 802e8f5..01ff9a7 100644 --- a/src/harmony/schemas/responses/text.py +++ b/src/harmony/schemas/responses/text.py @@ -1,4 +1,4 @@ -''' +""" MIT License Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk). @@ -22,14 +22,15 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -''' +""" from typing import List +from pydantic import BaseModel, Field + +from harmony.schemas.catalogue_instrument import CatalogueInstrument from harmony.schemas.requests.text import Instrument from harmony.schemas.requests.text import Question -from pydantic import BaseModel, Field class MatchResponse(BaseModel): @@ -42,6 +43,14 @@ class MatchResponse(BaseModel): ) +class MatchCatalogueResponse(BaseModel): + instruments: List[Instrument] = Field(description="A list of instruments") + closest_catalogue_instrument_match: CatalogueInstrument = Field( + default=None, + description="The closest catalogue instrument match" + ) + + class InstrumentList(BaseModel): __root__: List[Instrument] From 492208a8769f79aabcb6fcc2c7cfaa803029d817 Mon Sep 17 00:00:00 2001 From: Zairon Jacobs Date: Fri, 21 Jun 2024 08:15:52 +0200 Subject: [PATCH 2/7] check if embeddings array is empty in match catalogue func --- src/harmony/matching/matcher.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/harmony/matching/matcher.py b/src/harmony/matching/matcher.py index 8a1b0ec..1697adf 100644 --- a/src/harmony/matching/matcher.py +++ b/src/harmony/matching/matcher.py @@ -147,7 +147,7 @@ def match_instruments_with_catalogue_instruments( catalogue_data: dict, vectorisation_function: Callable, texts_cached_vectors: dict[str, List[float]], -) -> tuple[List[Instrument], CatalogueInstrument]: +) -> tuple[List[Instrument], CatalogueInstrument | None]: """ Match instruments with catalogue instruments. @@ -193,7 +193,7 @@ def match_questions_with_catalogue_instruments( vectorisation_function: Callable, texts_cached_vectors: dict[str, List[float]], questions_are_from_one_instrument: bool, -) -> CatalogueInstrument: +) -> CatalogueInstrument | None: """ Match questions with catalogue instruments. Each question will receive a list of closest instrument matches, and at the end one closest instrument match for @@ -209,14 +209,18 @@ def match_questions_with_catalogue_instruments( """ # Catalogue data - catalogue_instrument_idx_to_catalogue_questions_idx = catalogue_data[ + catalogue_instrument_idx_to_catalogue_questions_idx: List[List[int]] = catalogue_data[ "instrument_idx_to_question_idx" ] - all_catalogue_questions_embeddings_concatenated = catalogue_data[ + all_catalogue_questions_embeddings_concatenated: np.ndarray = catalogue_data[ "all_embeddings_concatenated" ] - all_catalogue_instruments = catalogue_data["all_instruments"] - all_catalogue_questions = catalogue_data["all_questions"] + all_catalogue_instruments: List[dict] = catalogue_data["all_instruments"] + all_catalogue_questions: List[str] = catalogue_data["all_questions"] + + # No embeddings = nothing to find + if len(all_catalogue_questions_embeddings_concatenated) == 0: + return None # Create text vectors text_vectors, new_vectors_dict = create_full_text_vectors( @@ -308,7 +312,6 @@ def match_questions_with_catalogue_instruments( # Question similarity with catalogue questions for idx, question in enumerate(questions): - seen_in_instruments: List[CatalogueInstrument] = [] for instrument in input_question_idx_to_matching_instruments[idx]: instrument_name = instrument["instrument_name"] From 60b3bbcc2398ad0811fa9623a54818b8d2d6a7fe Mon Sep 17 00:00:00 2001 From: Zairon Jacobs Date: Fri, 21 Jun 2024 08:16:37 +0200 Subject: [PATCH 3/7] allow passing sources list to match catalogue body --- src/harmony/schemas/requests/text.py | 15 ++++++++++----- src/harmony/schemas/responses/text.py | 2 +- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/harmony/schemas/requests/text.py b/src/harmony/schemas/requests/text.py index 3d54c02..40f7501 100644 --- a/src/harmony/schemas/requests/text.py +++ b/src/harmony/schemas/requests/text.py @@ -100,7 +100,7 @@ class Instrument(BaseModel): language: Language = Field(Language.English, description="The ISO 639-2 (alpha-2) encoding of the instrument language") questions: List[Question] = Field(description="The items inside the instrument") - closest_catalogue_instrument_match: CatalogueInstrument = Field( + closest_catalogue_instrument_match: CatalogueInstrument | None = Field( None, description="The closest instrument match in the catalogue for the instrument" ) @@ -211,8 +211,12 @@ class Config: class MatchCatalogueBody(BaseModel): - instruments: List[Instrument] = Field(description="Instruments to match") - parameters: MatchParameters = Field(DEFAULT_MATCH_PARAMETERS, description="Parameters on how to match") + instruments: List[Instrument] = Field(description="Instruments to match.") + parameters: MatchParameters = Field(DEFAULT_MATCH_PARAMETERS, description="Parameters on how to match.") + sources: List[str] = Field( + default=[], + description="The instrument sources to use for matching. If empty, all instrument sources will be considered." + ) class Config: schema_extra = { @@ -270,6 +274,7 @@ class Config: ], "parameters": {"framework": DEFAULT_FRAMEWORK, - "model": DEFAULT_MODEL} + "model": DEFAULT_MODEL}, + "sources": ["MHC"] } - } \ No newline at end of file + } diff --git a/src/harmony/schemas/responses/text.py b/src/harmony/schemas/responses/text.py index 01ff9a7..d058b69 100644 --- a/src/harmony/schemas/responses/text.py +++ b/src/harmony/schemas/responses/text.py @@ -45,7 +45,7 @@ class MatchResponse(BaseModel): class MatchCatalogueResponse(BaseModel): instruments: List[Instrument] = Field(description="A list of instruments") - closest_catalogue_instrument_match: CatalogueInstrument = Field( + closest_catalogue_instrument_match: CatalogueInstrument | None = Field( default=None, description="The closest catalogue instrument match" ) From 21785840c4ceac69e032cd2bd051df99acfca815 Mon Sep 17 00:00:00 2001 From: Zairon Jacobs Date: Fri, 21 Jun 2024 10:23:21 +0200 Subject: [PATCH 4/7] return up to 10 top instr matches in match catalogue func --- src/harmony/matching/matcher.py | 129 ++++++++++++++------------ src/harmony/schemas/requests/text.py | 6 +- src/harmony/schemas/responses/text.py | 5 +- 3 files changed, 75 insertions(+), 65 deletions(-) diff --git a/src/harmony/matching/matcher.py b/src/harmony/matching/matcher.py index 1697adf..b05b887 100644 --- a/src/harmony/matching/matcher.py +++ b/src/harmony/matching/matcher.py @@ -147,7 +147,7 @@ def match_instruments_with_catalogue_instruments( catalogue_data: dict, vectorisation_function: Callable, texts_cached_vectors: dict[str, List[float]], -) -> tuple[List[Instrument], CatalogueInstrument | None]: +) -> tuple[List[Instrument], List[CatalogueInstrument]]: """ Match instruments with catalogue instruments. @@ -156,13 +156,14 @@ def match_instruments_with_catalogue_instruments( :param vectorisation_function: A function to vectorize a text. :param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector). - :return: A tuple, index 0 contains the list of instruments that now each contain the best instrument match from the - catalog, and index 1 contains the closest instrument match from the catalog for all the instruments. + :return: Index 0 in the tuple contains the list of instruments that now each contain the best instrument matches + from the catalog. Index 1 in the tuple contains a list of closest instrument matches from the catalog for all the + instruments. """ - # For each instrument, find the best matching instrument for it in the catalogue + # For each instrument, find the best instrument matches for it in the catalogue for instrument in instruments: - instrument.closest_catalogue_instrument_match = ( + instrument.closest_catalogue_instrument_matches = ( match_questions_with_catalogue_instruments( questions=instrument.questions, catalogue_data=catalogue_data, @@ -172,11 +173,11 @@ def match_instruments_with_catalogue_instruments( ) ) - # Gather all questions from all instruments and find the best matching instrument in the catalogue + # Gather all questions from all instruments and find the best instrument matches in the catalogue all_instrument_questions: List[Question] = [] for instrument in instruments: all_instrument_questions.extend(instrument.questions) - closest_catalogue_instrument_match = match_questions_with_catalogue_instruments( + closest_catalogue_instrument_matches = match_questions_with_catalogue_instruments( questions=all_instrument_questions, catalogue_data=catalogue_data, vectorisation_function=vectorisation_function, @@ -184,7 +185,7 @@ def match_instruments_with_catalogue_instruments( questions_are_from_one_instrument=False, ) - return instruments, closest_catalogue_instrument_match + return instruments, closest_catalogue_instrument_matches def match_questions_with_catalogue_instruments( @@ -193,7 +194,7 @@ def match_questions_with_catalogue_instruments( vectorisation_function: Callable, texts_cached_vectors: dict[str, List[float]], questions_are_from_one_instrument: bool, -) -> CatalogueInstrument | None: +) -> List[CatalogueInstrument]: """ Match questions with catalogue instruments. Each question will receive a list of closest instrument matches, and at the end one closest instrument match for @@ -205,7 +206,7 @@ def match_questions_with_catalogue_instruments( :param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector). :param questions_are_from_one_instrument: If the questions provided are coming from one instrument only. - :return: The closest instrument match for the questions provided. + :return: A list of closest instrument matches for the questions provided. """ # Catalogue data @@ -220,7 +221,7 @@ def match_questions_with_catalogue_instruments( # No embeddings = nothing to find if len(all_catalogue_questions_embeddings_concatenated) == 0: - return None + return [] # Create text vectors text_vectors, new_vectors_dict = create_full_text_vectors( @@ -253,10 +254,10 @@ def match_questions_with_catalogue_instruments( # For each input question, this is the index of the single closest matching question text in our catalogues. # Note that each question text in the catalogue (vector index) is unique, and we must later do a further mapping to # find out which instrument(s) it occurs in. - idxs_top_input_questions_matches = np.argmax(catalogue_similarities, axis=1) + idxs_of_top_questions_matched_in_catalogue = np.argmax(catalogue_similarities, axis=1) # Get a set of all the top matching question text indices in our catalogue. - # idxs_top_input_questions_matches_set = set(idxs_top_input_questions_matches) + idxs_of_top_questions_matched_in_catalogue_set = set(idxs_of_top_questions_matched_in_catalogue) # This keeps track of each instrument matches how many question items in the query # e.g. if the first instrument in our catalogue (instrument 0) matches 4 items, then this dictionary will @@ -265,7 +266,7 @@ def match_questions_with_catalogue_instruments( # This dictionary will contain the index of the instrument and the cosine similarities to the top matched questions # in that instrument e.g. {50: [ ... ]} - instrument_idx_to_cosine_similarities_top_match: dict[int, []] = {} + instrument_idx_to_cosine_similarities_top_match: dict[int, [float]] = {} # This keeps track of how many question items in total are contained in each instrument, irrespective of the # number of matches. @@ -277,7 +278,7 @@ def match_questions_with_catalogue_instruments( for input_question_idx in range(len(questions)): input_question_idx_to_matching_instruments.append([]) for input_question_idx in range(len(questions)): - top_match_catalogue_question_idx = idxs_top_input_questions_matches[ + top_match_catalogue_question_idx = idxs_of_top_questions_matched_in_catalogue[ input_question_idx ] for instrument_idx, question_idxs_in_this_instrument in enumerate( @@ -303,7 +304,7 @@ def match_questions_with_catalogue_instruments( ) # instrument_idx_to_num_matching_items_with_query[instrument_idx] = len( # catalogue_question_idxs_in_this_instrument_set.intersection( - # idxs_top_input_questions_matches_set + # idxs_of_top_questions_matched_in_catalogue_set # ) # ) instrument_idx_to_total_num_question_items_present[instrument_idx] = len( @@ -328,13 +329,13 @@ def match_questions_with_catalogue_instruments( ) question.closest_catalogue_question_match = CatalogueQuestion( - question=all_catalogue_questions[idxs_top_input_questions_matches[idx]], + question=all_catalogue_questions[idxs_of_top_questions_matched_in_catalogue[idx]], seen_in_instruments=seen_in_instruments, ) # Instrument index to list of cosine similarities top question match for input_question_idx, idx_top_input_question_match_in_catalogue in enumerate( - idxs_top_input_questions_matches + idxs_of_top_questions_matched_in_catalogue ): for ( catalogue_instrument_idx, @@ -367,60 +368,66 @@ def match_questions_with_catalogue_instruments( } # Calculate the average for each list of cosine similarities from instruments + instrument_idx_to_cosine_similarities_average: dict[int, float] = {} for ( instrument_idx, cosine_similarities, ) in instrument_idx_to_cosine_similarities_top_match.items(): - instrument_idx_to_cosine_similarities_top_match[instrument_idx] = ( + instrument_idx_to_cosine_similarities_average[instrument_idx] = ( statistics.mean(cosine_similarities) ) - # Find the index of the best instrument match - best_catalogue_instrument_idx = max( - instrument_idx_to_cosine_similarities_top_match, - key=instrument_idx_to_cosine_similarities_top_match.get, - ) - - # Get the best instrument match - best_catalogue_instrument = all_catalogue_instruments[best_catalogue_instrument_idx] - num_questions_in_ref_instrument = ( - instrument_idx_to_total_num_question_items_present[ - best_catalogue_instrument_idx + # Find the top 10 best instrument idx matches, index 0 containing the best match etc. + top_10_catalogue_instrument_idxs = sorted( + instrument_idx_to_cosine_similarities_average, + key=instrument_idx_to_cosine_similarities_average.get, + reverse=True + )[:10] + + # Create a list of CatalogueInstrument for each top instrument + top_instruments: List[CatalogueInstrument] = [] + for top_catalogue_instrument_idx in top_10_catalogue_instrument_idxs: + top_catalogue_instrument = all_catalogue_instruments[top_catalogue_instrument_idx] + num_questions_in_ref_instrument = ( + instrument_idx_to_total_num_question_items_present[ + top_catalogue_instrument_idx + ] + ) + num_top_match_questions = instrument_idx_to_top_matches_ct[ + top_catalogue_instrument_idx ] - ) - num_top_match_questions = instrument_idx_to_top_matches_ct[ - best_catalogue_instrument_idx - ] - instrument_name = best_catalogue_instrument["instrument_name"] - instrument_url = best_catalogue_instrument["metadata"].get("url", "") - source = best_catalogue_instrument["metadata"]["source"].upper() - sweep = best_catalogue_instrument["metadata"].get("sweep_id", "") + instrument_name = top_catalogue_instrument["instrument_name"] + instrument_url = top_catalogue_instrument["metadata"].get("url", "") + source = top_catalogue_instrument["metadata"]["source"].upper() + sweep = top_catalogue_instrument["metadata"].get("sweep_id", "") - if questions_are_from_one_instrument: - info = ( - f"{instrument_name} Sweep {sweep if sweep else 'UNKNOWN'} matched {num_top_match_questions} " - f"question(s) in your instrument, your instrument contains {num_input_questions} question(s). " - f"The reference instrument contains {num_questions_in_ref_instrument} question(s)." - ) - else: - info = ( - f"{instrument_name} Sweep {sweep if sweep else 'UNKNOWN'} matched {num_top_match_questions} " - f"question(s) in all of your instruments, your instruments contains {num_input_questions} " - f"question(s). The reference instrument contains {num_questions_in_ref_instrument} question(s)." - ) + if questions_are_from_one_instrument: + info = ( + f"{instrument_name} Sweep {sweep if sweep else 'UNKNOWN'} matched {num_top_match_questions} " + f"question(s) in your instrument, your instrument contains {num_input_questions} question(s). " + f"The reference instrument contains {num_questions_in_ref_instrument} question(s)." + ) + else: + info = ( + f"{instrument_name} Sweep {sweep if sweep else 'UNKNOWN'} matched {num_top_match_questions} " + f"question(s) in all of your instruments, your instruments contains {num_input_questions} " + f"question(s). The reference instrument contains {num_questions_in_ref_instrument} question(s)." + ) - return CatalogueInstrument( - instrument_name=instrument_name, - instrument_url=instrument_url, - source=source, - sweep=sweep, - metadata={ - "info": info, - "num_matched_questions": num_top_match_questions, - "num_ref_instrument_questions": num_questions_in_ref_instrument, - }, - ) + top_instruments.append(CatalogueInstrument( + instrument_name=instrument_name, + instrument_url=instrument_url, + source=source, + sweep=sweep, + metadata={ + "info": info, + "num_matched_questions": num_top_match_questions, + "num_ref_instrument_questions": num_questions_in_ref_instrument, + }, + )) + + return top_instruments # diff --git a/src/harmony/schemas/requests/text.py b/src/harmony/schemas/requests/text.py index 40f7501..4ac2e9a 100644 --- a/src/harmony/schemas/requests/text.py +++ b/src/harmony/schemas/requests/text.py @@ -100,8 +100,10 @@ class Instrument(BaseModel): language: Language = Field(Language.English, description="The ISO 639-2 (alpha-2) encoding of the instrument language") questions: List[Question] = Field(description="The items inside the instrument") - closest_catalogue_instrument_match: CatalogueInstrument | None = Field( - None, description="The closest instrument match in the catalogue for the instrument" + closest_catalogue_instrument_matches: List[CatalogueInstrument] = Field( + None, + description="The closest instrument matches in the catalogue for the instrument, the first index " + "contains the best match etc" ) class Config: diff --git a/src/harmony/schemas/responses/text.py b/src/harmony/schemas/responses/text.py index d058b69..47f26ab 100644 --- a/src/harmony/schemas/responses/text.py +++ b/src/harmony/schemas/responses/text.py @@ -45,9 +45,10 @@ class MatchResponse(BaseModel): class MatchCatalogueResponse(BaseModel): instruments: List[Instrument] = Field(description="A list of instruments") - closest_catalogue_instrument_match: CatalogueInstrument | None = Field( + closest_catalogue_instrument_matches: List[CatalogueInstrument] = Field( default=None, - description="The closest catalogue instrument match" + description="The closest catalogue instrument matches in the catalogue for all the instruments, " + "the first index contains the best match etc." ) From 706b5cb35706eeb13549e13392139b3f990cf451 Mon Sep 17 00:00:00 2001 From: Zairon Jacobs Date: Fri, 21 Jun 2024 14:14:40 +0200 Subject: [PATCH 5/7] create all instruments question text vectors at one place in match catalogue func and return new vectors --- src/harmony/matching/matcher.py | 50 ++++++++++++++++++--------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/harmony/matching/matcher.py b/src/harmony/matching/matcher.py index b05b887..eb5fcd2 100644 --- a/src/harmony/matching/matcher.py +++ b/src/harmony/matching/matcher.py @@ -147,7 +147,7 @@ def match_instruments_with_catalogue_instruments( catalogue_data: dict, vectorisation_function: Callable, texts_cached_vectors: dict[str, List[float]], -) -> tuple[List[Instrument], List[CatalogueInstrument]]: +) -> tuple[List[Instrument], List[CatalogueInstrument], dict[str, float]]: """ Match instruments with catalogue instruments. @@ -158,17 +158,30 @@ def match_instruments_with_catalogue_instruments( :return: Index 0 in the tuple contains the list of instruments that now each contain the best instrument matches from the catalog. Index 1 in the tuple contains a list of closest instrument matches from the catalog for all the - instruments. + instruments. Index 2 contains the new text vectors to be cached. """ + # Gather all questions + all_questions: List[str] = [] + for instrument in instruments: + all_questions.extend([q.question_text for q in instrument.questions]) + all_questions = list(set(all_questions)) + + # Create text vectors for all questions in all the uploaded instruments + all_instruments_text_vectors, new_text_vectors_dict = create_full_text_vectors( + all_questions=all_questions, + query=None, + vectorisation_function=vectorisation_function, + texts_cached_vectors=texts_cached_vectors, + ) + # For each instrument, find the best instrument matches for it in the catalogue for instrument in instruments: instrument.closest_catalogue_instrument_matches = ( match_questions_with_catalogue_instruments( questions=instrument.questions, catalogue_data=catalogue_data, - vectorisation_function=vectorisation_function, - texts_cached_vectors=texts_cached_vectors, + all_instruments_text_vectors=all_instruments_text_vectors, questions_are_from_one_instrument=True, ) ) @@ -180,19 +193,17 @@ def match_instruments_with_catalogue_instruments( closest_catalogue_instrument_matches = match_questions_with_catalogue_instruments( questions=all_instrument_questions, catalogue_data=catalogue_data, - vectorisation_function=vectorisation_function, - texts_cached_vectors=texts_cached_vectors, + all_instruments_text_vectors=all_instruments_text_vectors, questions_are_from_one_instrument=False, ) - return instruments, closest_catalogue_instrument_matches + return instruments, closest_catalogue_instrument_matches, new_text_vectors_dict def match_questions_with_catalogue_instruments( questions: List[Question], catalogue_data: dict, - vectorisation_function: Callable, - texts_cached_vectors: dict[str, List[float]], + all_instruments_text_vectors: List[TextVector], questions_are_from_one_instrument: bool, ) -> List[CatalogueInstrument]: """ @@ -202,8 +213,7 @@ def match_questions_with_catalogue_instruments( :param questions: The questions. :param catalogue_data: The catalogue data. - :param vectorisation_function: A function to vectorize a text. - :param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector). + :param all_instruments_text_vectors: A list of text vectors of all questions found in all the instruments uploaded. :param questions_are_from_one_instrument: If the questions provided are coming from one instrument only. :return: A list of closest instrument matches for the questions provided. @@ -223,24 +233,18 @@ def match_questions_with_catalogue_instruments( if len(all_catalogue_questions_embeddings_concatenated) == 0: return [] - # Create text vectors - text_vectors, new_vectors_dict = create_full_text_vectors( - all_questions=[q.question_text for q in questions], - query=None, - vectorisation_function=vectorisation_function, - texts_cached_vectors=texts_cached_vectors, - ) + # All instruments text vectors to dict + all_instruments_text_vectors_dict = { + text_vector.text: text_vector.vector for text_vector in all_instruments_text_vectors + } # The total number of questions we received as input. num_input_questions = len(questions) # Get an array of dimensions. # (number of input questions) x (number of dimensions of LLM - typically 768, 384, 500, 512, etc.) - text_vectors_dict = { - text_vector.text: text_vector.vector for text_vector in text_vectors - } vectors = np.array( - [text_vectors_dict[question.question_text] for question in questions] + [all_instruments_text_vectors_dict[question.question_text] for question in questions] ) # Get a 2D array of (number of input questions) x (number of questions in catalogue). @@ -257,7 +261,7 @@ def match_questions_with_catalogue_instruments( idxs_of_top_questions_matched_in_catalogue = np.argmax(catalogue_similarities, axis=1) # Get a set of all the top matching question text indices in our catalogue. - idxs_of_top_questions_matched_in_catalogue_set = set(idxs_of_top_questions_matched_in_catalogue) + # idxs_of_top_questions_matched_in_catalogue_set = set(idxs_of_top_questions_matched_in_catalogue) # This keeps track of each instrument matches how many question items in the query # e.g. if the first instrument in our catalogue (instrument 0) matches 4 items, then this dictionary will From 7c9baad3d115657eda549e19e47573e5dc8ebf04 Mon Sep 17 00:00:00 2001 From: Zairon Jacobs Date: Mon, 24 Jun 2024 14:26:45 +0200 Subject: [PATCH 6/7] return catalogue match data at match match_instruments_with_functio func --- src/harmony/matching/default_matcher.py | 9 ++- src/harmony/matching/matcher.py | 101 +++++++++++++++--------- src/harmony/schemas/requests/text.py | 76 +----------------- src/harmony/schemas/responses/text.py | 7 +- 4 files changed, 73 insertions(+), 120 deletions(-) diff --git a/src/harmony/matching/default_matcher.py b/src/harmony/matching/default_matcher.py index 946e510..9670d90 100644 --- a/src/harmony/matching/default_matcher.py +++ b/src/harmony/matching/default_matcher.py @@ -1,4 +1,4 @@ -''' +""" MIT License Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk). @@ -22,8 +22,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -''' +""" import os from typing import List @@ -61,6 +60,8 @@ def match_instruments( mhc_all_metadatas: List = [], mhc_embeddings: np.ndarray = np.zeros((0, 0)), texts_cached_vectors: dict[str, List[float]] = {}, + include_catalogue_matches: bool = False, + catalogue_data: dict = {}, ) -> tuple: return match_instruments_with_function( instruments=instruments, @@ -70,4 +71,6 @@ def match_instruments( mhc_all_metadatas=mhc_all_metadatas, mhc_embeddings=mhc_embeddings, texts_cached_vectors=texts_cached_vectors, + include_catalogue_matches=include_catalogue_matches, + catalogue_data=catalogue_data, ) diff --git a/src/harmony/matching/matcher.py b/src/harmony/matching/matcher.py index eb5fcd2..dab0641 100644 --- a/src/harmony/matching/matcher.py +++ b/src/harmony/matching/matcher.py @@ -107,10 +107,10 @@ def vectors_pos_neg(text_vectors): def create_full_text_vectors( - all_questions: List[str], - query: str | None, - vectorisation_function: Callable, - texts_cached_vectors: dict[str, list[float]], + all_questions: List[str], + query: str | None, + vectorisation_function: Callable, + texts_cached_vectors: dict[str, list[float]], ) -> tuple[List[TextVector], dict]: """ Create full text vectors. @@ -143,11 +143,11 @@ def create_full_text_vectors( def match_instruments_with_catalogue_instruments( - instruments: List[Instrument], - catalogue_data: dict, - vectorisation_function: Callable, - texts_cached_vectors: dict[str, List[float]], -) -> tuple[List[Instrument], List[CatalogueInstrument], dict[str, float]]: + instruments: List[Instrument], + catalogue_data: dict, + vectorisation_function: Callable, + texts_cached_vectors: dict[str, List[float]], +) -> tuple[List[Instrument], List[CatalogueInstrument]]: """ Match instruments with catalogue instruments. @@ -156,9 +156,7 @@ def match_instruments_with_catalogue_instruments( :param vectorisation_function: A function to vectorize a text. :param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector). - :return: Index 0 in the tuple contains the list of instruments that now each contain the best instrument matches - from the catalog. Index 1 in the tuple contains a list of closest instrument matches from the catalog for all the - instruments. Index 2 contains the new text vectors to be cached. + :return: Index 0 in the tuple contains the list of instruments that now each contain the best instrument matches from the catalog. Index 1 in the tuple contains a list of closest instrument matches from the catalog for all the instruments. """ # Gather all questions @@ -168,7 +166,7 @@ def match_instruments_with_catalogue_instruments( all_questions = list(set(all_questions)) # Create text vectors for all questions in all the uploaded instruments - all_instruments_text_vectors, new_text_vectors_dict = create_full_text_vectors( + all_instruments_text_vectors, _ = create_full_text_vectors( all_questions=all_questions, query=None, vectorisation_function=vectorisation_function, @@ -197,14 +195,14 @@ def match_instruments_with_catalogue_instruments( questions_are_from_one_instrument=False, ) - return instruments, closest_catalogue_instrument_matches, new_text_vectors_dict + return instruments, closest_catalogue_instrument_matches def match_questions_with_catalogue_instruments( - questions: List[Question], - catalogue_data: dict, - all_instruments_text_vectors: List[TextVector], - questions_are_from_one_instrument: bool, + questions: List[Question], + catalogue_data: dict, + all_instruments_text_vectors: List[TextVector], + questions_are_from_one_instrument: bool, ) -> List[CatalogueInstrument]: """ Match questions with catalogue instruments. @@ -286,13 +284,13 @@ def match_questions_with_catalogue_instruments( input_question_idx ] for instrument_idx, question_idxs_in_this_instrument in enumerate( - catalogue_instrument_idx_to_catalogue_questions_idx + catalogue_instrument_idx_to_catalogue_questions_idx ): if top_match_catalogue_question_idx in question_idxs_in_this_instrument: instrument_from_catalogue = all_catalogue_instruments[instrument_idx] if not any( - x["instrument_name"] == instrument_from_catalogue["instrument_name"] - for x in input_question_idx_to_matching_instruments[input_question_idx] + x["instrument_name"] == instrument_from_catalogue["instrument_name"] + for x in input_question_idx_to_matching_instruments[input_question_idx] ): input_question_idx_to_matching_instruments[ input_question_idx @@ -301,7 +299,7 @@ def match_questions_with_catalogue_instruments( # For each catalogue instrument get the total number of question matches in the query # For each catalogue instrument get the total number of questions for instrument_idx, question_idxs_in_this_instrument in enumerate( - catalogue_instrument_idx_to_catalogue_questions_idx + catalogue_instrument_idx_to_catalogue_questions_idx ): catalogue_question_idxs_in_this_instrument_set = set( question_idxs_in_this_instrument @@ -339,11 +337,11 @@ def match_questions_with_catalogue_instruments( # Instrument index to list of cosine similarities top question match for input_question_idx, idx_top_input_question_match_in_catalogue in enumerate( - idxs_of_top_questions_matched_in_catalogue + idxs_of_top_questions_matched_in_catalogue ): for ( - catalogue_instrument_idx, - catalogue_question_idxs_in_this_instrument, + catalogue_instrument_idx, + catalogue_question_idxs_in_this_instrument, ) in enumerate(catalogue_instrument_idx_to_catalogue_questions_idx): catalogue_question_idxs_set = set( catalogue_question_idxs_in_this_instrument @@ -351,7 +349,7 @@ def match_questions_with_catalogue_instruments( if idx_top_input_question_match_in_catalogue in catalogue_question_idxs_set: # Create the list if it doesn't exist yet if not instrument_idx_to_cosine_similarities_top_match.get( - catalogue_instrument_idx + catalogue_instrument_idx ): instrument_idx_to_cosine_similarities_top_match[ catalogue_instrument_idx @@ -374,8 +372,8 @@ def match_questions_with_catalogue_instruments( # Calculate the average for each list of cosine similarities from instruments instrument_idx_to_cosine_similarities_average: dict[int, float] = {} for ( - instrument_idx, - cosine_similarities, + instrument_idx, + cosine_similarities, ) in instrument_idx_to_cosine_similarities_top_match.items(): instrument_idx_to_cosine_similarities_average[instrument_idx] = ( statistics.mean(cosine_similarities) @@ -443,6 +441,8 @@ def match_instruments_with_function( mhc_all_metadatas: List = [], mhc_embeddings: np.ndarray = np.zeros((0, 0)), texts_cached_vectors: dict[str, List[float]] = {}, + include_catalogue_matches: bool = False, + catalogue_data: dict = {}, ) -> tuple: """ Match instruments. @@ -450,21 +450,27 @@ def match_instruments_with_function( :param instruments: The instruments :param query: The query :param vectorisation_function: A function to vectorize a text - :param mhc_questions - :param mhc_all_metadatas - :param mhc_embeddings + :param mhc_questions: MHC questions. + :param mhc_all_metadatas: MHC metadatas. + :param mhc_embeddings: MHC embeddings. :param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector) + :param include_catalogue_matches: Include catalogue instrument matches in the result. + :param catalogue_data: The catalogue data. + + :return: Index 0 contains a list of all the questions from the instruments (if include_catalogue_matches is True, each question now contains the best matched question from the catalogue). Index 1 contains similarity with polarity. Index 2 contains the query similarity. Index 3 contains a dict with the new text vectors to be cached. Index 4 contains a list of all instruments (if include_catalogue_matches is True, each instrument now contains a list of closest instrument matches). Index 5 contains a list of closest catalogue instrument matches for all the instruments (only if include_catalogue_matches is True). """ - all_questions = [] + all_questions: List[Question] = [] for instrument in instruments: all_questions.extend(instrument.questions) - # all_questions: List[Question] = all_questions - all_questions_str: List[str] = [q.question_text for q in all_questions] - # all_questions = [instrument["question_text"] for instrument in instruments] - text_vectors, new_vectors_dict = create_full_text_vectors(all_questions_str, query, vectorisation_function, - texts_cached_vectors) + text_vectors, new_vectors_dict = create_full_text_vectors( + all_questions=[q.question_text for q in all_questions], + query=query, + vectorisation_function=vectorisation_function, + texts_cached_vectors=texts_cached_vectors + ) + # get vectors for all original texts and vectors for negated texts vectors_pos, vectors_neg = vectors_pos_neg(text_vectors) @@ -478,7 +484,7 @@ def match_instruments_with_function( query_similarity = np.array([]) # Get similarity with polarity - if vectors_pos.any(): # NOTE: Should an error be thrown if vectors_pos is empty? + if vectors_pos.any(): # NOTE: Should an error be thrown if vectors_pos is empty? pairwise_similarity = cosine_similarity(vectors_pos, vectors_pos) # NOTE: Similarity of (vectors_neg, vectors_pos) & (vectors_pos, vectors_neg) should be the same pairwise_similarity_neg1 = cosine_similarity(vectors_neg, vectors_pos) @@ -536,4 +542,21 @@ def match_instruments_with_function( for question in all_questions: question.topics_auto = [] - return all_questions, similarity_with_polarity, query_similarity, new_vectors_dict + # Get catalogue matches + closest_catalogue_instrument_matches = [] + if include_catalogue_matches: + instruments, closest_catalogue_instrument_matches = match_instruments_with_catalogue_instruments( + instruments=instruments, + catalogue_data=catalogue_data, + vectorisation_function=vectorisation_function, + texts_cached_vectors=texts_cached_vectors, + ) + + return ( + all_questions, + similarity_with_polarity, + query_similarity, + new_vectors_dict, + instruments, + closest_catalogue_instrument_matches + ) diff --git a/src/harmony/schemas/requests/text.py b/src/harmony/schemas/requests/text.py index 4ac2e9a..b2bd0e3 100644 --- a/src/harmony/schemas/requests/text.py +++ b/src/harmony/schemas/requests/text.py @@ -1,4 +1,4 @@ -''' +""" MIT License Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk). @@ -22,8 +22,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -''' +""" from typing import List @@ -101,7 +100,7 @@ class Instrument(BaseModel): description="The ISO 639-2 (alpha-2) encoding of the instrument language") questions: List[Question] = Field(description="The items inside the instrument") closest_catalogue_instrument_matches: List[CatalogueInstrument] = Field( - None, + [], description="The closest instrument matches in the catalogue for the instrument, the first index " "contains the best match etc" ) @@ -211,72 +210,3 @@ class Config: } } - -class MatchCatalogueBody(BaseModel): - instruments: List[Instrument] = Field(description="Instruments to match.") - parameters: MatchParameters = Field(DEFAULT_MATCH_PARAMETERS, description="Parameters on how to match.") - sources: List[str] = Field( - default=[], - description="The instrument sources to use for matching. If empty, all instrument sources will be considered." - ) - - class Config: - schema_extra = { - "example": { - "instruments": [{ - "file_id": "fd60a9a64b1b4078a68f4bc06f20253c", - "instrument_id": "7829ba96f48e4848abd97884911b6795", - "instrument_name": "GAD-7 English", - "file_name": "GAD-7 EN.pdf", - "file_type": "pdf", - "file_section": "GAD-7 English", - "language": "en", - "questions": [{"question_no": "1", - "question_intro": "Over the last two weeks, how often have you been bothered by the following problems?", - "question_text": "Feeling nervous, anxious, or on edge", - "options": ["Not at all", "Several days", "More than half the days", - "Nearly every day"], - "source_page": 0 - }, - {"question_no": "2", - "question_intro": "Over the last two weeks, how often have you been bothered by the following problems?", - "question_text": "Not being able to stop or control worrying", - "options": ["Not at all", "Several days", "More than half the days", - "Nearly every day"], - "source_page": 0 - } - - ] - }, - { - "file_id": "fd60a9a64b1b4078a68f4bc06f20253c", - "instrument_id": "7829ba96f48e4848abd97884911b6795", - "instrument_name": "GAD-7 Portuguese", - "file_name": "GAD-7 PT.pdf", - "file_type": "pdf", - "file_section": "GAD-7 Portuguese", - "language": "en", - "questions": [{"question_no": "1", - "question_intro": "Durante as últimas 2 semanas, com que freqüência você foi incomodado/a pelos problemas abaixo?", - "question_text": "Sentir-se nervoso/a, ansioso/a ou muito tenso/a", - "options": ["Nenhuma vez", "Vários dias", "Mais da metade dos dias", - "Quase todos os dias"], - "source_page": 0 - }, - {"question_no": "2", - "question_intro": "Durante as últimas 2 semanas, com que freqüência você foi incomodado/a pelos problemas abaixo?", - "question_text": " Não ser capaz de impedir ou de controlar as preocupações", - "options": ["Nenhuma vez", "Vários dias", "Mais da metade dos dias", - "Quase todos os dias"], - "source_page": 0 - } - - ] - } - - ], - "parameters": {"framework": DEFAULT_FRAMEWORK, - "model": DEFAULT_MODEL}, - "sources": ["MHC"] - } - } diff --git a/src/harmony/schemas/responses/text.py b/src/harmony/schemas/responses/text.py index 47f26ab..496f364 100644 --- a/src/harmony/schemas/responses/text.py +++ b/src/harmony/schemas/responses/text.py @@ -34,6 +34,7 @@ class MatchResponse(BaseModel): + instruments: List[Instrument] = Field(description="A list of instruments") questions: List[Question] = Field( description="The questions which were matched, in an order matching the order of the matrix" ) @@ -41,12 +42,8 @@ class MatchResponse(BaseModel): query_similarity: List = Field( None, description="Similarity metric between query string and items" ) - - -class MatchCatalogueResponse(BaseModel): - instruments: List[Instrument] = Field(description="A list of instruments") closest_catalogue_instrument_matches: List[CatalogueInstrument] = Field( - default=None, + default=[], description="The closest catalogue instrument matches in the catalogue for all the instruments, " "the first index contains the best match etc." ) From bcaa94d5ebd28f5664ec386bbd667e73ee2e4792 Mon Sep 17 00:00:00 2001 From: Zairon Jacobs Date: Tue, 25 Jun 2024 10:44:02 +0200 Subject: [PATCH 7/7] remove extra data in tuple res from func match_instruments_with_function to prevent unpack err --- src/harmony/matching/default_matcher.py | 4 ---- src/harmony/matching/matcher.py | 22 ++-------------------- 2 files changed, 2 insertions(+), 24 deletions(-) diff --git a/src/harmony/matching/default_matcher.py b/src/harmony/matching/default_matcher.py index 9670d90..1f8ada7 100644 --- a/src/harmony/matching/default_matcher.py +++ b/src/harmony/matching/default_matcher.py @@ -60,8 +60,6 @@ def match_instruments( mhc_all_metadatas: List = [], mhc_embeddings: np.ndarray = np.zeros((0, 0)), texts_cached_vectors: dict[str, List[float]] = {}, - include_catalogue_matches: bool = False, - catalogue_data: dict = {}, ) -> tuple: return match_instruments_with_function( instruments=instruments, @@ -71,6 +69,4 @@ def match_instruments( mhc_all_metadatas=mhc_all_metadatas, mhc_embeddings=mhc_embeddings, texts_cached_vectors=texts_cached_vectors, - include_catalogue_matches=include_catalogue_matches, - catalogue_data=catalogue_data, ) diff --git a/src/harmony/matching/matcher.py b/src/harmony/matching/matcher.py index dab0641..b82f785 100644 --- a/src/harmony/matching/matcher.py +++ b/src/harmony/matching/matcher.py @@ -441,8 +441,6 @@ def match_instruments_with_function( mhc_all_metadatas: List = [], mhc_embeddings: np.ndarray = np.zeros((0, 0)), texts_cached_vectors: dict[str, List[float]] = {}, - include_catalogue_matches: bool = False, - catalogue_data: dict = {}, ) -> tuple: """ Match instruments. @@ -453,11 +451,7 @@ def match_instruments_with_function( :param mhc_questions: MHC questions. :param mhc_all_metadatas: MHC metadatas. :param mhc_embeddings: MHC embeddings. - :param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector) - :param include_catalogue_matches: Include catalogue instrument matches in the result. - :param catalogue_data: The catalogue data. - - :return: Index 0 contains a list of all the questions from the instruments (if include_catalogue_matches is True, each question now contains the best matched question from the catalogue). Index 1 contains similarity with polarity. Index 2 contains the query similarity. Index 3 contains a dict with the new text vectors to be cached. Index 4 contains a list of all instruments (if include_catalogue_matches is True, each instrument now contains a list of closest instrument matches). Index 5 contains a list of closest catalogue instrument matches for all the instruments (only if include_catalogue_matches is True). + :param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector). """ all_questions: List[Question] = [] @@ -542,21 +536,9 @@ def match_instruments_with_function( for question in all_questions: question.topics_auto = [] - # Get catalogue matches - closest_catalogue_instrument_matches = [] - if include_catalogue_matches: - instruments, closest_catalogue_instrument_matches = match_instruments_with_catalogue_instruments( - instruments=instruments, - catalogue_data=catalogue_data, - vectorisation_function=vectorisation_function, - texts_cached_vectors=texts_cached_vectors, - ) - return ( all_questions, similarity_with_polarity, query_similarity, - new_vectors_dict, - instruments, - closest_catalogue_instrument_matches + new_vectors_dict )