Skip to content

Commit

Permalink
Add tests for model save load and add support for none data
Browse files Browse the repository at this point in the history
  • Loading branch information
WarmCyan committed Aug 14, 2023
1 parent 0646fa0 commit 3269307
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 6 deletions.
5 changes: 5 additions & 0 deletions icat/anchorlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,10 +652,15 @@ def highlight_regex(self) -> str:
return kw_regex

def build_tfidf_features(self):
# TODO: this eventually needs to use the anchorlist cache, and also
# this may need to change when tfidf just becomes a sim function instead
# of a dedicated anchor
if self.model is None:
raise RuntimeError(
"The anchorlist has no associated model to get a dataset from."
)
if self.model.data.active_data is None:
return

self.tfidf_vectorizer = TfidfVectorizer(stop_words="english")
# TODO: coupling: accessing active_data through model
Expand Down
2 changes: 2 additions & 0 deletions icat/histograms.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def _set_layout(self):
# ============================================================

def refresh_data(self, data: DataManager):
if data.active_data is None:
return
if data.prediction_col in data.active_data.columns:
self.hist_local.set_data(
data.active_data.loc[
Expand Down
3 changes: 3 additions & 0 deletions icat/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def populate(self, index: int):
"""Fill or update all of the fields for the given index. This
should be called anytime the model updates, or when the user
clicks/requests to view a new instance."""
if self.data.active_data is None:
return

self.index = index
self.index_display.children = [str(index)]
row = self.data.active_data.iloc[index]
Expand Down
3 changes: 3 additions & 0 deletions icat/view.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def _send_anchorlist_anchor_modification_to_viz(
def _serialize_data_to_dicts(self) -> dict:
feature_names = self.model.feature_names()

if self.model.data.active_data is None:
return []

# first we need to see if we are filtering data based on prediction range
# we reference the sample indices a lot, so just compute this once and reference
# throughout.
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,5 +64,5 @@ def fun_df():
def data_file_loc():
shutil.rmtree("test/exampledata", ignore_errors=True)
os.makedirs("test/exampledata", exist_ok=True)
yield
yield "test/exampledata/thing"
shutil.rmtree("test/exampledata", ignore_errors=True)
8 changes: 5 additions & 3 deletions tests/test_anchorlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,9 @@ def test_highlight_regex(kw_set1, kw_set2, expected_result):
"""Special characters should be escaped, esp *"""


def test_save_load(data_file_loc, fun_df):
def test_save_load_anchorlist(data_file_loc, fun_df):
"""When we save an anchorlist and then reload it, all the anchors should reload into
the same spot with the same parameters."""
model = Model(fun_df, "text")
a1 = DictionaryAnchor(anchor_name="thing1")
a1.keywords = ["hello", "there"]
Expand All @@ -183,10 +185,10 @@ def test_save_load(data_file_loc, fun_df):
model.add_anchor(a1)
model.add_anchor(a2)

model.anchor_list.save("test/exampledata/thing")
model.anchor_list.save(data_file_loc)

model2 = Model(fun_df, "text")
model2.anchor_list.load("test/exampledata/thing")
model2.anchor_list.load(data_file_loc)
assert model2.anchor_list.cache["test_cache"] == 13

a21 = model2.anchor_list.anchors[0]
Expand Down
4 changes: 2 additions & 2 deletions tests/test_anchors.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,10 @@ def test_save_load_similarity_anchor(data_file_loc):
a1.text_col = "my_text"
a1.reference_texts = ["I am a powerful potato"]
a1.reference_short = ["I am a"]
a1.save("test/exampledata/test")
a1.save(data_file_loc)

a2 = SimilarityFunctionAnchor()
a2.load("test/exampledata/test")
a2.load(data_file_loc)
assert a2.anchor_name == "I am an anchor"
assert a2.weight == 1.2
assert not a2.in_view
Expand Down
16 changes: 16 additions & 0 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,19 @@ def test_changing_anchor_name_twice_before_model_trained_modifies_data(
model.anchor_list.anchors[0]._anchor_name_input.fire_event("blur", "testing123")
assert "_testing123" in model.data.active_data.columns
assert "_testing" not in model.data.active_data.columns


def test_save_load_model(data_file_loc, fun_df, dummy_anchor):
"""Saving a model and then reloading it should load in all the same data and anchors."""

model = Model(fun_df, text_col="text")
model.anchor_list.add_anchor(dummy_anchor)
model.data.apply_label(0, 1)
model.save(data_file_loc)

model2 = Model(None, text_col="")
model2.load(data_file_loc)
assert len(model2.data.active_data) == 12
assert model2.text_col == "text"
assert len(model2.anchor_list.anchors) == 1
assert len(model2.training_data) == 1

0 comments on commit 3269307

Please sign in to comment.