学会用tensorflow搭建简单的神经网络 2-1 用matplotlib画出变化图
本篇文的代码是针对于学会用tensorflow搭建简单的神经网络 2那篇文章的一个改进,重点在于让读者学会用matplotlib画出随机生成的一个测试数据,同时也画出了预测曲线,让程序员更加清楚地知道本训练中计算机学习的效果。 在阅读本文之前,建议先阅读《学会用tensorflow搭建简单的神经网络 2》这篇文章。
代码中红色部分为新添代码
1 #!/usr/bin/env python 2 # _*_ coding: utf-8 _*_ 3 import tensorflow as tf 4 import numpy as np 5 import matplotlib.pyplot as plt 6 #add_laye 7 def add_layer(inputs, in_size, out_size, activation_function=None): 8 # add one more layer and return the output of this layer 9 Weights = tf.Variable(tf.random_normal([in_size, out_size])) 10 biases = tf.Variable(tf.zeros([1, out_size]) + 0.1) 11 Wx_plus_b = tf.matmul(inputs, Weights) + biases 12 if activation_function is None: 13 outputs = Wx_plus_b 14 else: 15 outputs = activation_function(Wx_plus_b) 16 return outputs 17 # Make up some real data 18 x_data = np.linspace(-1,1,300)[:, np.newaxis] 19 noise = np.random.normal(0, 0.05, x_data.shape) 20 y_data = np.square(x_data) - 0.5 + noise 21 # define placeholder for inputs to network 22 xs = tf.placeholder(tf.float32, [None, 1]) 23 ys = tf.placeholder(tf.float32, [None, 1]) 24 # add hidden layer 25 layer1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu) 26 # add output layer 27 prediction = add_layer(layer1, 10, 1, activation_function=None) 28 # the error between prediciton and real data 29 loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction), 30 reduction_indices=[1])) 31 #Select the optimizer to minimize the loss 32 train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) 33 # Important step---Initializes all variables 34 init=tf.initialize_all_variables() 35 # Sess.run ---it will start the operation 36 sess=tf.Session() 37 sess.run(init) 38 #Iterations of 201 times 39 fig=plt.figure() #生成一个图片框 40 ax=fig.add_subplot(1,1,1)#生成连续性的图片框 41 ax.scatter(x_data,y_data) #将训练数据以点的形式运行出来 42 plt.ion() #连续生成图片 43 plt.show() #打印图片 44 for i in range(1000): 45 sess.run(train_step, feed_dict={xs: x_data, ys: y_data}) 46 if i % 50 == 0: 47 # to see the step improvement 48 #print(sess.run(loss, feed_dict={xs: x_data, ys: y_data})) 49 try: 50 ax.lines.remove(lines[0]) # 因为线是在不断地生成,所以要把旧的线去掉,只显示新生成的,这样方便查看结果 51 except Exception: 52 pass 53 prediction_value=sess.run(prediction,feed_dict={xs: x_data, ys: y_data}) #run prediction 54 lines=ax.plot(x_data,prediction_value,'r-',lw=5) #把prediction_value的数据plot上去,以红色的线,宽度为5 55 plt.pause(0.2) #每次新生成的线暂停0.2s
结果:截取其中一张图片

浙公网安备 33010602011771号