Skip to content

Latest commit

 

History

History
executable file
·
65 lines (47 loc) · 2.29 KB

README.md

File metadata and controls

executable file
·
65 lines (47 loc) · 2.29 KB

Tensorflow VGG

This is a Tensorflow implemention of VGG forked from tensorflow-vgg repo.

The main change from the original repo are:

  • The Vgg class won't load the VGG model in the constructor so that you are allowed to share the model among multiple Vgg instances.
  • The Vgg class is now able to do training and prediction and you could shared a model among multiple Vgg instance. That makes you more flexible to design the algorithm infrastructure.

To use the VGG networks, you should download the *.npy files from VGG16 or VGG19 and put them under the models directory.

Usage

It is still under development, please stay tuned. This is a rough example for the prediction.

###Use It for Predicting

First, load the model (around 500MB) to the memory.

model = np.load("/path/to/your/vgg19.npy").item()
print("The VGG model is loaded.")

New the Vgg19 instance in the scope of your tf.Graph. Run the graph in a tf.Session.

# Design the graph.
graph = tf.Graph()
with graph.as_default():
    nn = Vgg19(model=model)

# Run the graph in the session.
with tf.Session(graph=graph) as sess:
    tf.initialize_all_variables().run()
    print("Tensorflow initialized all variables.")

    preds = sess.run(nn.preds,
                     feed_dict={
                         nn.inputRGB: imgs
                     })

You could also add a tf.train.SummaryWriter to summarize whatever you want. Then enter the command, tensorboard --log=/path/to/your/log/dir, to see the summary in the web page.

# Run the graph in the session.
with tf.Session(graph=graph) as sess:
    tf.initialize_all_variables().run()
    print("Tensorflow initialized all variables.")

    # The OP to write logs to Tensorboard.
    summary_writer = tf.train.SummaryWriter("/path/to/your/log/dir",
                                            graph=sess.graph)

    preds = sess.run(nn.preds,
                     feed_dict={
                         nn.inputRGB: imgs
                     })

installation method: pip install tensorflow refer: https://www.tensorflow.org/