Skip to content

Commit

Permalink
return catalogue match data at match match_instruments_with_functio func
Browse files Browse the repository at this point in the history
  • Loading branch information
zaironjacobs committed Jun 24, 2024
1 parent 706b5cb commit 7c9baad
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 120 deletions.
9 changes: 6 additions & 3 deletions src/harmony/matching/default_matcher.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
'''
"""
MIT License
Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk).
Expand All @@ -22,8 +22,7 @@
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
"""

import os
from typing import List
Expand Down Expand Up @@ -61,6 +60,8 @@ def match_instruments(
mhc_all_metadatas: List = [],
mhc_embeddings: np.ndarray = np.zeros((0, 0)),
texts_cached_vectors: dict[str, List[float]] = {},
include_catalogue_matches: bool = False,
catalogue_data: dict = {},
) -> tuple:
return match_instruments_with_function(
instruments=instruments,
Expand All @@ -70,4 +71,6 @@ def match_instruments(
mhc_all_metadatas=mhc_all_metadatas,
mhc_embeddings=mhc_embeddings,
texts_cached_vectors=texts_cached_vectors,
include_catalogue_matches=include_catalogue_matches,
catalogue_data=catalogue_data,
)
101 changes: 62 additions & 39 deletions src/harmony/matching/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ def vectors_pos_neg(text_vectors):


def create_full_text_vectors(
all_questions: List[str],
query: str | None,
vectorisation_function: Callable,
texts_cached_vectors: dict[str, list[float]],
all_questions: List[str],
query: str | None,
vectorisation_function: Callable,
texts_cached_vectors: dict[str, list[float]],
) -> tuple[List[TextVector], dict]:
"""
Create full text vectors.
Expand Down Expand Up @@ -143,11 +143,11 @@ def create_full_text_vectors(


def match_instruments_with_catalogue_instruments(
instruments: List[Instrument],
catalogue_data: dict,
vectorisation_function: Callable,
texts_cached_vectors: dict[str, List[float]],
) -> tuple[List[Instrument], List[CatalogueInstrument], dict[str, float]]:
instruments: List[Instrument],
catalogue_data: dict,
vectorisation_function: Callable,
texts_cached_vectors: dict[str, List[float]],
) -> tuple[List[Instrument], List[CatalogueInstrument]]:
"""
Match instruments with catalogue instruments.
Expand All @@ -156,9 +156,7 @@ def match_instruments_with_catalogue_instruments(
:param vectorisation_function: A function to vectorize a text.
:param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector).
:return: Index 0 in the tuple contains the list of instruments that now each contain the best instrument matches
from the catalog. Index 1 in the tuple contains a list of closest instrument matches from the catalog for all the
instruments. Index 2 contains the new text vectors to be cached.
:return: Index 0 in the tuple contains the list of instruments that now each contain the best instrument matches from the catalog. Index 1 in the tuple contains a list of closest instrument matches from the catalog for all the instruments.
"""

# Gather all questions
Expand All @@ -168,7 +166,7 @@ def match_instruments_with_catalogue_instruments(
all_questions = list(set(all_questions))

# Create text vectors for all questions in all the uploaded instruments
all_instruments_text_vectors, new_text_vectors_dict = create_full_text_vectors(
all_instruments_text_vectors, _ = create_full_text_vectors(
all_questions=all_questions,
query=None,
vectorisation_function=vectorisation_function,
Expand Down Expand Up @@ -197,14 +195,14 @@ def match_instruments_with_catalogue_instruments(
questions_are_from_one_instrument=False,
)

return instruments, closest_catalogue_instrument_matches, new_text_vectors_dict
return instruments, closest_catalogue_instrument_matches


def match_questions_with_catalogue_instruments(
questions: List[Question],
catalogue_data: dict,
all_instruments_text_vectors: List[TextVector],
questions_are_from_one_instrument: bool,
questions: List[Question],
catalogue_data: dict,
all_instruments_text_vectors: List[TextVector],
questions_are_from_one_instrument: bool,
) -> List[CatalogueInstrument]:
"""
Match questions with catalogue instruments.
Expand Down Expand Up @@ -286,13 +284,13 @@ def match_questions_with_catalogue_instruments(
input_question_idx
]
for instrument_idx, question_idxs_in_this_instrument in enumerate(
catalogue_instrument_idx_to_catalogue_questions_idx
catalogue_instrument_idx_to_catalogue_questions_idx
):
if top_match_catalogue_question_idx in question_idxs_in_this_instrument:
instrument_from_catalogue = all_catalogue_instruments[instrument_idx]
if not any(
x["instrument_name"] == instrument_from_catalogue["instrument_name"]
for x in input_question_idx_to_matching_instruments[input_question_idx]
x["instrument_name"] == instrument_from_catalogue["instrument_name"]
for x in input_question_idx_to_matching_instruments[input_question_idx]
):
input_question_idx_to_matching_instruments[
input_question_idx
Expand All @@ -301,7 +299,7 @@ def match_questions_with_catalogue_instruments(
# For each catalogue instrument get the total number of question matches in the query
# For each catalogue instrument get the total number of questions
for instrument_idx, question_idxs_in_this_instrument in enumerate(
catalogue_instrument_idx_to_catalogue_questions_idx
catalogue_instrument_idx_to_catalogue_questions_idx
):
catalogue_question_idxs_in_this_instrument_set = set(
question_idxs_in_this_instrument
Expand Down Expand Up @@ -339,19 +337,19 @@ def match_questions_with_catalogue_instruments(

# Instrument index to list of cosine similarities top question match
for input_question_idx, idx_top_input_question_match_in_catalogue in enumerate(
idxs_of_top_questions_matched_in_catalogue
idxs_of_top_questions_matched_in_catalogue
):
for (
catalogue_instrument_idx,
catalogue_question_idxs_in_this_instrument,
catalogue_instrument_idx,
catalogue_question_idxs_in_this_instrument,
) in enumerate(catalogue_instrument_idx_to_catalogue_questions_idx):
catalogue_question_idxs_set = set(
catalogue_question_idxs_in_this_instrument
)
if idx_top_input_question_match_in_catalogue in catalogue_question_idxs_set:
# Create the list if it doesn't exist yet
if not instrument_idx_to_cosine_similarities_top_match.get(
catalogue_instrument_idx
catalogue_instrument_idx
):
instrument_idx_to_cosine_similarities_top_match[
catalogue_instrument_idx
Expand All @@ -374,8 +372,8 @@ def match_questions_with_catalogue_instruments(
# Calculate the average for each list of cosine similarities from instruments
instrument_idx_to_cosine_similarities_average: dict[int, float] = {}
for (
instrument_idx,
cosine_similarities,
instrument_idx,
cosine_similarities,
) in instrument_idx_to_cosine_similarities_top_match.items():
instrument_idx_to_cosine_similarities_average[instrument_idx] = (
statistics.mean(cosine_similarities)
Expand Down Expand Up @@ -443,28 +441,36 @@ def match_instruments_with_function(
mhc_all_metadatas: List = [],
mhc_embeddings: np.ndarray = np.zeros((0, 0)),
texts_cached_vectors: dict[str, List[float]] = {},
include_catalogue_matches: bool = False,
catalogue_data: dict = {},
) -> tuple:
"""
Match instruments.
:param instruments: The instruments
:param query: The query
:param vectorisation_function: A function to vectorize a text
:param mhc_questions
:param mhc_all_metadatas
:param mhc_embeddings
:param mhc_questions: MHC questions.
:param mhc_all_metadatas: MHC metadatas.
:param mhc_embeddings: MHC embeddings.
:param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector)
:param include_catalogue_matches: Include catalogue instrument matches in the result.
:param catalogue_data: The catalogue data.
:return: Index 0 contains a list of all the questions from the instruments (if include_catalogue_matches is True, each question now contains the best matched question from the catalogue). Index 1 contains similarity with polarity. Index 2 contains the query similarity. Index 3 contains a dict with the new text vectors to be cached. Index 4 contains a list of all instruments (if include_catalogue_matches is True, each instrument now contains a list of closest instrument matches). Index 5 contains a list of closest catalogue instrument matches for all the instruments (only if include_catalogue_matches is True).
"""

all_questions = []
all_questions: List[Question] = []
for instrument in instruments:
all_questions.extend(instrument.questions)
# all_questions: List[Question] = all_questions
all_questions_str: List[str] = [q.question_text for q in all_questions]
# all_questions = [instrument["question_text"] for instrument in instruments]

text_vectors, new_vectors_dict = create_full_text_vectors(all_questions_str, query, vectorisation_function,
texts_cached_vectors)
text_vectors, new_vectors_dict = create_full_text_vectors(
all_questions=[q.question_text for q in all_questions],
query=query,
vectorisation_function=vectorisation_function,
texts_cached_vectors=texts_cached_vectors
)

# get vectors for all original texts and vectors for negated texts
vectors_pos, vectors_neg = vectors_pos_neg(text_vectors)

Expand All @@ -478,7 +484,7 @@ def match_instruments_with_function(
query_similarity = np.array([])

# Get similarity with polarity
if vectors_pos.any(): # NOTE: Should an error be thrown if vectors_pos is empty?
if vectors_pos.any(): # NOTE: Should an error be thrown if vectors_pos is empty?
pairwise_similarity = cosine_similarity(vectors_pos, vectors_pos)
# NOTE: Similarity of (vectors_neg, vectors_pos) & (vectors_pos, vectors_neg) should be the same
pairwise_similarity_neg1 = cosine_similarity(vectors_neg, vectors_pos)
Expand Down Expand Up @@ -536,4 +542,21 @@ def match_instruments_with_function(
for question in all_questions:
question.topics_auto = []

return all_questions, similarity_with_polarity, query_similarity, new_vectors_dict
# Get catalogue matches
closest_catalogue_instrument_matches = []
if include_catalogue_matches:
instruments, closest_catalogue_instrument_matches = match_instruments_with_catalogue_instruments(
instruments=instruments,
catalogue_data=catalogue_data,
vectorisation_function=vectorisation_function,
texts_cached_vectors=texts_cached_vectors,
)

return (
all_questions,
similarity_with_polarity,
query_similarity,
new_vectors_dict,
instruments,
closest_catalogue_instrument_matches
)
76 changes: 3 additions & 73 deletions src/harmony/schemas/requests/text.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
'''
"""
MIT License
Copyright (c) 2023 Ulster University (https://www.ulster.ac.uk).
Expand All @@ -22,8 +22,7 @@
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
'''
"""

