Skip to content

Commit

Permalink
Add tests for unlabeling
Browse files Browse the repository at this point in the history
  • Loading branch information
WarmCyan committed Aug 21, 2023
1 parent 98820e9 commit d4537d3
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 1 deletion.
3 changes: 2 additions & 1 deletion icat/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,8 @@ def apply_label(self, index: int | list[int], label: int | list[int]):
Args:
index (int | list[int]): Either a single index, or a list of indices.
label (int | list[int]): Either the single label to apply or a list of corresponding labels
for the provided indices. 1 is "interesting", 0 is "uninteresting".
for the provided indices. 1 is "interesting", 0 is "uninteresting". If a -1 is provided,
this resets or "unlabels", removing it from the container model's training set.
"""
self._handle_label_changed(index, label)

Expand Down
5 changes: 5 additions & 0 deletions icat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,11 @@ def _train_model(self):
# coverage stats even if it's not seeded, so compute and set here
coverage_info = self.compute_coverage()
self.anchor_list.set_coverage(coverage_info)

# if this model _was_ trained but we unlabel sufficiently many points, we need to consider
# it untrained
if self.is_trained():
del self.classifier.classes_
return False

if len(self.feature_names(in_model_only=True)) < 1:
Expand Down
47 changes: 47 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,3 +296,50 @@ def test_save_load_model(data_file_loc, fun_df, dummy_anchor):
assert model2.text_col == "text"
assert len(model2.anchor_list.anchors) == 1
assert len(model2.training_data) == 1


def test_unlabel_removes_from_training_set(fun_df, dummy_anchor):
"""Labeling a point and then unlabeling it should remove it from the training set."""
model = Model(fun_df, text_col="text")
model.add_anchor(dummy_anchor)
model.data.apply_label(1, 1)
assert len(model.training_data) == 1
model.data.apply_label(1, -1)
assert len(model.training_data) == 0


def test_unlabel_when_no_corresp_row_does_not_break(fun_df, dummy_anchor):
"""Specifying to 'unlabel' a row that isn't in the training set shouldn't throw
an error, the training data should simply still not contain that row."""
model = Model(fun_df, text_col="text")
model.add_anchor(dummy_anchor)
model.data.apply_label(1, -1)
assert len(model.training_data) == 0


def test_unlabel_after_seed_correctly_unseeds(fun_df, dummy_anchor):
"""Unlabeling a point immediately after a model is seeded should effectively
unseed the model without causing an error."""
model = Model(fun_df, text_col="text")
model.add_anchor(dummy_anchor)
for i in range(0, 10):
if i in [3, 6, 7]:
model.data.apply_label(i, 1)
else:
model.data.apply_label(i, 0)
assert model.is_seeded()
assert model.is_trained()
model.data.apply_label(7, -1)
assert not model.is_seeded()
assert not model.is_trained()


def test_unlabel_multi(fun_df, dummy_anchor):
"""Calling label function with multiple -1's should correctly unlabel
multiple points."""
model = Model(fun_df, text_col="text")
model.add_anchor(dummy_anchor)
model.data.apply_label([1, 2], [1, 1])
assert len(model.training_data) == 2
model.data.apply_label([1, 2], [-1, -1])
assert len(model.training_data) == 0

0 comments on commit d4537d3

Please sign in to comment.