From 326930706c9588e6eb7ddd5bd4fdbd394c33784f Mon Sep 17 00:00:00 2001 From: "Martindale, Nathan" Date: Mon, 14 Aug 2023 09:33:14 -0400 Subject: [PATCH] Add tests for model save load and add support for none data --- icat/anchorlist.py | 5 +++++ icat/histograms.py | 2 ++ icat/instance.py | 3 +++ icat/view.py | 3 +++ tests/conftest.py | 2 +- tests/test_anchorlist.py | 8 +++++--- tests/test_anchors.py | 4 ++-- tests/test_model.py | 16 ++++++++++++++++ 8 files changed, 37 insertions(+), 6 deletions(-) diff --git a/icat/anchorlist.py b/icat/anchorlist.py index 1a30088..f7a3610 100644 --- a/icat/anchorlist.py +++ b/icat/anchorlist.py @@ -652,10 +652,15 @@ def highlight_regex(self) -> str: return kw_regex def build_tfidf_features(self): + # TODO: this eventually needs to use the anchorlist cache, and also + # this may need to change when tfidf just becomes a sim function instead + # of a dedicated anchor if self.model is None: raise RuntimeError( "The anchorlist has no associated model to get a dataset from." ) + if self.model.data.active_data is None: + return self.tfidf_vectorizer = TfidfVectorizer(stop_words="english") # TODO: coupling: accessing active_data through model diff --git a/icat/histograms.py b/icat/histograms.py index a0a1c01..e46aeda 100644 --- a/icat/histograms.py +++ b/icat/histograms.py @@ -100,6 +100,8 @@ def _set_layout(self): # ============================================================ def refresh_data(self, data: DataManager): + if data.active_data is None: + return if data.prediction_col in data.active_data.columns: self.hist_local.set_data( data.active_data.loc[ diff --git a/icat/instance.py b/icat/instance.py index b4d7070..8aa3967 100644 --- a/icat/instance.py +++ b/icat/instance.py @@ -153,6 +153,9 @@ def populate(self, index: int): """Fill or update all of the fields for the given index. This should be called anytime the model updates, or when the user clicks/requests to view a new instance.""" + if self.data.active_data is None: + return + self.index = index self.index_display.children = [str(index)] row = self.data.active_data.iloc[index] diff --git a/icat/view.py b/icat/view.py index c53a7c2..2b0b9db 100644 --- a/icat/view.py +++ b/icat/view.py @@ -153,6 +153,9 @@ def _send_anchorlist_anchor_modification_to_viz( def _serialize_data_to_dicts(self) -> dict: feature_names = self.model.feature_names() + if self.model.data.active_data is None: + return [] + # first we need to see if we are filtering data based on prediction range # we reference the sample indices a lot, so just compute this once and reference # throughout. diff --git a/tests/conftest.py b/tests/conftest.py index 085a37d..b0b7f53 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,5 +64,5 @@ def fun_df(): def data_file_loc(): shutil.rmtree("test/exampledata", ignore_errors=True) os.makedirs("test/exampledata", exist_ok=True) - yield + yield "test/exampledata/thing" shutil.rmtree("test/exampledata", ignore_errors=True) diff --git a/tests/test_anchorlist.py b/tests/test_anchorlist.py index 53721c0..005d6ec 100644 --- a/tests/test_anchorlist.py +++ b/tests/test_anchorlist.py @@ -169,7 +169,9 @@ def test_highlight_regex(kw_set1, kw_set2, expected_result): """Special characters should be escaped, esp *""" -def test_save_load(data_file_loc, fun_df): +def test_save_load_anchorlist(data_file_loc, fun_df): + """When we save an anchorlist and then reload it, all the anchors should reload into + the same spot with the same parameters.""" model = Model(fun_df, "text") a1 = DictionaryAnchor(anchor_name="thing1") a1.keywords = ["hello", "there"] @@ -183,10 +185,10 @@ def test_save_load(data_file_loc, fun_df): model.add_anchor(a1) model.add_anchor(a2) - model.anchor_list.save("test/exampledata/thing") + model.anchor_list.save(data_file_loc) model2 = Model(fun_df, "text") - model2.anchor_list.load("test/exampledata/thing") + model2.anchor_list.load(data_file_loc) assert model2.anchor_list.cache["test_cache"] == 13 a21 = model2.anchor_list.anchors[0] diff --git a/tests/test_anchors.py b/tests/test_anchors.py index 4fe0a98..724626a 100644 --- a/tests/test_anchors.py +++ b/tests/test_anchors.py @@ -380,10 +380,10 @@ def test_save_load_similarity_anchor(data_file_loc): a1.text_col = "my_text" a1.reference_texts = ["I am a powerful potato"] a1.reference_short = ["I am a"] - a1.save("test/exampledata/test") + a1.save(data_file_loc) a2 = SimilarityFunctionAnchor() - a2.load("test/exampledata/test") + a2.load(data_file_loc) assert a2.anchor_name == "I am an anchor" assert a2.weight == 1.2 assert not a2.in_view diff --git a/tests/test_model.py b/tests/test_model.py index 518b751..2cd3ebd 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -280,3 +280,19 @@ def test_changing_anchor_name_twice_before_model_trained_modifies_data( model.anchor_list.anchors[0]._anchor_name_input.fire_event("blur", "testing123") assert "_testing123" in model.data.active_data.columns assert "_testing" not in model.data.active_data.columns + + +def test_save_load_model(data_file_loc, fun_df, dummy_anchor): + """Saving a model and then reloading it should load in all the same data and anchors.""" + + model = Model(fun_df, text_col="text") + model.anchor_list.add_anchor(dummy_anchor) + model.data.apply_label(0, 1) + model.save(data_file_loc) + + model2 = Model(None, text_col="") + model2.load(data_file_loc) + assert len(model2.data.active_data) == 12 + assert model2.text_col == "text" + assert len(model2.anchor_list.anchors) == 1 + assert len(model2.training_data) == 1