diff --git a/gluon/gluoncv2/models/resnet.py b/gluon/gluoncv2/models/resnet.py index dd8d0ce03..a77bc36d5 100644 --- a/gluon/gluoncv2/models/resnet.py +++ b/gluon/gluoncv2/models/resnet.py @@ -959,38 +959,38 @@ def _test(): import numpy as np import mxnet as mx - pretrained = True + pretrained = False models = [ - resnet10, - resnet12, - resnet14, - resnet16, - resnet18_wd4, - resnet18_wd2, - resnet18_w3d4, + # resnet10, + # resnet12, + # resnet14, + # resnet16, + # resnet18_wd4, + # resnet18_wd2, + # resnet18_w3d4, resnet18, - resnet34, - resnet50, - resnet50b, - resnet101, - resnet101b, - resnet152, - resnet152b, - # resnet200, - # resnet200b, - - # seresnet18, - # seresnet34, - seresnet50, - # seresnet50b, - seresnet101, - # seresnet101b, - seresnet152, - # seresnet152b, - # seresnet200, - # seresnet200b, + # resnet34, + # resnet50, + # resnet50b, + # resnet101, + # resnet101b, + # resnet152, + # resnet152b, + # # resnet200, + # # resnet200b, + # + # # seresnet18, + # # seresnet34, + # seresnet50, + # # seresnet50b, + # seresnet101, + # # seresnet101b, + # seresnet152, + # # seresnet152b, + # # seresnet200, + # # seresnet200b, ] for model in models: diff --git a/tensorflow_/models/common.py b/tensorflow_/models/common.py index ab9b7d9d5..f232c66f5 100644 --- a/tensorflow_/models/common.py +++ b/tensorflow_/models/common.py @@ -72,6 +72,29 @@ def conv2d(x, use_bias=use_bias, name=name) + # import numpy as np + # if tf_padding == 'same': + # tf_padding = 'SAME' + # elif tf_padding == 'valid': + # tf_padding = 'VALID' + # else: + # raise ValueError('Invalid padding: ' + str(tf_padding)) + # x = tf.nn.convolution( + # input=x, + # filter=tf.Variable(np.zeros(kernel_size + (int(x.shape[1]), out_channels), np.float32), + # dtype=tf.float32, name=name + "/kernel"), + # dilation_rate=(1, 1), + # strides=strides, + # padding=tf_padding, + # # name=name + "/kernel1", + # data_format='NCHW') + # if use_bias: + # x = tf.nn.bias_add( + # value=x, + # bias=tf.Variable(np.zeros((out_channels,), np.float32), dtype=tf.float32, name=name + "/bias"), + # data_format='NCHW') + # # name=name + "/bias1") + return x diff --git a/tensorflow_/models/model_store.py b/tensorflow_/models/model_store.py index 600c4ab8d..f3eede2e3 100644 --- a/tensorflow_/models/model_store.py +++ b/tensorflow_/models/model_store.py @@ -220,6 +220,8 @@ def load_model(sess, sess.run(dst_params[src_key].assign(src_params[src_key])) elif not ignore_extra: raise Exception("The file `{}` is incompatible with the model".format(file_path)) + else: + print("Key `{}` is ignored".format(src_key)) def download_model(sess, diff --git a/tensorflow_/models/resnet.py b/tensorflow_/models/resnet.py index a1189c212..5a4982202 100644 --- a/tensorflow_/models/resnet.py +++ b/tensorflow_/models/resnet.py @@ -999,40 +999,40 @@ def _test(): pretrained = False models = [ - resnet10, - resnet12, - resnet14, - resnet16, - resnet18_wd4, - resnet18_wd2, - resnet18_w3d4, + # resnet10, + # resnet12, + # resnet14, + # resnet16, + # resnet18_wd4, + # resnet18_wd2, + # resnet18_w3d4, resnet18, - resnet34, - resnet50, - resnet50b, - resnet101, - resnet101b, - resnet152, - resnet152b, - resnet200, - resnet200b, - - seresnet18, - seresnet34, - seresnet50, - seresnet50b, - seresnet101, - seresnet101b, - seresnet152, - seresnet152b, - seresnet200, - seresnet200b, + # resnet34, + # resnet50, + # resnet50b, + # resnet101, + # resnet101b, + # resnet152, + # resnet152b, + # resnet200, + # resnet200b, + # + # seresnet18, + # seresnet34, + # seresnet50, + # seresnet50b, + # seresnet101, + # seresnet101b, + # seresnet152, + # seresnet152b, + # seresnet200, + # seresnet200b, ] for model in models: - net, _ = model(pretrained=pretrained) + net = model(pretrained=pretrained) x = tf.placeholder( dtype=tf.float32, @@ -1072,7 +1072,8 @@ def _test(): with tf.Session() as sess: sess.run(tf.global_variables_initializer()) - y = sess.run(y_net, feed_dict={x: np.zeros((1, 3, 224, 224), np.float32)}) + x_value = np.zeros((1, 3, 224, 224), np.float32) + y = sess.run(y_net, feed_dict={x: x_value}) assert (y.shape == (1, 1000)) tf.reset_default_graph()