-
Notifications
You must be signed in to change notification settings - Fork 346
/
Copy pathseg.py
110 lines (82 loc) · 3.07 KB
/
seg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
from io import BytesIO
import numpy as np
from PIL import Image
import tensorflow as tf
import sys
import datetime
class DeepLabModel(object):
"""Class to load deeplab model and run inference."""
INPUT_TENSOR_NAME = 'ImageTensor:0'
OUTPUT_TENSOR_NAME = 'SemanticPredictions:0'
INPUT_SIZE = 513
FROZEN_GRAPH_NAME = 'frozen_inference_graph'
def __init__(self, tarball_path):
"""Creates and loads pretrained deeplab model."""
self.graph = tf.Graph()
graph_def = None
graph_def = tf.GraphDef.FromString(open(tarball_path + "/frozen_inference_graph.pb", "rb").read())
if graph_def is None:
raise RuntimeError('Cannot find inference graph in tar archive.')
with self.graph.as_default():
tf.import_graph_def(graph_def, name='')
self.sess = tf.Session(graph=self.graph)
def run(self, image):
"""Runs inference on a single image.
Args:
image: A PIL.Image object, raw input image.
Returns:
resized_image: RGB image resized from original input image.
seg_map: Segmentation map of `resized_image`.
"""
start = datetime.datetime.now()
width, height = image.size
resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height)
target_size = (int(resize_ratio * width), int(resize_ratio * height))
resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS)
batch_seg_map = self.sess.run(
self.OUTPUT_TENSOR_NAME,
feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]})
seg_map = batch_seg_map[0]
end = datetime.datetime.now()
diff = end - start
print("Time taken to evaluate segmentation is : " + str(diff))
return resized_image, seg_map
def drawSegment(baseImg, matImg):
width, height = baseImg.size
dummyImg = np.zeros([height, width, 4], dtype=np.uint8)
for x in range(width):
for y in range(height):
color = matImg[y,x]
(r,g,b) = baseImg.getpixel((x,y))
if color == 0:
dummyImg[y,x,3] = 0
else :
dummyImg[y,x] = [r,g,b,255]
img = Image.fromarray(dummyImg)
img.save(outputFilePath)
inputFilePath = sys.argv[1]
outputFilePath = sys.argv[2]
if inputFilePath is None or outputFilePath is None:
print("Bad parameters. Please specify input file path and output file path")
exit()
modelType = "mobile_net_model"
if len(sys.argv) > 3 and sys.argv[3] == "1":
modelType = "xception_model"
MODEL = DeepLabModel(modelType)
print('model loaded successfully : ' + modelType)
def run_visualization(filepath):
"""Inferences DeepLab model and visualizes result."""
try:
print("Trying to open : " + sys.argv[1])
# f = open(sys.argv[1])
jpeg_str = open(filepath, "rb").read()
orignal_im = Image.open(BytesIO(jpeg_str))
except IOError:
print('Cannot retrieve image. Please check file: ' + filepath)
return
print('running deeplab on image %s...' % filepath)
resized_im, seg_map = MODEL.run(orignal_im)
# vis_segmentation(resized_im, seg_map)
drawSegment(resized_im, seg_map)
run_visualization(inputFilePath)