1 import tensorflow as tf;
2 from tensorflow.examples.tutorials.mnist import input_data
3
4 ##定义网络结构
5 input_nodes = 784
6 output_nodes = 10
7 layer1_nodes = 500
8 #定义超参数
9 #自动设置学习率
10 learning_rate_base= 0.8;
11 learning_decay = 0.99 ;
12 decay_step=100 ;
13
14 #滑动平均
15 moving_average__decay = 0.99
16 regularizer_rate = 0.0001;
17 train_step=30000
18 batch_size= 100
19
20
21 def inference(tensor1,weight1,bias1,weight2,bias2,average_class=None):
22 if(average_class==None):
23 layer1=tf.nn.relu( tf.matmul(tensor1,weight1)+ bias1 )
24 return tf.matmul( layer1,weight2 ) + bias2
25 else:
26 layer1 = tf.nn.relu(tf.matmul(tensor1, average_class.average(weight1)) + average_class.average(bias1))
27 return tf.matmul(layer1, average_class.average(weight2) ) + average_class.average(bias2)
28
29 def get_weight(shape):
30 weight=tf.Variable(tf.truncated_normal(shape=shape,stddev=0.1),tf.float32)
31 tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer_rate)(weight))
32 return weight
33
34 def get_bias(shape):
35 return tf.Variable(tf.zeros(shape))
36
37 def train(mnist):
38 #定义输入输出
39 train_x=tf.placeholder(tf.float32,shape=[None,input_nodes],name='train_x')
40 train_y=tf.placeholder(tf.float32,shape=[None,output_nodes],name='train_y' )
41
42 weight1=get_weight( [input_nodes,layer1_nodes] )
43 bias1 =get_bias([layer1_nodes])
44
45 weight2=get_weight([layer1_nodes,output_nodes]);
46 bias2 =get_bias([output_nodes])
47 results = inference(train_x, weight1, bias1, weight2, bias2, None)
48
49 #定义学习率
50 global_step = tf.Variable(0, trainable=False)
51 learning_rate = tf.train.exponential_decay(learning_rate_base, global_step, mnist.train.num_examples / batch_size, learning_decay,staircase=True)
52
53 #定义损失、优化器
54
55 ce= tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=results,labels=tf.argmax( train_y,1) ) )
56 loss=ce+tf.add_n( tf.get_collection('losses') )
57 tf.summary.scalar('lost',loss)
58
59 optimizer= tf.train.GradientDescentOptimizer(learning_rate).minimize(loss,global_step=global_step);
60
61 #定义滑动平均
62 ema = tf.train.ExponentialMovingAverage(moving_average__decay, global_step);
63 maintain_average_op = ema.apply( tf.trainable_variables())
64 with tf.control_dependencies([optimizer,maintain_average_op]):
65 train_op=tf.no_op(name='train')
66
67 #预测准确率
68 average_y=inference(train_x,weight1,bias1,weight2,bias2,ema);
69 correction_prediction = tf.equal( tf.argmax( average_y,1 ) ,tf.argmax(train_y,1))
70 accuracy = tf.reduce_mean(tf.cast(correction_prediction,tf.float32));
71
72 with tf.Session() as sess:
73 tf.global_variables_initializer().run()
74
75 validate_feed={train_x:mnist.validation.images,train_y:mnist.validation.labels}
76 test_feed ={train_x:mnist.test.images,train_y:mnist.test.labels}
77
78 #汇总
79 merged_summary_op = tf.summary.merge_all()
80 summaryWriter = tf.summary.FileWriter('./log/mnist_with_summaries',sess.graph)
81
82 #迭代训练
83 for i in range(train_step):
84 if(i%1000 == 0 ):
85 validate_acc=sess.run(accuracy,feed_dict=validate_feed);
86 print('After %d training steps,using aaverage model is %g '%(i,validate_acc))
87
88 xt,yt=mnist.train.next_batch(batch_size);
89 sess.run( train_op,feed_dict={ train_x :xt,train_y:yt} );
90 summary_str=sess.run( merged_summary_op,feed_dict={ train_x :xt,train_y:yt} );
91 summaryWriter.add_summary(summary_str,i)
92
93
94 test_acc=sess.run(accuracy,feed_dict=test_feed)
95 print('accuracy is %g'%(test_acc));
96 def main():
97 mnist= input_data.read_data_sets('./MNIST_data',one_hot=True)
98 train(mnist);
99
100 if __name__ == '__main__':
101 main()