Skip to content

Commit

Permalink
Merge pull request #9 from EveWCheng/eve_pr
Browse files Browse the repository at this point in the history
broke down matcher function into smaller functions  (Eve Cheng PR)
  • Loading branch information
woodthom2 authored Dec 22, 2023
2 parents c2ecb9a + b7dd2f0 commit 6dfb2c6
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 94 deletions.
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ ipython_config.py
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

Pipfile.lock
Pipfile
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

Expand Down Expand Up @@ -129,3 +129,4 @@ dmypy.json
.pyre/
.idea/

src/log.txt
192 changes: 100 additions & 92 deletions src/harmony/matching/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
SOFTWARE.
'''

from collections import Counter
from typing import List, Callable

Expand All @@ -45,105 +44,44 @@ def cosine_similarity(vec1: ndarray, vec2: ndarray) -> ndarray:
return np.asarray(dp / matmul(m1.T, m2))


def match_instruments_with_function(
instruments: List[Instrument],
query: str,
vectorisation_function: Callable,
mhc_questions: List = [],
mhc_all_metadatas: List = [],
mhc_embeddings: np.ndarray = np.zeros((0, 0)),
texts_cached_vectors: dict[str, List[float]] = {},
) -> tuple:
"""
Match instruments
:param instruments: The instruments
:param query: The query
:param vectorisation_function: A function to vectorize a text
:param mhc_questions
:param mhc_all_metadatas
:param mhc_embeddings
:param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector)
"""

# Create a list of text vectors
all_questions = []
text_vectors: List[TextVector] = []
for instrument in instruments:
for question in instrument.questions:
if question.question_text is None or question.question_text.strip() == "":
continue # skip empty questions

question.instrument_id = instrument.instrument_id
all_questions.append(question)

# Text
question_text = question.question_text
if question_text not in texts_cached_vectors.keys():
text_vectors.append(
TextVector(
text=question_text, vector=[], is_negated=False, is_query=False
)
)
else:
vector = texts_cached_vectors[question_text]
text_vectors.append(
TextVector(
text=question_text,
vector=vector,
is_negated=False,
is_query=False,
)
)

# Negated text
negated_text = negate(question_text, instrument.language)
if negated_text not in texts_cached_vectors.keys():
text_vectors.append(
TextVector(
text=negated_text, vector=[], is_negated=True, is_query=False
)
)
else:
vector = texts_cached_vectors[negated_text]
text_vectors.append(
TextVector(
text=negated_text,
vector=vector,
is_negated=True,
is_query=False,
)
)

# Add query
if query:
if query not in texts_cached_vectors.keys():
text_vectors.append(
TextVector(text=query, vector=[], is_negated=False, is_query=True)
def add_text_to_vec(text, texts_cached_vectors, text_vectors, is_negated_, is_query_):
if text not in texts_cached_vectors.keys():
text_vectors.append(
TextVector(
text=text, vector=[], is_negated=is_negated_, is_query=is_query_
)
else:
vector = texts_cached_vectors[query]
text_vectors.append(
TextVector(text=query, vector=vector, is_negated=False, is_query=True)
)
else:
vector = texts_cached_vectors[text]
text_vectors.append(
TextVector(
text=text,
vector=vector,
is_negated=is_negated_,
is_query=is_query_,
)
)
return text_vectors

# Texts with no cached vector
texts_not_cached = [x.text for x in text_vectors if not x.vector]

# Get vectors for all texts not cached
new_vectors_list: List = vectorisation_function(texts_not_cached).tolist()
def process_questions(questions):
texts_cached_vectors: dict[str, List[float]] = {}
text_vectors: List[TextVector] = []
for question_text in questions:
text_vectors = add_text_to_vec(question_text, texts_cached_vectors, text_vectors, False, False)
negated_text = negate(question_text, 'en')
text_vectors = add_text_to_vec(negated_text, texts_cached_vectors, text_vectors, True, False)
return text_vectors

# Create a dictionary with new vectors
new_vectors_dict = {}
for vector, text in zip(new_vectors_list, texts_not_cached):
new_vectors_dict[text] = vector

# Add new vectors to all_texts
def vectorise_texts(text_vectors, vectorisation_function):
for index, text_dict in enumerate(text_vectors):
if not text_dict.vector:
text_vectors[index].vector = new_vectors_list.pop(0)
text_vectors[index].vector = vectorisation_function([text_dict.text]).tolist()[0]
return text_vectors

# Create numpy array of texts vectors

def vectors_pos_neg(text_vectors):
vectors_pos = np.array(
[
x.vector
Expand All @@ -160,8 +98,78 @@ def match_instruments_with_function(
if (x.is_negated is True and x.is_query is False)
]
)
return vectors_pos, vectors_neg


def create_full_text_vectors(all_questions, query, vectorisation_function, texts_cached_vectors):
# Create a list of text vectors
text_vectors: List[TextVector] = []
text_vectors = process_questions(all_questions)

# Add query
if query:
text_vectors = add_text_to_vec(query, texts_cached_vectors, text_vectors, False, True)

# Texts with no cached vector
texts_not_cached = [x.text for x in text_vectors if not x.vector]

# Get vectors for all texts not cached
new_vectors_list: List = vectorisation_function(texts_not_cached).tolist()

# Create a dictionary with new vectors
new_vectors_dict = {}
for vector, text in zip(new_vectors_list, texts_not_cached):
new_vectors_dict[text] = vector

# Add new vectors to all_texts
for index, text_dict in enumerate(text_vectors):
if not text_dict.vector:
text_vectors[index].vector = new_vectors_list.pop(0)
return text_vectors, new_vectors_dict


def process_instruments(instruments):
return [instrument for instrument in instruments for q in instrument.questions if
q.question_text is not None or q.question_text.strip() != ""]


# in_ = []
# for instrument in instruments:
# for question in instrument.questions:
# if question.question_text is not None and question.question_text.strip() != "":
# in_.append(instrument)
# return in_

#
def match_instruments_with_function(
instruments: List[Instrument],
query: str,
vectorisation_function: Callable,
mhc_questions: List = [],
mhc_all_metadatas: List = [],
mhc_embeddings: np.ndarray = np.zeros((0, 0)),
texts_cached_vectors: dict[str, List[float]] = {},
) -> tuple:
"""
Match instruments
:param instruments: The instruments
:param query: The query
:param vectorisation_function: A function to vectorize a text
:param mhc_questions
:param mhc_all_metadatas
:param mhc_embeddings
:param texts_cached_vectors: A dictionary of already cached vectors from texts (key is the text and value is the vector)
"""
instruments = process_instruments(instruments)
all_questions: List[TextVector] = [q.question_text for instrument in instruments for q in instrument.questions]
# all_questions = [instrument["question_text"] for instrument in instruments]

text_vectors, new_vectors_dict = create_full_text_vectors(all_questions, query, vectorisation_function,
texts_cached_vectors)
vectors_pos, vectors_neg = vectors_pos_neg(text_vectors)

# Get query similarity
# Get similarity between the query (only one query?) and the questions
if vectors_pos.any() and query:
vector_query = np.array(
[[x for x in text_vectors if x.is_query is True][0].vector]
Expand Down

0 comments on commit 6dfb2c6

Please sign in to comment.