Skip to content

Commit

Permalink
Merge pull request #39 from LeapMind/dev
Browse files Browse the repository at this point in the history
add output npy func
  • Loading branch information
a-hanamoto authored Apr 27, 2020
2 parents fd4e385 + 0dea771 commit 4384f7b
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 34 deletions.
1 change: 1 addition & 0 deletions cfg/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ model_params:
num_of_nonzero: 3 # number of nonzero (default: 3)
cutoff_edge_width: 2
num_of_ch: 896 # number of channels used for anomaly detection (1<=num_of_ch<=896)
output_npy: False
1 change: 1 addition & 0 deletions cfg/sample_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ model_params:
num_of_nonzero: 3 # number of nonzero (default: 3)
cutoff_edge_width: 2
num_of_ch: 100 # number of channels used for anomaly detection (1<=num_of_ch<=896)
output_npy: False
30 changes: 18 additions & 12 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,30 @@ def __init__(

if dir_name:
dir_path = os.path.abspath(os.path.join(root, dir_name))
self.dataset = self.load_dataset(dir_path, ext)
dir_parent_path = os.path.dirname(dir_path)
dir_name = os.path.basename(dir_path)
self.dataset = self.load_dataset(dir_parent_path, dir_name, ext)
else:
dir_path = os.path.abspath(os.path.join(root, excp_name))
dir_parent_path = os.path.dirname(dir_path)
dir_paths = [
os.path.join(dir_parent_path, d)
dirs = [
d
for d in os.listdir(dir_parent_path)
if d not in excp_name
]
self.dataset = []
for path in dir_paths:
self.dataset.extend(self.load_dataset(path, ext))
for dir_name in dirs:
self.dataset.extend(self.load_dataset(
dir_parent_path, dir_name, ext))

def load_dataset(self, dir_parent_path, dir_name, ext):

def load_dataset(self, dir_path, ext):
return [
(f, cv2.imread(os.path.join(dir_path, f))[:, :, [2, 1, 0]])
(dir_name, f, cv2.imread(os.path.join(
dir_parent_path, dir_name, f))
[:, :, [2, 1, 0]])
for f in tqdm(
os.listdir(dir_path),
os.listdir(os.path.join(dir_parent_path, dir_name)),
desc="loading images"
)
if ext in f
Expand All @@ -59,13 +65,13 @@ def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
sample = self.dataset[idx][1]
sample = self.dataset[idx][2]

if self.preprocessor:
for p in self.preprocessor:
sample = p(sample)

return (self.dataset[idx][0], sample)
return (self.dataset[idx][0], self.dataset[idx][1], sample)


class DataLoader(object):
Expand Down Expand Up @@ -100,8 +106,8 @@ def __next__(self):

batch = []
for idx in self.idxs[self.counter: self.counter + self.batch_size]:
batch.append(self.dataset[idx][1])
batch.append(self.dataset[idx][2])

self.counter += self.batch_size

return self.dataset[idx][0], numpy.stack(batch)
return self.dataset[idx][0], self.dataset[idx][1], numpy.stack(batch)
16 changes: 9 additions & 7 deletions src/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,23 @@ class TestDatasetandDataLoader(unittest.TestCase):
dir_env["test_good_dir"] = "test/good"
dir_env["test_bad_dir"] = None

mvtec_dataset = dataset.MVTecDataset(is_train=False, dir_env=dir_env)
dataloader = dataset.DataLoader(
mvtec_dataset,
mvtec_dataset_train = dataset.MVTecDataset(is_train=True, dir_env=dir_env)
mvtec_dataset_test = dataset.MVTecDataset(is_train=False, dir_env=dir_env)
dataloader_test = dataset.DataLoader(
mvtec_dataset_test,
batch_size=2,
shuffle=True,
drop_last=False,
)

def test_dataset(self):
self.assertEqual(len(self.mvtec_dataset), 2)
self.assertEqual(len(self.mvtec_dataset[0]), 2)
self.assertEqual(len(self.mvtec_dataset_train), 10)
self.assertEqual(len(self.mvtec_dataset_test), 2)
self.assertEqual(len(self.mvtec_dataset_test[0]), 3)

def test_dataloader(self):
self.assertEqual(len(self.dataloader), 2)
self.assertEqual(len(self.dataloader_test), 2)
ret = 0
for _ in self.dataloader:
for _ in self.dataloader_test:
ret += 1
self.assertEqual(ret, 2 // 2)
27 changes: 16 additions & 11 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
self.patch_size = model_env["patch_size"]
self.stride = model_env["stride"]
self.num_of_ch = model_env["num_of_ch"]
self.output_npy = model_env["output_npy"]

self.org_l = int(256 / 8.0) - self.cutoff_edge_width * 2

Expand All @@ -47,7 +48,7 @@ def __init__(
def train(self):
arrs = []
for batch_data in self.train_loader:
batch_img = batch_data[1]
batch_img = batch_data[2]
for p in self.preprocesses:
batch_img = p(batch_img)
N, P, C, H, W = batch_img.shape
Expand Down Expand Up @@ -111,7 +112,7 @@ def calculate_error(self, coders, is_positive):

for batch_data in tqdm(loader, desc="testing"):

batch_name, batch_img = batch_data[0], batch_data[1]
batch_path, batch_name, batch_img = batch_data
p_batch_img = batch_img
for p in self.preprocesses:
p_batch_img = p(p_batch_img)
Expand Down Expand Up @@ -144,18 +145,22 @@ def calculate_error(self, coders, is_positive):
top_5[numpy.argsort(ch_err)[::-1][:5]] += 1
errs.append(numpy.sum(ch_err))
f_diff /= self.num_of_ch
visualized_out = self.visualize(org_img, f_diff)
self.output_image(is_positive, batch_name,
ch_err, visualized_out)
if self.output_npy:
self.output_np_array(batch_path, batch_name, f_diff)
else:
visualized_out = self.visualize(org_img, f_diff)
self.output_image(batch_path, batch_name,
ch_err, visualized_out)
return errs

def output_image(self, is_positive, batch_name, ch_err, visualized_out):
if is_positive:
mode = "pos"
else:
mode = "neg"
def output_np_array(self, batch_path, batch_name, f_diff):
output_path = os.path.join("visualized_results", batch_path)
os.makedirs(output_path, exist_ok=True)
numpy.save(os.path.join(
output_path, batch_name.split(".")[0] + ".npy"), f_diff)

output_path = os.path.join("visualized_results", mode)
def output_image(self, batch_path, batch_name, ch_err, visualized_out):
output_path = os.path.join("visualized_results", batch_path)
os.makedirs(output_path, exist_ok=True)

cv2.imwrite(
Expand Down
16 changes: 12 additions & 4 deletions src/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,22 @@ class TestCalculateScore(unittest.TestCase):
model_params = config["model_params"]

model = models.SparseCodingWithMultiDict([], model_params)
batch_path = "test"
batch_name = "image.png"

def test_output_np_array(self):
f_diff = np.ones([1, 10, 10])
self.model.output_np_array(self.batch_path, self.batch_name, f_diff)
self.assertTrue(os.path.exists("visualized_results/test/image.npy"))
shutil.rmtree("visualized_results")

def test_output_image(self):
is_positive = True
batch_name = "image.png"
ch_err = np.ones([896])
output_img = np.zeros([10, 10, 3])
self.model.output_image(is_positive, batch_name, ch_err, output_img)
self.assertTrue(os.path.exists("visualized_results/pos/image-896.png"))
self.model.output_image(
self.batch_path, self.batch_name, ch_err, output_img)
self.assertTrue(os.path.exists(
"visualized_results/test/image-896.png"))
shutil.rmtree("visualized_results")

def test_calclate_ssim(self):
Expand Down

0 comments on commit 4384f7b

Please sign in to comment.