-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathinference.py
128 lines (102 loc) · 4.15 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import json
import boto3
import json_logging
import os
import logging
import sys
import time
import tensorflow as tf
import tensorflow_hub as hub
from flask import Flask
from flask import request
# Settings
if os.environ.get("MIN_SCORE"):
MIN_SCORE = float(os.environ.get("MIN_SCORE"))
else:
MIN_SCORE = 0.1
if os.environ.get("MAX_BOXES"):
MAX_BOXES = int(os.environ.get("MAX_BOXES"))
else:
MAX_BOXES = 15
# Logger
# Logger initialized
json_logging.init_non_web(enable_json=True)
logger = logging.getLogger("serving")
if os.environ.get("LOG_LEVEL"):
logger.setLevel(os.environ.get("LOG_LEVEL"))
else:
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler(sys.stdout))
# Flask app
app = Flask(__name__)
# Load tensorflow hub model
print("Loading model")
model = "https://tfhub.dev/google/faster_rcnn/openimages_v4/inception_resnet_v2/1"
detector = hub.load(model).signatures['default']
print("Model loaded")
def build_result_object(boxes, class_names, scores, max_boxes=15, min_score=0.1):
logger.debug("Building {} boxes with min score of {}".format(max_boxes, min_score))
result_object_list = []
for i in range(min(boxes.shape[0], max_boxes)):
logger.debug("checking if {} is bigger than {}".format(scores[i], min_score))
if scores[i] >= min_score:
y_min, x_min, y_max, x_max = tuple(boxes[i])
logger.info("object detected", extra={'props': {'service': 'tf-hub',
'object': class_names[i].decode("ascii"),
'confidence': str(scores[i])}})
obj_dict = {"ymin": str(y_min),
"xmin": str(x_min),
"ymax": str(y_max),
"xmax": str(x_max),
"class": class_names[i].decode("ascii"),
"confidence": str(scores[i]),
"mp3": "_empty_"
}
result_object_list.append(obj_dict)
logger.debug("Added {} objects to result".format(len(result_object_list)))
return result_object_list
def download_s3_img(bucket_name, object_prefix, object_key):
logger.debug("Downloading from {} image {}/{}".format(bucket_name, object_prefix, object_key))
full_object_path = object_prefix + '/' + object_key
s3 = boto3.client('s3')
s3.download_file(bucket_name, full_object_path, object_key)
return object_key
def load_img(path):
logger.debug("Loading image {}".format(path))
img = tf.io.read_file(path)
img = tf.image.decode_jpeg(img, channels=3)
return img
def run_detector(func_detector, img):
converted_img = tf.image.convert_image_dtype(img, tf.float32)[tf.newaxis, ...]
start_time = time.time()
result = func_detector(converted_img)
end_time = time.time()
result = {key: value.numpy() for key, value in result.items()}
logger.info("Found %d objects." % len(result["detection_scores"]))
elapsed_time = end_time - start_time
logger.info("Inference time: {}".format(elapsed_time))
return result
def delete_temp_image(file):
os.remove(file)
@app.route('/ping')
def ping():
return "pong"
@app.route('/invocations', methods=['GET', 'POST'])
def invocations():
if request.method == "POST":
response_body = request.get_json()
if "file_name" in response_body:
s3_bucket = response_body['s3_bucket']
object_prefix = response_body['key_prefix']
object_key = response_body['file_name']
downloaded_image_name = download_s3_img(s3_bucket, object_prefix, object_key)
img = load_img(downloaded_image_name)
result = run_detector(detector, img)
result_object = build_result_object(result["detection_boxes"], result["detection_class_entities"],
result["detection_scores"], MAX_BOXES, MIN_SCORE)
delete_temp_image(object_key)
return json.dumps(result_object)
else:
return "Missing file name in POST request"
if __name__ == '__main__':
app.run(port=8080, host="0.0.0.0", debug=False)