自建神经网络与结果可视化
来自周莫烦Tensorflow教学视频,主要内容有添加神经层,和训练结果可视化。
youtube地址: https://www.youtube.com/watch?v=nhn8B0pM9ls&list=PLXO45tsB95cKI5AIlf5TxxFPzb-0zeVZ8&index=16
1 import tensorflow as tf 2 import numpy as np 3 import os 4 os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 5 import matplotlib.pyplot as plt 6 def add_layer(inputs, in_size, out_size, activation_function = None): 7 #矩阵 8 Weights = tf.Variable(tf.random_normal([in_size,out_size])) 9 bias = tf.Variable(tf.zeros([1,out_size]) + 0.1) 10 Wx_plus_b = tf.matmul(tf.cast(inputs,tf.float32),Weights) + bias 11 if activation_function is None: 12 outputs = Wx_plus_b 13 else: 14 outputs = activation_function(Wx_plus_b) 15 return outputs 16 17 if 'session' in locals() and session is not None: 18 print('Close interactive session') 19 session.close() 20 21 x_data = np.linspace(-1,1,300)[:,np.newaxis] 22 noise = np.random.normal(0,0.05,x_data.shape) 23 y_data = np.square(x_data) - 0.5 + noise 24 xs = tf.placeholder(tf.float32,[None,1]) 25 ys = tf.placeholder(tf.float32,[None,1]) 26 27 l1 = add_layer(x_data,1,10,activation_function=tf.nn.relu) 28 prediction = add_layer(l1,10,1,activation_function=None) 29 30 loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1])) 31 train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) 32 33 init = tf.global_variables_initializer() 34 sess = tf.Session() 35 sess.run(init) 36 37 fig = plt.figure() 38 ax = fig.add_subplot(1,1,1) #连续性,编号 39 ax.scatter(x_data,y_data) 40 plt.ion() #不暂停主程序 41 plt.show() 42 43 for i in range(1000): 44 sess.run(train_step,feed_dict={xs:x_data,ys:y_data}) 45 if i % 50 == 0: 46 try: 47 ax.lines.remove(lines[0]) 48 except Exception: 49 pass 50 #print(sess.run(loss,feed_dict={xs:x_data,ys:y_data})) 51 prediction_value = sess.run(prediction, feed_dict={xs:x_data}) 52 lines = ax.plot(x_data,prediction_value,'r-',lw=5) 53 plt.pause(0.1)
结果显示:

特别注意的是其中有三行代码为:
1 if 'session' in locals() and session is not None: 2 print('Close interactive session') 3 session.close()
是因为出现了“InternalError (see above for traceback): Blas GEMM launch failed : a.shape=(300, 1), b.shape=(1, 10), m=300, n=10, k=1”的错误,stackoverflow上给的解答大致意思为session已在其他进程中出现,所以先关闭session.

浙公网安备 33010602011771号