Skip to content

Commit

Permalink
add output npy func
Browse files Browse the repository at this point in the history
  • Loading branch information
a-hanamoto committed Apr 26, 2020
1 parent fd4e385 commit 1c6fe9f
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 23 deletions.
1 change: 1 addition & 0 deletions cfg/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ model_params:
num_of_nonzero: 3 # number of nonzero (default: 3)
cutoff_edge_width: 2
num_of_ch: 896 # number of channels used for anomaly detection (1<=num_of_ch<=896)
output_npy: True
1 change: 1 addition & 0 deletions cfg/sample_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ model_params:
num_of_nonzero: 3 # number of nonzero (default: 3)
cutoff_edge_width: 2
num_of_ch: 100 # number of channels used for anomaly detection (1<=num_of_ch<=896)
output_npy: True
30 changes: 18 additions & 12 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,30 @@ def __init__(

if dir_name:
dir_path = os.path.abspath(os.path.join(root, dir_name))
self.dataset = self.load_dataset(dir_path, ext)
dir_parent_path = os.path.dirname(dir_path)
dir_name = os.path.basename(dir_path)
self.dataset = self.load_dataset(dir_parent_path, dir_name, ext)
else:
dir_path = os.path.abspath(os.path.join(root, excp_name))
dir_parent_path = os.path.dirname(dir_path)
dir_paths = [
os.path.join(dir_parent_path, d)
dirs = [
d
for d in os.listdir(dir_parent_path)
if d not in excp_name
]
self.dataset = []
for path in dir_paths:
self.dataset.extend(self.load_dataset(path, ext))
for dir_name in dirs:
self.dataset.extend(self.load_dataset(
dir_parent_path, dir_name, ext))

def load_dataset(self, dir_parent_path, dir_name, ext):

def load_dataset(self, dir_path, ext):
return [
(f, cv2.imread(os.path.join(dir_path, f))[:, :, [2, 1, 0]])
(dir_name, f, cv2.imread(os.path.join(
dir_parent_path, dir_name, f))
[:, :, [2, 1, 0]])
for f in tqdm(
os.listdir(dir_path),
os.listdir(os.path.join(dir_parent_path, dir_name)),
desc="loading images"
)
if ext in f
Expand All @@ -59,13 +65,13 @@ def __len__(self):
return len(self.dataset)

def __getitem__(self, idx):
sample = self.dataset[idx][1]
sample = self.dataset[idx][2]

if self.preprocessor:
for p in self.preprocessor:
sample = p(sample)

return (self.dataset[idx][0], sample)
return (self.dataset[idx][0], self.dataset[idx][1], sample)


class DataLoader(object):
Expand Down Expand Up @@ -100,8 +106,8 @@ def __next__(self):

batch = []
for idx in self.idxs[self.counter: self.counter + self.batch_size]:
batch.append(self.dataset[idx][1])
batch.append(self.dataset[idx][2])

self.counter += self.batch_size

return self.dataset[idx][0], numpy.stack(batch)
return self.dataset[idx][0], self.dataset[idx][1], numpy.stack(batch)
27 changes: 16 additions & 11 deletions src/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
self.patch_size = model_env["patch_size"]
self.stride = model_env["stride"]
self.num_of_ch = model_env["num_of_ch"]
self.output_npy = model_env["output_npy"]

self.org_l = int(256 / 8.0) - self.cutoff_edge_width * 2

Expand All @@ -47,7 +48,7 @@ def __init__(
def train(self):
arrs = []
for batch_data in self.train_loader:
batch_img = batch_data[1]
batch_img = batch_data[2]
for p in self.preprocesses:
batch_img = p(batch_img)
N, P, C, H, W = batch_img.shape
Expand Down Expand Up @@ -111,7 +112,7 @@ def calculate_error(self, coders, is_positive):

for batch_data in tqdm(loader, desc="testing"):

batch_name, batch_img = batch_data[0], batch_data[1]
batch_path, batch_name, batch_img = batch_data
p_batch_img = batch_img
for p in self.preprocesses:
p_batch_img = p(p_batch_img)
Expand Down Expand Up @@ -144,18 +145,22 @@ def calculate_error(self, coders, is_positive):
top_5[numpy.argsort(ch_err)[::-1][:5]] += 1
errs.append(numpy.sum(ch_err))
f_diff /= self.num_of_ch
visualized_out = self.visualize(org_img, f_diff)
self.output_image(is_positive, batch_name,
ch_err, visualized_out)
if self.output_npy:
self.output_np_array(batch_path, batch_name, f_diff)
else:
visualized_out = self.visualize(org_img, f_diff)
self.output_image(batch_path, batch_name,
ch_err, visualized_out)
return errs

def output_image(self, is_positive, batch_name, ch_err, visualized_out):
if is_positive:
mode = "pos"
else:
mode = "neg"
def output_np_array(self, batch_path, batch_name, f_diff):
output_path = os.path.join("visualized_results", batch_path)
os.makedirs(output_path, exist_ok=True)
numpy.save(os.path.join(
output_path, batch_name.split(".")[0] + ".npy"), f_diff)

output_path = os.path.join("visualized_results", mode)
def output_image(self, batch_path, batch_name, ch_err, visualized_out):
output_path = os.path.join("visualized_results", batch_path)
os.makedirs(output_path, exist_ok=True)

cv2.imwrite(
Expand Down

0 comments on commit 1c6fe9f

Please sign in to comment.