1 import torch
2 from torch import nn
3 from torch.nn import functional as F
4 from torch import optim
5
6 import torchvision
7 from matplotlib import pyplot as plt
8
9 # 小工具
10
11 def plot_curve(data):
12 fig = plt.figure()
13 plt.plot(range(len(data)),data,color='blue')
14 plt.legend(['value'],loc='upper right')
15 plt.xlabel('step')
16 plt.tlabel('value')
17 plt.show()
18
19 def plot_image(img,label,name):
20 fig = plt.figure()
21 for i in range(6):
22 plt.subplot(2,3,i+1)
23 plt,tight_layout()
24 plt.imshow(img[i][0]*0.3081+0.1307,cmap='gray',interpolation='none')
25 plt.title("{}:{}".format(name,label[i].item()))
26 plt.xticks([])
27 plt.xticks([])
28
29 plt.show()
30
31 def one_hot(label,depth = 10):
32 out = torch.zeros(label.size(0),depth)
33 idx = torch.LongTensor(label).view(-1,1)
34 out.scatter_(dim=1,index=idx,value=1)
35 return out
36
37 # 一次加载多少图片
38 batch_size = 512
39 # step1. load dataset 数据加载
40 train_loader = torch.utils.data.DataLoader(
41 torchvision.datasets.MINST('mnist_data',train=True,download=True,
42 transform=torchvision.transforms.Compose([
43 torchvision.transfroms.ToTensor(),
44
45 torchvision.transfroms.Normalize(
46 (0.1307,),(0.3081,))
47 ])),
48 batch_size=batch_size,shuffle=True)
49 test_loader = torch.utils.data.DataLoader(
50 torchvision.datasets.MINST('mnist_data/',train=False,download=True,
51 transform=torchvision.transforms.Compose([
52 torchvision.transfroms.ToTensor(),
53 torchvision.transfroms.Normalize(
54 (0.1307,),(0.3081,))
55 ])),
56 batch_size=batch_size,shuffle=False)
57
58 # 网络创建
59 class Net(nn.Module):
60
61 def __init__(self):
62 super(Net,self).__init__()
63
64 #xw+b
65 self.fc1 = nn.Linear(28*28,256)
66 self.fc2 = nn.Linear(256,64)
67 self.fc3 = nn.Linear(64,10)
68
69 def forward(self,x):
70 # x:[batch_size,1,28,28]
71 # h1 = relu(xw1+b1)
72 x = F.relu(self.fc1(x))
73 # h1 = relu(h1w2+b2)
74 x = F.relu(self.fc2(x))
75 # h3 = h2w3+b3
76 x = self.fc3(x)
77
78 return x
79
80 net = Net()
81 # [w1,b1,w2,b1,w3,b3]
82 optimizer = optim.SGD(net.parameters(),lr=0.01,momentum=0.9)
83
84 train_loss = []
85
86 # 训练
87 for epoch in range(3):
88
89 for batch_idx,(x,y) in enumerate(train_loader):
90
91 # x: [b,1,28,28], y:[512]
92 # [b,1,28,28]-->[b,feature]
93 x = x.view(x.size(0),28*28)
94 # --> [b,10]
95 out = net(x)
96 # --> [b,10]
97 y_onehot = one_hot(y)
98 # loss = mse(out,y_onehot)
99 loss = F.mse_loss(out,y_onehot)
100 # 清零梯度
101 optimizer.zero_grad()
102 # 计算梯度
103 loss.backward()
104 #w' = w - lr*grad 更新梯度
105 optimizer.step()
106
107 train_loss.append(loss.item())
108
109 if batch_idx % 10 == 0:
110 print(epoch,batch_idx,loss.item())
111
112 plot_curve(train_loss)
113
114 # 得到一个比较好的 [w1,b1,w2,b1,w3,b3]
115
116
117 # 验证准确率
118 total_correct = 0
119 for x,y in test_loader"
120 x = x.view(x.size(0),28*28)
121 out = net(x)
122 # out: [b,10] --> pred: [b]
123 pred = out.argmax(dim = 1)
124 correct = pred.eq(y).sum().float().item()
125 total_correct += correct
126
127 total_num = len(test_loader.dataset)
128 acc = total_correct / total_num
129 print('test acc:',acc)
130
131 # 直观显示验证
132 x,y = next(iter(test_loader))
133 out = net(x.view(x.size(0),28*28))
134 pred = out.argmax(dim = 1)
135 plot_image(x,pred,'test')
136
137
138
139
140
141