-
Notifications
You must be signed in to change notification settings - Fork 0
/
BAISTools.py
113 lines (99 loc) · 4.77 KB
/
BAISTools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import time
import numpy as np
from PIL import Image
import tensorflow as tf
import tensorflow.contrib.slim as slim
class Tools(object):
@staticmethod
def new_dir(path):
if not os.path.exists(path):
os.makedirs(path)
return path
@staticmethod
def print_info(info):
print("{} {}".format(time.strftime("%H:%M:%S", time.localtime()), info))
pass
@staticmethod
def to_txt(data, file_name):
with open(file_name, "w") as f:
for one_data in data:
f.write("{}\n".format(one_data))
pass
pass
# 对输出进行着色
@staticmethod
def decode_labels(mask, num_images, num_classes):
# 0 = road, 1 = sidewalk, 2 = building, 3 = wall, 4 = fence, 5 = pole,
# 6 = traffic light, 7 = traffic sign, 8 = vegetation, 9 = terrain, 10 = sky,
# 11 = person, 12 = rider, 13 = car, 14 = truck, 15 = bus,
# 16 = train, 17 = motocycle, 18 = bicycle, 19 = void label
label_colours = [(0, 0, 0), (128, 64, 128), (244, 35, 231), (69, 69, 69), (102, 102, 156), (190, 153, 153),
(153, 153, 153), (250, 170, 29), (219, 219, 0), (106, 142, 35), (152, 250, 152),
(69, 129, 180), (219, 19, 60), (255, 0, 0), (0, 0, 142), (0, 0, 69),
(0, 60, 100), (0, 79, 100), (0, 0, 230), (119, 10, 32), (1, 1, 1)]
n, h, w, c = mask.shape
assert (n >= num_images), \
'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images)
outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
for i in range(num_images):
img = Image.new('RGB', (len(mask[i, 0]), len(mask[i])))
pixels = img.load()
for j_, j in enumerate(mask[i, :, :, 0]):
for k_, k in enumerate(j):
if k < num_classes:
pixels[k_, j_] = label_colours[k]
outputs[i] = np.array(img)
return outputs
# 对输出进行着色
@staticmethod
def decode_labels_test(mask, num_images, num_classes):
# 0 = road, 1 = sidewalk, 2 = building, 3 = wall, 4 = fence, 5 = pole,
# 6 = traffic light, 7 = traffic sign, 8 = vegetation, 9 = terrain, 10 = sky,
# 11 = person, 12 = rider, 13 = car, 14 = truck, 15 = bus,
# 16 = train, 17 = motocycle, 18 = bicycle, 19 = void label
label_colours = [(128, 64, 128), (244, 35, 231), (69, 69, 69), (102, 102, 156), (190, 153, 153),
(153, 153, 153), (250, 170, 29), (219, 219, 0), (106, 142, 35), (152, 250, 152),
(69, 129, 180), (219, 19, 60), (255, 0, 0), (0, 0, 142), (0, 0, 69),
(0, 60, 100), (0, 79, 100), (0, 0, 230), (119, 10, 32), (1, 1, 1)]
mask = np.array(mask)
n, h, w, c = mask.shape
assert (n >= num_images), \
'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images)
outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8)
for i in range(num_images):
img = Image.new('RGB', (len(mask[i, 0]), len(mask[i])))
pixels = img.load()
for j_, j in enumerate(mask[i, :, :, 0]):
for k_, k in enumerate(j):
if k < num_classes:
pixels[k_, j_] = label_colours[k]
outputs[i] = np.array(img)
return outputs
@staticmethod
def prepare_label(input_batch, new_size, num_classes, one_hot=True):
with tf.name_scope('label_encode'):
# as labels are integer numbers, need to use NN interp.
input_batch = tf.image.resize_nearest_neighbor(input_batch, new_size)
input_batch = tf.squeeze(input_batch, squeeze_dims=[3]) # reducing the channel dimension.
if one_hot:
input_batch = tf.one_hot(input_batch, depth=num_classes)
return input_batch
# 如果模型存在,恢复模型
@staticmethod
def restore_if_y(sess, log_dir, pretrain=None):
# 加载模型
ckpt = tf.train.get_checkpoint_state(log_dir)
pretrain = ckpt.model_checkpoint_path if ckpt and ckpt.model_checkpoint_path else pretrain
if pretrain:
# tf.train.Saver(var_list=tf.global_variables()).restore(sess, ckpt.model_checkpoint_path)
slim.assign_from_checkpoint_fn(pretrain, var_list=tf.global_variables(), ignore_missing_vars=True)(sess)
Tools.print_info("Restored model parameters from {}".format(pretrain))
else:
Tools.print_info('No checkpoint file found.')
pass
pass
@staticmethod
def get_shape(tensor):
return [int(i) for i in list(tensor.shape)[1:]]
pass