Skip to content

Commit

Permalink
Fix MHC test case
Browse files Browse the repository at this point in the history
  • Loading branch information
woodthom2 committed Apr 9, 2024
1 parent b822014 commit 554fac4
Showing 1 changed file with 8 additions and 163 deletions.
171 changes: 8 additions & 163 deletions tests/test_match_mhc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import unittest

import numpy as np
from sentence_transformers import SentenceTransformer

from harmony import match_instruments
from harmony.schemas.requests.text import Instrument, Question
Expand All @@ -43,169 +44,13 @@
'self-harm and suicide']}
]

mhc_questions = [
Question(
question_text="Have you ever felt annoyed by criticism of your drinking?"),
Question(
question_text="Have you recently")
]

mhc_embeddings = np.array([[0.31698248, 0.12777875, 0.04758111, 0.42555183, 0.39878827,
-0.08955862, 0.5192866, 0.30124685, 0.21197015, 0.00708418,
-0.2850551, 0.18860379, -0.1555653, 0.00297953, 0.00462213,
-0.12949726, -0.08035432, -0.26375496, 0.12412424, 0.10837679,
-0.5037259, -0.09385646, 0.3371012, 0.3913466, 0.20307001,
0.24039863, 0.1305629, -0.2821185, -0.5513103, 0.06190247,
-0.13765384, -0.19164908, 0.45563582, -0.14968751, -0.08522708,
-0.04938481, -0.05294258, 0.38618585, -0.15730491, 0.08312214,
0.47667167, -0.3356474, -0.17182834, 0.26071277, -0.17881565,
0.08396178, -0.00323634, 0.2311756, -0.12618396, -0.10611524,
0.43096817, 0.17048849, 0.21215717, -0.25437835, -0.13393787,
-0.53877395, -0.07101384, 0.01789423, -0.3065841, 0.29606512,
-0.08472076, -0.17864832, -0.02092675, 0.27630028, -0.01672868,
0.00911544, 0.05188185, 0.24766283, -0.34470555, 0.36253104,
0.4684341, -0.16652438, 0.51165503, -0.13537268, -0.16006592,
-0.03522887, -0.0707399, 0.33693355, -0.19236559, 0.23861784,
0.1801952, 0.6937744, -0.12062132, -0.06564505, -0.22319959,
0.05320398, 0.04514768, 0.22523028, 0.14264914, -0.16481638,
-0.17351519, -0.04900846, 0.3545729, 0.15396394, 0.04017958,
0.18319216, -0.27823684, -0.3195731, -0.29258227, -0.11263382,
0.15431456, 0.4166155, -0.1792731, -0.10493241, 0.48761463,
0.08412156, -0.3284959, 0.2170649, -0.04346004, -0.20890392,
-0.29599097, -0.2947356, 0.12064935, -0.42329106, -0.15400986,
0.44402364, -0.36696455, -0.05456911, 0.3541068, 0.11427474,
-0.23996827, 0.09099466, -0.29233944, 0.12837581, -0.21298626,
-0.10884392, -0.25267506, -0.2787788, -0.4845803, -0.13096423,
-0.4864495, 0.28124052, 0.14306262, -0.19078338, -0.17958394,
0.25551945, -0.11909194, -0.3314092, 0.4510464, 0.19852763,
0.01937219, 0.23598652, -0.12465857, 0.442798, -0.6327372,
0.070682, -0.33041057, -0.16026855, -0.15644713, 0.00449345,
0.26490465, 0.36309314, -0.10244551, -0.24318339, -0.28259054,
-0.10808916, 0.09426385, 0.09313136, -0.09232678, -0.15955387,
-0.03862188, -0.02023721, -0.04303666, -0.09471063, 0.4094802,
0.00823278, 0.177744, -0.4164041, 0.42607704, 0.15625669,
0.00993561, 0.10922246, -0.19608107, 0.12548625, -0.56741923,
-0.38538846, -0.47808996, 0.28678727, 0.05682391, 0.01955914,
0.3326408, 0.430999, -0.23403923, 0.02386556, -0.01759753,
-0.06427025, 0.04940457, -0.49631965, 0.02051971, 0.10658412,
-0.3540185, -0.7100885, -0.1473797, 0.4786842, -0.25432903,
0.2342263, 0.01315649, 0.13445738, -0.3768663, 0.39162108,
-0.31643704, -0.04117766, 0.22978652, -0.06744407, 0.03473052,
-0.2560465, 0.355641, -0.09752869, 0.2763238, -0.01599947,
0.09489536, 0.34597492, 0.06963903, -0.143612, 0.28069505,
-0.07262952, -0.556238, 0.36728564, -0.11330801, 0.20679858,
-0.14668477, 0.17844874, -0.04426139, -0.01047437, -0.02615638,
-0.20547685, 0.23042962, -0.0039904, -0.38805655, -0.02489062,
-0.02377239, -0.2765465, 0.16558783, -0.15966313, -0.13969666,
0.1478924, -0.61089206, 0.44975737, 0.12644823, 0.23916753,
0.0458822, -0.01550341, -0.19837566, -0.3952923, -0.0927781,
0.13077396, 0.26247647, -0.17566101, 0.21122949, 0.06102095,
0.5539473, -0.56075054, 0.07709154, 0.20823747, 0.16002455,
0.70210016, -0.05882035, -0.15420373, -0.05149654, 0.38249192,
-0.2946821, -0.35835972, -0.44311288, -0.24698798, 0.12509051,
-0.11229838, -0.08135978, -0.16638997, 0.03564276, -0.09275578,
-0.24160156, -0.30175573, 0.1027863, 0.08439031, 0.0099033,
0.02770073, 0.22322494, 0.311543, -0.13683608, -0.3894704,
0.49415538, 0.47102222, -0.3151316, 0.20377164, 0.09965412,
-0.10423395, -0.68527293, -0.473526, 0.6035066, -0.05493419,
-0.28567913, 0.29210818, 0.5681573, 0.12835786, -0.40651968,
0.16982934, -0.3874353, 0.15280275, -0.0236476, -0.03943689,
-0.10345548, -0.09391969, -0.1913922, -0.3044832, -0.13739675,
0.27429006, 0.50745153, -0.0597518, -0.21094711, -0.0732678,
0.04466628, 0.14001723, -0.18278922, 0.03852075, 0.23296797,
-0.42482314, -0.31042528, 0.5422277, -0.03087755, -0.1230187,
-0.07384335, -0.14684653, 0.42525747, 0.47017905, 0.27531362,
0.15183, 0.00558928, -0.08692063, -0.58131367, -0.3823981,
0.06546987, 0.2202636, -0.01623679, 0.1046277, -0.00911903,
-0.05312139, -0.10305279, -0.3665741, 0.11962718, 0.38077533,
0.3854031, 0.39518344, -0.30965704, 0.09311447, -0.12321192,
-0.18851212, 0.35448384, -0.20389782, -0.24292004, -0.2065611,
0.17573257, 0.30135575, -0.47104543, 0.14942494, 0.3500816,
-0.11687063, 0.12592225, -0.40719974, 0.26228407, -0.27288267,
-0.31435874, -0.0679815, 0.21934305, 0.3319977, 0.05899495,
0.26373258, -0.08150745, 0.4251593, -0.04965526, 0.15699777,
0.18204968, 0.06474901, -0.04183663, 0.06092591, -0.02925731,
-0.35751835, 0.05314383, 0.1297178, -0.5059059, -0.24174014,
0.48816782, 0.43312073, -0.20753404, -0.11278468],
[0.27486998, -0.2697892, 0.3188312, -0.12536432, 0.01304267,
0.33539486, -0.03457115, 0.42157128, 0.14488854, -0.2535716,
0.00203519, 0.19810879, -0.10463517, 0.37333468, -0.00255293,
-0.16467139, -0.14045683, -0.3814871, -0.48228467, -0.08481958,
-0.9221502, 0.21189533, -0.43577656, 0.16735111, 0.48340046,
-0.01760102, 0.27114028, 0.17150797, -0.34873474, -0.4152926,
-0.25706002, 0.03578025, -0.02686583, 0.14918645, -0.21884163,
0.38330188, 0.04579284, 0.05233987, -0.1042789, -0.04755146,
-0.24653739, -0.35496214, 0.07765934, 0.283041, 0.03939558,
-0.07236537, 0.29317337, 0.08163625, 0.16927677, -0.07933656,
-0.05621653, 0.00285317, 0.1326318, -0.23114794, 0.1844616,
-0.0594918, -0.3353115, 0.42261744, 0.29715118, 0.0768225,
-0.01329702, -0.13062464, -0.40871128, 0.1458279, -0.15691017,
-0.0560152, -0.08101902, -0.11061958, -0.10817435, 0.34984335,
-0.06367751, 0.16960506, -0.30438364, 0.05792395, -0.3264931,
0.26885664, 0.44736207, 0.14939362, 0.0298011, -0.54847616,
0.30995524, 0.56500894, -0.28844061, -0.13042337, -0.2105248,
0.3470874, -0.09324706, -0.0062361, -0.12570363, -0.53825516,
0.27010298, 0.48290047, 0.14266248, -0.06316312, 0.21796148,
0.08413079, 0.10383511, -0.14182593, 0.10216405, 0.9572182,
0.06990123, 0.11582927, 0.02274915, -0.34581697, 0.10836975,
-0.12180965, -0.33869275, 0.28453517, -0.06676594, -0.09463757,
0.00802726, -0.14916064, 0.14053154, -0.38230038, 0.16950493,
0.13753466, -0.26661664, 0.16837998, 0.30741367, 0.59256136,
-0.36832786, 0.04534454, -0.21152854, -0.28456149, -0.23686779,
-0.5134756, 0.31583533, -0.3019316, 0.10003791, -0.11501426,
-0.05164363, 0.12522875, -0.2860269, -0.16071294, -0.25988027,
-0.1528313, 0.21414776, -0.33567128, 0.19986635, -0.16722803,
-0.21049023, 0.3136991, 0.04985164, 0.11231628, -0.00206487,
-0.0749305, 0.2528408, 0.22653161, 0.2092252, 0.03754649,
0.27962247, 0.09704119, 0.2611406, -0.16651891, 0.07974467,
0.00834405, 0.34989062, 0.15298776, -0.0019381, -0.08982801,
0.08453012, -0.08024079, -0.14020282, 0.10008207, 0.22012894,
-0.15018322, -0.10299787, -0.1733045, 0.566371, 0.2072993,
-0.34126177, 0.07938533, -0.13098265, -0.08013391, -0.2878633,
0.00973775, 0.13913797, 0.0927425, 0.10339385, 0.21635972,
-0.20797694, 0.15161236, -0.11198922, 0.35578677, -0.02738264,
-0.26994178, -0.12715785, -0.09629931, 0.15184326, 0.29649338,
-0.17322885, -0.38323998, 0.12125704, 0.2510554, 0.04652905,
0.6565736, -0.04273917, 0.16851999, 0.11479741, 0.09580388,
-0.6911337, -0.3516644, 0.6060403, -0.31186315, 0.14358425,
-0.23110956, 0.504636, -0.0194466, 0.061623, -0.1000816,
-0.05559966, 0.3858693, 0.03753303, 0.09085271, 0.17007722,
0.17224549, -0.0962081, 0.32467312, -0.22916205, -0.23865096,
0.0587332, -0.11103028, 0.06267146, -0.32311437, -0.1277771,
-0.09335798, -0.10572634, 0.17335922, 0.03074208, 0.10909478,
0.03907165, 0.04682253, -0.17600848, -0.15760909, -0.01004206,
-0.49257722, 0.27541164, 0.03481415, -0.20166273, 0.4372687,
-0.17118098, -0.22499757, -0.11060498, -0.00291519, 0.26334196,
-0.20152582, -0.08343456, -0.00866678, -0.31192753, 0.4213688,
0.30551866, -0.12213267, -0.32327032, -0.15804684, -0.18317492,
0.13432573, -0.40005973, -0.15876392, -0.09548694, 0.2725558,
-0.1870125, -0.41536883, 0.10239711, -0.25523803, 0.16100271,
-0.01001115, -0.39272046, -0.02715501, -0.01456893, 0.10823169,
-0.14799145, -0.00846969, 0.2624863, 0.30831105, -0.08509161,
-0.38923135, -0.07503062, 0.16235362, -0.22095658, -0.1585248,
0.12947415, -0.05640597, 0.34181905, 0.48939648, -0.7805238,
-0.17587404, -0.01248726, -0.13006216, 0.12443904, 0.00942586,
0.06076963, 0.05074999, 0.02869129, -0.19378237, -0.35253748,
0.07046331, -0.519916, -0.06380238, 0.13962676, 0.1303891,
-0.01947186, -0.151776, -0.2864162, -0.05781687, 0.23414886,
0.09762021, 0.0017703, 0.03055813, 0.0319084, 0.20591748,
0.24638842, -0.04396889, 0.08936361, 0.21831344, 0.02940887,
-0.04536785, -0.5190801, 0.24697958, -0.15234096, -0.21200204,
0.37511286, 0.28909883, -0.04750488, 0.2628198, 0.5198975,
0.29022488, 0.22573914, 0.31517568, 0.05905829, -0.37448406,
-0.01358803, 0.25261813, -0.19083244, -0.26631537, 0.05716723,
0.20410937, -0.36892363, -0.19007134, 0.2187003, 0.09327414,
0.04022982, 0.01385638, -0.00511172, 0.01945869, -0.32149813,
-0.07335124, -0.12126502, 0.5104156, -0.17086674, 0.12854597,
-0.01081597, 0.28688022, -0.4756173, -0.30881587, -0.03464577,
-0.10276201, 0.34286246, -0.2315859, -0.05564657, -0.09513279,
-0.04309574, 0.18460624, 0.09360301, 0.46265283, -0.2623327,
-0.3211086, -0.3473561, 0.15126844, -0.14696191, -0.02683708,
-0.3956083, -0.07471086, 0.02117106, 0.05153611, 0.23008172,
-0.0815743, -0.03742188, -0.03611618, -0.01844736, 0.03614095,
0.3976692, 0.23438324, -0.23447269, 0.3847343]
],
dtype=np.float32)
mhc_questions_as_text = ["Have you ever felt annoyed by criticism of your drinking?", "Have you recently"]

model = SentenceTransformer('sentence-transformers/paraphrase-multilingual-mpnet-base-v2')

mhc_embeddings = model.encode(np.asarray(mhc_questions_as_text))

mhc_questions = [Question(question_text=t) for t in mhc_questions_as_text]


class TestMatchMhc(unittest.TestCase):
Expand Down

0 comments on commit 554fac4

Please sign in to comment.