1 import numpy
2 #激活函数库
3 import scipy.special
4
5 import matplotlib.pyplot
6
7 #neutral network class definition
8 class neutralNetwork:
9 def __init__(self,inputnodes,hiddennodes,outputnodes,learningrate):
10 #定义各个节点
11 self.inodes=inputnodes
12 self.hnodes=hiddennodes
13 self.onodes=outputnodes
14
15 #初始化权重矩阵(利用正态分布)
16 self.win=numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.hnodes,self.inodes))
17 self.who=numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.onodes,self.hnodes))
18
19 #定义激活函数
20 self.activation_function=lambda x: scipy.special.expit(x)
21
22 #初始化学习率
23 self.lr=learningrate
24 pass
25
26 #训练网络并更新权重
27 def train(self,inputs_list,targets_list):
28 inputs=numpy.array(inputs_list,ndmin=2).T
29 targets=numpy.array(targets_list,ndmin=2).T
30
31 hidden_inputs=numpy.dot(self.win,inputs)
32 hidden_outputs=self.activation_function(hidden_inputs)
33
34 final_inputs=numpy.dot(self.who,hidden_outputs)
35 final_outputs=self.activation_function(final_inputs)
36
37 output_errors=targets-final_outputs
38 hidden_errors=numpy.dot(self.who.T,output_errors)
39
40 self.who+=self.lr*numpy.dot((output_errors*final_outputs*(1.0-final_outputs)),numpy.transpose(hidden_outputs))
41 self.win+=self.lr*numpy.dot((hidden_errors*hidden_outputs*(1.0-hidden_outputs)),numpy.transpose(inputs))
42
43 pass
44
45 #查询每次输出结果
46 def query(self,inputs_list):
47 inputs=numpy.array(inputs_list,ndmin=2).T
48
49 hidden_inputs=numpy.dot(self.win,inputs)
50 hidden_outputs=self.activation_function(hidden_inputs)
51
52 final_inputs=numpy.dot(self.who,hidden_outputs)
53 final_outputs=self.activation_function(final_inputs)
54
55 return final_outputs
56 pass
57
58 #inputnode是像素的大小28*28
59 input_nodes=784
60 #选择比inputnode小的,强迫网络总结输入主要特点
61 hidden_nodes=100
62 #手写一共十个数字,所以设置outputnode为10
63 output_nodes=10
64
65 learning_rate=0.3
66
67 #训练2世代(太大会过度拟合)
68 epoches=2
69
70 n=neutralNetwork(input_nodes,hidden_nodes,output_nodes,learning_rate)
71
72 #加载mnist训练集
73 training_data_file=open("mnist_train.csv",'r')
74 training_data_list=training_data_file.readlines()
75 training_data_file.close()
76
77 #用训练集训练网络
78 for e in range(epoches):
79 for record in training_data_list:
80 all_values=record.split(',')
81
82 #转化成input矩阵格式(非0:会造成网络崩溃;除以最大像素是255得到0.01-0.99;激活函数不能达到1)
83 inputs=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
84
85 #设置目标输出:不能为0和1,否则会存在饱和网络(为了无限接近不可能的值0和1)
86 targets=numpy.zeros(output_nodes)+0.01
87 targets[int(all_values[0])]=0.99
88 n.train(inputs,targets)
89 pass
90 pass
91
92 #测试网络
93 test_data_file=open("mnist_test.csv",'r')
94 test_data_list=test_data_file.readlines()
95 test_data_file.close()
96
97 scorecard=[]
98
99 for record in test_data_list:
100 all_values=record.split(',')
101 correct_label=int(all_values[0])
102 print(correct_label,"correct label")
103 image_array=numpy.asfarray(all_values[1:]).reshape((28,28))
104 matplotlib.pyplot.imshow(image_array,cmap='Greys',interpolation='None')
105
106 inputs=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
107
108 outputs=n.query(inputs)
109
110 label=numpy.argmax(outputs)
111 print(label,"network's answer:")
112
113 if(label==correct_label):
114 scorecard.append(1)
115 else:
116 scorecard.append(0)
117 pass
118
119 scorecard_array=numpy.asfarray(scorecard)
120 print("performance=",scorecard_array.sum()/scorecard_array.size)