Skip to content

Commit

Permalink
Fix 'untraining' when fall beneath seeded requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
WarmCyan committed Aug 21, 2023
1 parent d4537d3 commit b5e4d34
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
3 changes: 3 additions & 0 deletions icat/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,9 @@ def _train_model(self):
# it untrained
if self.is_trained():
del self.classifier.classes_
self.data.active_data = self.data.active_data.drop(
self.data.prediction_col, axis=1
)
return False

if len(self.feature_names(in_model_only=True)) < 1:
Expand Down
18 changes: 18 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,21 @@ def test_unlabel_multi(fun_df, dummy_anchor):
assert len(model.training_data) == 2
model.data.apply_label([1, 2], [-1, -1])
assert len(model.training_data) == 0


def test_untraining_removes_pred_col(fun_df, dummy_anchor):
"""When a model 'untrains' because the label count fell beneath seeding requirements,
also remove any previous predictions from the active_data"""

model = Model(fun_df, text_col="text")
model.add_anchor(dummy_anchor)
for i in range(0, 10):
if i in [3, 6, 7]:
model.data.apply_label(i, 1)
else:
model.data.apply_label(i, 0)
assert model.is_seeded()
assert model.is_trained()
model.data.apply_label(7, -1)
assert not model.is_trained()
assert model.data.prediction_col not in model.data.active_data.columns

0 comments on commit b5e4d34

Please sign in to comment.