diff --git a/src/dataset_test.py b/src/dataset_test.py index 0ad8a9a..3256708 100644 --- a/src/dataset_test.py +++ b/src/dataset_test.py @@ -12,20 +12,22 @@ class TestDatasetandDataLoader(unittest.TestCase): dir_env["test_good_dir"] = "test/good" dir_env["test_bad_dir"] = None - mvtec_dataset = dataset.MVTecDataset(is_train=False, dir_env=dir_env) - dataloader = dataset.DataLoader( - mvtec_dataset, + mvtec_dataset_train = dataset.MVTecDataset(is_train=True, dir_env=dir_env) + mvtec_dataset_test = dataset.MVTecDataset(is_train=False, dir_env=dir_env) + dataloader_test = dataset.DataLoader( + mvtec_dataset_test, batch_size=2, shuffle=True, drop_last=False, ) def test_dataset(self): - self.assertEqual(len(self.mvtec_dataset), 2) - self.assertEqual(len(self.mvtec_dataset[0]), 3) + self.assertEqual(len(self.mvtec_dataset_train), 10) + self.assertEqual(len(self.mvtec_dataset_test), 2) + self.assertEqual(len(self.mvtec_dataset_test[0]), 3) def test_dataloader(self): - self.assertEqual(len(self.dataloader), 2) + self.assertEqual(len(self.dataloader_test), 2) ret = 0 for _ in self.dataloader: ret += 1 diff --git a/src/models_test.py b/src/models_test.py index 6a324d3..461f0b1 100644 --- a/src/models_test.py +++ b/src/models_test.py @@ -17,23 +17,22 @@ class TestCalculateScore(unittest.TestCase): model_params = config["model_params"] model = models.SparseCodingWithMultiDict([], model_params) + batch_path = "test" + batch_name = "image.png" - def output_np_array(self): - batch_path = "test" - batch_name = "sample.png" + def test_output_np_array(self): f_diff = np.ones([1, 10, 10]) - self.model.output_np_array(batch_path, batch_name, f_diff) - self.assertTrue(os.path.exists("test/sample.npy")) - shutil.rmtree("test") + self.model.output_np_array(self.batch_path, self.batch_name, f_diff) + self.assertTrue(os.path.exists("visualized_results/test/image.npy")) + shutil.rmtree("visualized_results") def test_output_image(self): - batch_path = "good" - batch_name = "image.png" ch_err = np.ones([896]) output_img = np.zeros([10, 10, 3]) - self.model.output_image(batch_path, batch_name, ch_err, output_img) + self.model.output_image( + self.batch_path, self.batch_name, ch_err, output_img) self.assertTrue(os.path.exists( - "visualized_results/good/image-896.png")) + "visualized_results/test/image-896.png")) shutil.rmtree("visualized_results") def test_calclate_ssim(self):