from typing import List

Expand Down Expand Up @@ -101,7 +100,7 @@ class Instrument(BaseModel):
description="The ISO 639-2 (alpha-2) encoding of the instrument language")
questions: List[Question] = Field(description="The items inside the instrument")
closest_catalogue_instrument_matches: List[CatalogueInstrument] = Field(
None,
[],
description="The closest instrument matches in the catalogue for the instrument, the first index "
"contains the best match etc"
)
Expand Down Expand Up @@ -211,72 +210,3 @@ class Config:
}
}


class MatchCatalogueBody(BaseModel):
instruments: List[Instrument] = Field(description="Instruments to match.")
parameters: MatchParameters = Field(DEFAULT_MATCH_PARAMETERS, description="Parameters on how to match.")
sources: List[str] = Field(
default=[],
description="The instrument sources to use for matching. If empty, all instrument sources will be considered."
)

class Config:
schema_extra = {
"example": {
"instruments": [{
"file_id": "fd60a9a64b1b4078a68f4bc06f20253c",
"instrument_id": "7829ba96f48e4848abd97884911b6795",
"instrument_name": "GAD-7 English",
"file_name": "GAD-7 EN.pdf",
"file_type": "pdf",
"file_section": "GAD-7 English",
"language": "en",
"questions": [{"question_no": "1",
"question_intro": "Over the last two weeks, how often have you been bothered by the following problems?",
"question_text": "Feeling nervous, anxious, or on edge",
"options": ["Not at all", "Several days", "More than half the days",
"Nearly every day"],
"source_page": 0
},
{"question_no": "2",
"question_intro": "Over the last two weeks, how often have you been bothered by the following problems?",
"question_text": "Not being able to stop or control worrying",
"options": ["Not at all", "Several days", "More than half the days",
"Nearly every day"],
"source_page": 0
}

]
},
{
"file_id": "fd60a9a64b1b4078a68f4bc06f20253c",
"instrument_id": "7829ba96f48e4848abd97884911b6795",
"instrument_name": "GAD-7 Portuguese",
"file_name": "GAD-7 PT.pdf",
"file_type": "pdf",
"file_section": "GAD-7 Portuguese",
"language": "en",
"questions": [{"question_no": "1",
"question_intro": "Durante as últimas 2 semanas, com que freqüência você foi incomodado/a pelos problemas abaixo?",
"question_text": "Sentir-se nervoso/a, ansioso/a ou muito tenso/a",
"options": ["Nenhuma vez", "Vários dias", "Mais da metade dos dias",
"Quase todos os dias"],
"source_page": 0
},
{"question_no": "2",
"question_intro": "Durante as últimas 2 semanas, com que freqüência você foi incomodado/a pelos problemas abaixo?",
"question_text": " Não ser capaz de impedir ou de controlar as preocupações",
"options": ["Nenhuma vez", "Vários dias", "Mais da metade dos dias",
"Quase todos os dias"],
"source_page": 0
}

]
}

],
"parameters": {"framework": DEFAULT_FRAMEWORK,
"model": DEFAULT_MODEL},
"sources": ["MHC"]
}
}
7 changes: 2 additions & 5 deletions src/harmony/schemas/responses/text.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,16 @@


class MatchResponse(BaseModel):
instruments: List[Instrument] = Field(description="A list of instruments")
questions: List[Question] = Field(
description="The questions which were matched, in an order matching the order of the matrix"
)
matches: List[List] = Field(description="Matrix of cosine similarity matches")
query_similarity: List = Field(
None, description="Similarity metric between query string and items"
)


class MatchCatalogueResponse(BaseModel):
instruments: List[Instrument] = Field(description="A list of instruments")
closest_catalogue_instrument_matches: List[CatalogueInstrument] = Field(
default=None,
default=[],
description="The closest catalogue instrument matches in the catalogue for all the instruments, "
"the first index contains the best match etc."
)
Expand Down

0 comments on commit 7c9baad

Please sign in to comment.