1 # 手写神经网络——mnist手写数字数据集
2 import numpy as np
3 # import torch
4 import torchvision
5 import torchvision.transforms as transforms
6 # from torch.utils.data import DataLoader
7 # import cv2
8
9 # input layer:784 nodes(28*28)
10 # hidden layer:three hidden layers with 20 nodes in each layer
11 # output layer:10 nodes
12 class BP:
13 def __init__(self):
14 self.input = np.zeros((100, 784)) # 100 samples per round
15 self.hidden_layer_1 = np.zeros((100, 20))
16 self.hidden_layer_2 = np.zeros((100, 20))
17 self.hidden_layer_3 = np.zeros((100, 20))
18 self.output_layer = np.zeros((100, 10))
19 self.w1 = 2 * np.random.random((784, 20)) - 1 # limit to (-1, 1)
20 self.w2 = 2 * np.random.random((20, 20)) - 1
21 self.w3 = 2 * np.random.random((20, 20)) - 1
22 self.w4 = 2 * np.random.random((20, 10)) - 1
23 self.error = np.zeros(10)
24 self.learning_rate = 0.1
25
26 def sigmoid(self, x):
27 return 1 / (1 + np.exp(-x))
28
29 def sigmoid_deri(self, x):
30 return x * (1 - x)
31
32 def forward_prop(self, data, label): # label:100 X 10,data: 100 X 784
33 self.input = data
34 self.hidden_layer_1 = self.sigmoid(np.dot(self.input, self.w1))
35 self.hidden_layer_2 = self.sigmoid(np.dot(self.hidden_layer_1, self.w2))
36 self.hidden_layer_3 = self.sigmoid(np.dot(self.hidden_layer_2, self.w3))
37 self.output_layer = self.sigmoid(np.dot(self.hidden_layer_3, self.w4))
38 # error
39 self.error = label - self.output_layer
40 return self.output_layer
41
42 def backward_prop(self):
43 output_diff = self.error * self.sigmoid_deri(self.output_layer)
44 hidden_diff_3 = np.dot(output_diff, self.w4.T) * self.sigmoid_deri(self.hidden_layer_3)
45 hidden_diff_2 = np.dot(hidden_diff_3, self.w3.T) * self.sigmoid_deri(self.hidden_layer_2)
46 hidden_diff_1 = np.dot(hidden_diff_2, self.w2.T) * self.sigmoid_deri(self.hidden_layer_1)
47 # update
48 self.w4 += self.learning_rate * np.dot(self.hidden_layer_3.T, output_diff)
49 self.w3 += self.learning_rate * np.dot(self.hidden_layer_2.T, hidden_diff_3)
50 self.w2 += self.learning_rate * np.dot(self.hidden_layer_1.T, hidden_diff_2)
51 self.w1 += self.learning_rate * np.dot(self.input.T, hidden_diff_1)
52
53 # from torchvision load data
54 def load_data():
55 datasets_train = torchvision.datasets.MNIST(root='../../data/', train=True, transform=transforms.ToTensor()) # , download=True)
56 # print(datasets_train)
57 datasets_test = torchvision.datasets.MNIST(root='../../data/', train=False, transform=transforms.ToTensor())
58
59 data_train = datasets_train.data
60 # print(data_train)
61 X_train = data_train.numpy()
62 # print(X_train)
63 X_test = datasets_test.data.numpy()
64 X_train = np.reshape(X_train, (60000, 784))
65 X_test = np.reshape(X_test, (10000, 784))
66 Y_train = datasets_train.targets.numpy()
67 Y_test = datasets_test.targets.numpy()
68
69 real_train_y = np.zeros((60000, 10))
70 real_test_y = np.zeros((10000, 10))
71 # each y has ten dimensions
72 for i in range(60000):
73 real_train_y[i, Y_train[i]] = 1
74 for i in range(10000):
75 real_test_y[i, Y_test[i]] = 1
76 index = np.arange(60000) # 返回一个有终点和起点的固定步长的排列
77 np.random.shuffle(index) # 打乱顺序函数
78 # shuffle train_data
79 X_train = X_train[index]
80 real_train_y = real_train_y[index]
81
82 X_train = np.int64(X_train > 0)
83 X_test = np.int64(X_test > 0)
84
85
86 return X_train, real_train_y, X_test, real_test_y
87
88
89 def bp_network():
90 nn = BP()
91 X_train, Y_train, X_test, Y_test = load_data()
92 batch_size = 100
93 epochs = 6000
94 for epoch in range(epochs):
95 start = (epoch % 600) * batch_size
96 end = start + batch_size
97 # print(start, end)
98 nn.forward_prop(X_train[start: end], Y_train[start: end])
99 nn.backward_prop()
100
101 return nn
102
103
104 def bp_test():
105 nn = bp_network()
106 sum = 0
107 X_train, Y_train, X_test, Y_test = load_data()
108 # test:
109 for i in range(len(X_test)):
110 res = nn.forward_prop(X_test[i], Y_test[i])
111 res = res.tolist() # 转换为列表
112 index = res.index(max(res)) # 检测字符串中是否包含子字符串str
113 if Y_test[i, index] == 1:
114 sum += 1
115
116 print('accuracy:', sum / len(Y_test))
117
118
119 if __name__ == '__main__':
120 bp_test()