forked from ganyc717/LeNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lenet.py
42 lines (34 loc) · 1.96 KB
/
lenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import tensorflow as tf
import tensorflow.contrib.slim as slim
import config as cfg
class Lenet:
def __init__(self):
self.raw_input_image = tf.placeholder(tf.float32, [None, 784])
self.input_images = tf.reshape(self.raw_input_image, [-1, 28, 28, 1])
self.raw_input_label = tf.placeholder("float", [None, 10])
self.input_labels = tf.cast(self.raw_input_label,tf.int32)
self.dropout = cfg.KEEP_PROB
with tf.variable_scope("Lenet") as scope:
self.train_digits = self.construct_net(True)
scope.reuse_variables()
self.pred_digits = self.construct_net(False)
self.prediction = tf.argmax(self.pred_digits, 1)
self.correct_prediction = tf.equal(tf.argmax(self.pred_digits, 1), tf.argmax(self.input_labels, 1))
self.train_accuracy = tf.reduce_mean(tf.cast(self.correct_prediction, "float"))
self.loss = slim.losses.softmax_cross_entropy(self.train_digits, self.input_labels)
self.lr = cfg.LEARNING_RATE
self.train_op = tf.train.AdamOptimizer(self.lr).minimize(self.loss)
def construct_net(self,is_trained = True):
with slim.arg_scope([slim.conv2d], padding='VALID',
weights_initializer=tf.truncated_normal_initializer(stddev=0.01),
weights_regularizer=slim.l2_regularizer(0.0005)):
net = slim.conv2d(self.input_images,6,[5,5],1,padding='SAME',scope='conv1')
net = slim.max_pool2d(net, [2, 2], scope='pool2')
net = slim.conv2d(net,16,[5,5],1,scope='conv3')
net = slim.max_pool2d(net, [2, 2], scope='pool4')
net = slim.conv2d(net,120,[5,5],1,scope='conv5')
net = slim.flatten(net, scope='flat6')
net = slim.fully_connected(net, 84, scope='fc7')
net = slim.dropout(net, self.dropout,is_training=is_trained, scope='dropout8')
digits = slim.fully_connected(net, 10, scope='fc9')
return digits