Skip to content

Commit

Permalink
update test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
avalanchesiqi committed Oct 15, 2024
1 parent cc0498d commit 206e907
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 5 deletions.
5 changes: 3 additions & 2 deletions tests/test_binned_dud.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def test_binned_dud_density(self):
self.assertAlmostEqual(dud.get_density(label),
density,
places=3,
msg=f'Test failed: density of label "{label}" does not match expected result')
msg=f'Test failed: density of label "{label}" does not match expected result'
)


if __name__ == "__main__":
unittest.main()
unittest.main()
89 changes: 89 additions & 0 deletions tests/test_extrinsic_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import unittest
import numpy as np
from scipy.stats import beta

from pyquantifier.distributions import MixtureCUD
from pyquantifier.calibration_curve import PlattScaling
from pyquantifier.data import Item, Dataset

test_case_list = [
(MixtureCUD(components=[beta(10, 2), beta(2, 5)], weights=[2, 8]), # classifer_score_density_rv
(21.92, -14.61), # (w, b)
1000000, # num_base_size
1000000, # num_sample
10), # num_bin
(MixtureCUD(components=[beta(10, 2), beta(2, 5)], weights=[2, 8]),
(10, -5),
1000000,
1000000,
10)
]


class TestExtrinsicGenerator(unittest.TestCase):
def setUp(self):
pass

def test_extrinsic_generator(self):
for case_num, test_case in enumerate(test_case_list):
classifer_score_density_rv, (w, b), num_base_size, num_sample, num_bin = test_case

# 1. Generate a dataset with `num_sample` items,
# which pos scores follow the `classifer_score_density` distribution
all_labels = ['pos', 'neg']
simulated_pos_scores = classifer_score_density_rv.generate_data(num_base_size)

calibration_curve = PlattScaling()
calibration_curve.set_params(w, b)

items = []
gt_label_dict = {}
for idx, pos_score in enumerate(simulated_pos_scores):
item = Item(uid=idx+1, all_labels=all_labels, all_probs=[pos_score, 1-pos_score])
items.append(item)

calibrated_pos_score = calibration_curve.get_calibrated_prob(pos_score)[0]
gt_label_dict[item.uid] = np.random.choice(all_labels, p=[calibrated_pos_score, 1-calibrated_pos_score])
# print(gt_label_dict)

dataset = Dataset(items=items)

# 2. Select a subset of the dataset and annotate the labels
selected_dataset, selection_weights = dataset.select_sample_for_annotation(
n=num_sample,
strategy='random',
bins=num_bin
)
annotated_labels = [gt_label_dict[uid] for uid in selected_dataset.df['uid'].values.astype(int)]
selected_dataset.annotate_sample(annotated_labels)

# 3. Model the sampled dataset and retrieve the estimated w and b
selected_dataset.update_dataset_model(num_bin=num_bin)
selected_dataset.update_calibration_curve(method='platt scaling')
estimated_w, estimated_b = selected_dataset.calibration_curve.get_params()
estimated_w = estimated_w.item()
estimated_b = estimated_b.item()

# 4. Check if the estimated w and b are close to the ground truth synthetic w and b
with self.subTest(msg=f'Checking case number {case_num+1}'):
print(f'\nestimated w: {estimated_w:.4f}')
print(f'estimated b: {estimated_b:.4f}')
self.assertAlmostEqual(estimated_w,
w,
delta=0.2,
msg=f'Test failed: estimated w does not match expected result'
)
self.assertAlmostEqual(estimated_b,
b,
delta=0.2,
msg=f'Test failed: estimated b does not match expected result'
)

def test_extrinsic_generator2(self):
with self.subTest(msg=f'Checking case 2'):
pass



if __name__ == "__main__":
unittest.main()
5 changes: 3 additions & 2 deletions tests/test_multinomial_dud.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ def test_multinomial_dud_density(self):
for label, density in expected_result.items():
self.assertEqual(dud.get_density(label),
density,
msg=f'Test failed: density of label "{label}" does not match expected result')
msg=f'Test failed: density of label "{label}" does not match expected result'
)


if __name__ == "__main__":
unittest.main()
unittest.main()
3 changes: 2 additions & 1 deletion tests/test_neyman.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ def test_get_neyman_allocation(self):
with self.subTest(msg=f'Checking case number {case_num+1}'):
self.assertDictEqual(result,
expected_dist,
msg='Test failed: output of Neyman allocation does not match expected result')
msg='Test failed: output of Neyman allocation does not match expected result'
)


if __name__ == "__main__":
Expand Down

0 comments on commit 206e907

Please sign in to comment.