Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for batch nms and export checkpoint to frozen graph #174

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ The pretrained darknet weights file can be downloaded [here](https://pjreddie.co
python convert_weight.py
```

Then the converted TensorFlow checkpoint file will be saved to `./data/darknet_weights/` directory.
If you want to convert weights to frozen graph:
```shell
python export_frozen_graph.py
```
Then the converted TensorFlow checkpoint(and frozen graph) file will be saved to `./data/darknet_weights/` directory.

You can also download the converted TensorFlow checkpoint file by me via [[Google Drive link](https://drive.google.com/drive/folders/1mXbNgNxyXPi7JNsnBaxEv1-nWr7SVoQt?usp=sharing)] or [[Github Release](https://github.com/wizyoung/YOLOv3_TensorFlow/releases/)] and then place it to the same directory.

Expand All @@ -41,6 +45,10 @@ Single image test demo:
```shell
python test_single_image.py ./data/demo_data/messi.jpg
```
Batch images test demo(using frozen graph):
```shell
python test_batch_images.py ./data/demo_data/
```

Video test demo:

Expand Down
Binary file added data/demo_data/results/batch_output_dog.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/demo_data/results/batch_output_kite.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added data/demo_data/results/batch_output_messi.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
51 changes: 51 additions & 0 deletions export_frozen_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import tensorflow as tf
import numpy as np
import argparse
import cv2
import glob
import matplotlib.pyplot as plt
from utils.misc_utils import parse_anchors, read_class_names
from utils.nms_utils import batch_nms
from utils.plot_utils import get_color_table, plot_one_box
from utils.data_aug import letterbox_resize
from model import yolov3

anchors = parse_anchors("./data/yolo_anchors.txt")
classes = read_class_names("./data/coco.names")
num_class = len(classes)

color_table = get_color_table(num_class)

with tf.Session() as sess:
# build graph
input_data = tf.placeholder(tf.float32, [None, None, None, 3], name='input')
yolo_model = yolov3(num_class, anchors)
with tf.variable_scope('yolov3'):
pred_feature_maps = yolo_model.forward(input_data, False)
pred_boxes, pred_confs, pred_probs = yolo_model.predict(pred_feature_maps)
pred_scores = pred_confs * pred_probs
boxes, scores, labels, num_dects = batch_nms(pred_boxes, pred_scores, max_boxes=20, score_thresh=0.5, nms_thresh=0.5)
# restore weight
saver = tf.train.Saver()
saver.restore(sess, "./data/darknet_weights/yolov3.ckpt")
# save
output_node_names = [
"output/boxes",
"output/scores",
"output/labels",
"output/num_detections",
"input",
]
output_node_names = ",".join(output_node_names)

output_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
tf.get_default_graph().as_graph_def(),
output_node_names.split(",")
)

with tf.gfile.GFile('./data/darknet_weights/yolov3_frozen_graph_batch.pb', "wb") as f:
f.write(output_graph_def.SerializeToString())

print("{} ops written to {}.".format(len(output_graph_def.node), './data/darknet_weights/yolov3_frozen_graph_batch.pb'))

2 changes: 1 addition & 1 deletion model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class yolov3(object):

def __init__(self, class_num, anchors, use_label_smooth=False, use_focal_loss=False, batch_norm_decay=0.999, weight_decay=5e-4, use_static_shape=True):
def __init__(self, class_num, anchors, use_label_smooth=False, use_focal_loss=False, batch_norm_decay=0.999, weight_decay=5e-4, use_static_shape=False):

# self.anchors = [[10, 13], [16, 30], [33, 23],
# [30, 61], [62, 45], [59, 119],
Expand Down
63 changes: 63 additions & 0 deletions test_batch_images.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import tensorflow as tf

class YoloV3:
"""Class to load ssd model and run inference."""
INPUT_NAME = 'input:0'
BOXES_NAME = 'output/boxes:0'
CLASSES_NAME = 'output/labels:0'
SCORES_NAME = 'output/scores:0'
NUM_DETECTIONS_NAME = 'output/num_detections:0'
def __init__(self, frozen_graph):
self.graph = tf.Graph()
graph_def = tf.GraphDef()
with tf.gfile.GFile(frozen_graph, 'rb') as fid:
serialized_graph = fid.read()
graph_def.ParseFromString(serialized_graph)

if graph_def is None:
raise RuntimeError('Cannot find inference graph.')

with self.graph.as_default():
tf.import_graph_def(graph_def, name='')

self.sess = tf.Session(graph=self.graph)
def run(self, image):
"""
image should be normalized to [0,1] and RGB order
"""
boxes, classes, scores, num_detections = self.sess.run(
[self.BOXES_NAME, self.CLASSES_NAME, self.SCORES_NAME, self.NUM_DETECTIONS_NAME],
feed_dict={self.INPUT_NAME: image})
return boxes, classes.astype(np.int64), scores, num_detections.astype(np.int64)

if __name__ == '__main__':
import os
import glob
import numpy as np
import cv2

from utils.plot_utils import get_color_table, plot_one_box
from utils.misc_utils import parse_anchors, read_class_names

model = YoloV3('./data/darknet_weights/yolov3_frozen_graph_batch.pb')
classes = read_class_names("./data/coco.names")
color_table = get_color_table(80)
files = glob.glob('./data/demo_data/*.jpg')
images = []
vis_images = []
for file in files:
image = cv2.imread(file)
image = cv2.resize(image, (640, 640))
vis_images.append(image)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)/255. #important!
images.append(image)
images = np.array(images)
# inference
boxes_,labels_,scores_, num_dect_= model.run(images)
# visualize
for idx, image in enumerate(vis_images):
for i in range(len(boxes_[idx])):
x0, y0, x1, y1 = boxes_[idx][i]
plot_one_box(image, [x0, y0, x1, y1], label=classes[labels_[idx][i]] + ', {:.2f}%'.format(scores_[idx][i] * 100), color=color_table[labels_[idx][i]])
out_name = os.path.join('./data/demo_data/results', 'batch_output_' + os.path.basename(files[idx]))
cv2.imwrite(out_name, image)
28 changes: 26 additions & 2 deletions utils/nms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def gpu_nms(boxes, scores, num_classes, max_boxes=50, score_thresh=0.5, nms_thre
score = tf.concat(score_list, axis=0)
label = tf.concat(label_list, axis=0)

return boxes, score, label
return tf.identity(boxes, name='output/boxes'), tf.identity(score,name='output/scores'), tf.identity(label, name='output/labels')


def py_nms(boxes, scores, max_boxes=50, iou_thresh=0.5):
Expand Down Expand Up @@ -120,4 +120,28 @@ def cpu_nms(boxes, scores, num_classes, max_boxes=50, score_thresh=0.5, iou_thre
score = np.concatenate(picked_score, axis=0)
label = np.concatenate(picked_label, axis=0)

return boxes, score, label
return boxes, score, label

def batch_nms(boxes, scores, max_boxes=20, score_thresh=0.5, nms_thresh=0.5):
"""
Perform batch NMS on GPU using TensorFlow.

params:
boxes: tensor of shape [batch_size, num_anchors, 4] # num_anchors=10647=(13*13+26*26+52*52)*3, for input 416*416 image
scores: tensor of shape [batch_size, num_anchors, num_classes], score=conf*prob
max_boxes: integer, maximum number of predicted boxes you'd like, default is 20
score_thresh: if [ highest class probability score < score_threshold]
then get rid of the corresponding box
nms_thresh: real value, "intersection over union" threshold used for NMS filtering
"""
# make shape valid for tf.image.combined_non_max_suppression()
boxes = tf.expand_dims(boxes, 2) # boxes now is [bs, na, 1, 4]
boxes, scores, labels, valid_num_detections = tf.image.combined_non_max_suppression(boxes = boxes,
scores = scores,
max_output_size_per_class = 5,
max_total_size = max_boxes,
iou_threshold = nms_thresh,
score_threshold = score_thresh,
pad_per_class = False,
clip_boxes = False)
return tf.identity(boxes, name='output/boxes'), tf.identity(scores, name='output/scores'), tf.identity(labels, name='output/labels'), tf.identity(valid_num_detections, name='output/num_detections')