学会用tensorflow搭建简单的神经网络 2-1 用matplotlib画出变化图

本篇文的代码是针对于学会用tensorflow搭建简单的神经网络 2那篇文章的一个改进,重点在于让读者学会用matplotlib画出随机生成的一个测试数据,同时也画出了预测曲线,让程序员更加清楚地知道本训练中计算机学习的效果。  在阅读本文之前,建议先阅读《学会用tensorflow搭建简单的神经网络 2》这篇文章。

代码中红色部分为新添代码

 1 #!/usr/bin/env python
 2 # _*_ coding: utf-8 _*_
 3 import tensorflow as tf
 4 import numpy as np
 5 import matplotlib.pyplot as plt
 6 #add_laye
 7 def add_layer(inputs, in_size, out_size, activation_function=None):
 8     # add one more layer and return the output of this layer
 9     Weights = tf.Variable(tf.random_normal([in_size, out_size]))
10     biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
11     Wx_plus_b = tf.matmul(inputs, Weights) + biases
12     if activation_function is None:
13         outputs = Wx_plus_b
14     else:
15         outputs = activation_function(Wx_plus_b)
16     return outputs
17 # Make up some real data
18 x_data = np.linspace(-1,1,300)[:, np.newaxis]
19 noise = np.random.normal(0, 0.05, x_data.shape)
20 y_data = np.square(x_data) - 0.5 + noise
21 # define placeholder for inputs to network
22 xs = tf.placeholder(tf.float32, [None, 1])
23 ys = tf.placeholder(tf.float32, [None, 1])
24 # add hidden layer
25 layer1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
26 # add output layer
27 prediction = add_layer(layer1, 10, 1, activation_function=None)
28 # the error between prediciton and real data
29 loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
30                      reduction_indices=[1]))
31 #Select the optimizer to minimize the loss
32 train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
33 # Important step---Initializes all variables
34 init=tf.initialize_all_variables()
35 # Sess.run ---it will start the operation
36 sess=tf.Session()
37 sess.run(init)
38 #Iterations of 201 times
39 fig=plt.figure()  #生成一个图片框
40 ax=fig.add_subplot(1,1,1)#生成连续性的图片框
41 ax.scatter(x_data,y_data) #将训练数据以点的形式运行出来
42 plt.ion()  #连续生成图片
43 plt.show() #打印图片
44 for i in range(1000):
45     sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
46     if i % 50 == 0:
47         # to see the step improvement
48         #print(sess.run(loss, feed_dict={xs: x_data, ys: y_data}))
49         try:
50             ax.lines.remove(lines[0])  # 因为线是在不断地生成,所以要把旧的线去掉,只显示新生成的,这样方便查看结果
51         except Exception:
52             pass
53         prediction_value=sess.run(prediction,feed_dict={xs: x_data, ys: y_data})  #run prediction
54         lines=ax.plot(x_data,prediction_value,'r-',lw=5)  #把prediction_value的数据plot上去,以红色的线,宽度为5
55         plt.pause(0.2) #每次新生成的线暂停0.2s

结果:截取其中一张图片

 

posted on 2017-08-08 10:07  可可洁儿  阅读(370)  评论(0)    收藏  举报

导航