Pytorch-基础入门之ANN
在这部分中来介绍下ANN的Pytorch,这里的ANN具有三个隐含层。
这一块的话与上一篇逻辑斯蒂回归使用的是相同的数据集MNIST。
第一部分:构造模型
# Import Libraries
import torch
import torch.nn as nn
from torch.autograd import Variable
# Create ANN Model
class ANNModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(ANNModel, self).__init__()
# Linear function 1: 784 --> 150
self.fc1 = nn.Linear(input_dim, hidden_dim)
# Non-linearity 1
self.relu1 = nn.ReLU()
# Linear function 2: 150 --> 150
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
# Non-linearity 2
self.tanh2 = nn.Tanh()
# Linear function 3: 150 --> 150
self.fc3 = nn.Linear(hidden_dim, hidden_dim)
# Non-linearity 3
self.elu3 = nn.ELU()
# Linear function 4 (readout): 150 --> 10
self.fc4 = nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# Linear function 1
out = self.fc1(x)
# Non-linearity 1
out = self.relu1(out)
# Linear function 2
out = self.fc2(out)
# Non-linearity 2
out = self.tanh2(out)
# Linear function 2
out = self.fc3(out)
# Non-linearity 2
out = self.elu3(out)
# Linear function 4 (readout)
out = self.fc4(out)
return out
# instantiate ANN
input_dim = 28*28
hidden_dim = 150 #hidden layer dim is one of the hyper parameter and it should be chosen and tuned. For now I only say 150 there is no reason.
output_dim = 10
# Create ANN
model = ANNModel(input_dim, hidden_dim, output_dim)
# Cross Entropy Loss
error = nn.CrossEntropyLoss()
# SGD Optimizer
learning_rate = 0.02
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
第二部分:训练模型
# ANN model training
count = 0
loss_list = []
iteration_list = []
accuracy_list = []
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
train = Variable(images.view(-1, 28*28))
labels = Variable(labels)
# Clear gradients
optimizer.zero_grad()
# Forward propagation
outputs = model(train)
# Calculate softmax and ross entropy loss
loss = error(outputs, labels)
# Calculating gradients
loss.backward()
# Update parameters
optimizer.step()
count += 1
if count % 50 == 0:
# Calculate Accuracy
correct = 0
total = 0
# Predict test dataset
for images, labels in test_loader:
test = Variable(images.view(-1, 28*28))
# Forward propagation
outputs = model(test)
# Get predictions from the maximum value
predicted = torch.max(outputs.data, 1)[1]
# Total number of labels
total += len(labels)
# Total correct predictions
correct += (predicted == labels).sum()
accuracy = 100 * correct / float(total)
# store loss and iteration
loss_list.append(loss.data)
iteration_list.append(count)
accuracy_list.append(accuracy)
if count % 500 == 0:
# Print Loss
print('Iteration: {} Loss: {} Accuracy: {} %'.format(count, loss.data, accuracy))
结果:
Iteration: 500 Loss: 0.8311067223548889 Accuracy: 77 % Iteration: 1000 Loss: 0.4767582416534424 Accuracy: 87 % Iteration: 1500 Loss: 0.21807175874710083 Accuracy: 89 % Iteration: 2000 Loss: 0.2915269732475281 Accuracy: 90 % Iteration: 2500 Loss: 0.3073478937149048 Accuracy: 91 % Iteration: 3000 Loss: 0.12328791618347168 Accuracy: 92 % Iteration: 3500 Loss: 0.24098418653011322 Accuracy: 93 % Iteration: 4000 Loss: 0.06471655517816544 Accuracy: 93 % Iteration: 4500 Loss: 0.3368555009365082 Accuracy: 94 % Iteration: 5000 Loss: 0.12026549130678177 Accuracy: 94 % Iteration: 5500 Loss: 0.217212975025177 Accuracy: 94 % Iteration: 6000 Loss: 0.20914879441261292 Accuracy: 94 % Iteration: 6500 Loss: 0.10008767992258072 Accuracy: 95 % Iteration: 7000 Loss: 0.13490895926952362 Accuracy: 95 % Iteration: 7500 Loss: 0.11741413176059723 Accuracy: 95 % Iteration: 8000 Loss: 0.17519493401050568 Accuracy: 95 % Iteration: 8500 Loss: 0.06657659262418747 Accuracy: 95 % Iteration: 9000 Loss: 0.05512683466076851 Accuracy: 95 % Iteration: 9500 Loss: 0.02535334974527359 Accuracy: 96 %
第三部分:可视化展示
# visualization loss
plt.plot(iteration_list,loss_list)
plt.xlabel("Number of iteration")
plt.ylabel("Loss")
plt.title("ANN: Loss vs Number of iteration")
plt.show()
# visualization accuracy
plt.plot(iteration_list,accuracy_list,color = "red")
plt.xlabel("Number of iteration")
plt.ylabel("Accuracy")
plt.title("ANN: Accuracy vs Number of iteration")
plt.show()
结果:


浙公网安备 33010602011771号