1.导入包
2.定义Layer
1 class Layer(object): 2 3 def __init__(self, inputs, in_size, out_size, activation_function=None): 4 self.W = theano.shared(np.random.normal(0, 1, (in_size, out_size))) 5 self.b = theano.shared(np.zeros((out_size, )) + 0.1) 6 self.Wx_plus_b = T.dot(inputs, self.W) + self.b 7 self.activation_function = activation_function 8 if activation_function is None: 9 self.outputs = self.Wx_plus_b 10 else: 11 self.outputs = self.activation_function(self.Wx_plus_b)
3.Make up some fake data虚拟数据
1 x_data = np.linspace(-1, 1, 300)[:, np.newaxis] 2 noise = np.random.normal(0, 0.05, x_data.shape) 3 y_data = np.square(x_data) - 0.5 + noise # y = x^2 - 0.5
4.show the fake data可视化虚拟数据
1 plt.scatter(x_data, y_data) 2 plt.show()
5.determind the inputs dtype
1 x = T.dmatrix("x") 2 y = T.dmatrix("y")
6.网络搭建
1 l1 = Layer(x, 1, 10, T.nnet.relu) 2 l2 = Layer(l1.outputs, 10, 1, None)
7.计算Loss
1 # compute the loss 2 cost = T.mean(T.square(l2.outputs - y)) 3 # compute the gradients 4 gW1, gb1, gW2, gb2 = T.grad(cost, [l1.W, l1.b, l2.W, l2.b]) 5 #grad是梯度计算
8.梯度下降
1 learning_rate = 0.05 2 #学习率<1 3 train = theano.function( 4 inputs=[x, y], 5 outputs=[cost], 6 updates=[(l1.W, l1.W - learning_rate * gW1), 7 (l1.b, l1.b - learning_rate * gb1), 8 (l2.W, l2.W - learning_rate * gW2), 9 (l2.b, l2.b - learning_rate * gb2)]) 10 #updates更新参数
9.预测
1 predict = theano.function(inputs=[x], outputs=l2.outputs) 2 #输入x预测y
10.训练
1 for i in range(1000): 2 # training 3 err = train(x_data, y_data) #放入数据 4 if i % 50 == 0: 5 print(err)
11.可视化结果
1 # plot the real data 2 fig = plt.figure() 3 ax.scatter(x_data, y_data) #scatter散点形式 4 plt.show() 5 6 ax = fig.add_subplot(1,1,1) #从画布的1行1列第1点开始 7 plt.ion() #显示画布后不停止运行程序(终止block) 8 plt.show()
1 for i in range(1000): 2 # training 3 err = train(x_data, y_data) #放入数据 4 if i % 50 == 0: 5 # to visualize the result and improvement 6 try: #在第一次时跳过remove命令 7 ax.lines.remove(lines[0]) #清除上一次的线 8 except Exception: 9 pass 10 prediction_value = predict(x_data) 11 # plot the prediction 12 lines = ax.plot(x_data, prediction_value, 'r-', lw=5) 13 #r red - 线 lw 粗细 14 plt.pause(.5) 15 #暂停0.5s
浙公网安备 33010602011771号