Skip to content

Commit

Permalink
Work on TF branch, 3
Browse files Browse the repository at this point in the history
  • Loading branch information
osmr committed Oct 15, 2018
1 parent 7f99c7a commit 92e65dd
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 57 deletions.
56 changes: 28 additions & 28 deletions gluon/gluoncv2/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -959,38 +959,38 @@ def _test():
import numpy as np
import mxnet as mx

pretrained = True
pretrained = False

models = [
resnet10,
resnet12,
resnet14,
resnet16,
resnet18_wd4,
resnet18_wd2,
resnet18_w3d4,
# resnet10,
# resnet12,
# resnet14,
# resnet16,
# resnet18_wd4,
# resnet18_wd2,
# resnet18_w3d4,

resnet18,
resnet34,
resnet50,
resnet50b,
resnet101,
resnet101b,
resnet152,
resnet152b,
# resnet200,
# resnet200b,

# seresnet18,
# seresnet34,
seresnet50,
# seresnet50b,
seresnet101,
# seresnet101b,
seresnet152,
# seresnet152b,
# seresnet200,
# seresnet200b,
# resnet34,
# resnet50,
# resnet50b,
# resnet101,
# resnet101b,
# resnet152,
# resnet152b,
# # resnet200,
# # resnet200b,
#
# # seresnet18,
# # seresnet34,
# seresnet50,
# # seresnet50b,
# seresnet101,
# # seresnet101b,
# seresnet152,
# # seresnet152b,
# # seresnet200,
# # seresnet200b,
]

for model in models:
Expand Down
23 changes: 23 additions & 0 deletions tensorflow_/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,29 @@ def conv2d(x,
use_bias=use_bias,
name=name)

# import numpy as np
# if tf_padding == 'same':
# tf_padding = 'SAME'
# elif tf_padding == 'valid':
# tf_padding = 'VALID'
# else:
# raise ValueError('Invalid padding: ' + str(tf_padding))
# x = tf.nn.convolution(
# input=x,
# filter=tf.Variable(np.zeros(kernel_size + (int(x.shape[1]), out_channels), np.float32),
# dtype=tf.float32, name=name + "/kernel"),
# dilation_rate=(1, 1),
# strides=strides,
# padding=tf_padding,
# # name=name + "/kernel1",
# data_format='NCHW')
# if use_bias:
# x = tf.nn.bias_add(
# value=x,
# bias=tf.Variable(np.zeros((out_channels,), np.float32), dtype=tf.float32, name=name + "/bias"),
# data_format='NCHW')
# # name=name + "/bias1")

return x


Expand Down
2 changes: 2 additions & 0 deletions tensorflow_/models/model_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def load_model(sess,
sess.run(dst_params[src_key].assign(src_params[src_key]))
elif not ignore_extra:
raise Exception("The file `{}` is incompatible with the model".format(file_path))
else:
print("Key `{}` is ignored".format(src_key))


def download_model(sess,
Expand Down
59 changes: 30 additions & 29 deletions tensorflow_/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,40 +999,40 @@ def _test():
pretrained = False

models = [
resnet10,
resnet12,
resnet14,
resnet16,
resnet18_wd4,
resnet18_wd2,
resnet18_w3d4,
# resnet10,
# resnet12,
# resnet14,
# resnet16,
# resnet18_wd4,
# resnet18_wd2,
# resnet18_w3d4,

resnet18,
resnet34,
resnet50,
resnet50b,
resnet101,
resnet101b,
resnet152,
resnet152b,
resnet200,
resnet200b,

seresnet18,
seresnet34,
seresnet50,
seresnet50b,
seresnet101,
seresnet101b,
seresnet152,
seresnet152b,
seresnet200,
seresnet200b,
# resnet34,
# resnet50,
# resnet50b,
# resnet101,
# resnet101b,
# resnet152,
# resnet152b,
# resnet200,
# resnet200b,
#
# seresnet18,
# seresnet34,
# seresnet50,
# seresnet50b,
# seresnet101,
# seresnet101b,
# seresnet152,
# seresnet152b,
# seresnet200,
# seresnet200b,
]

for model in models:

net, _ = model(pretrained=pretrained)
net = model(pretrained=pretrained)

x = tf.placeholder(
dtype=tf.float32,
Expand Down Expand Up @@ -1072,7 +1072,8 @@ def _test():

with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
y = sess.run(y_net, feed_dict={x: np.zeros((1, 3, 224, 224), np.float32)})
x_value = np.zeros((1, 3, 224, 224), np.float32)
y = sess.run(y_net, feed_dict={x: x_value})
assert (y.shape == (1, 1000))
tf.reset_default_graph()

Expand Down

0 comments on commit 92e65dd

Please sign in to comment.