测试

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms


def get_num_correct(preds, labels):
    return preds.argmax(dim=1).eq(labels).sum().item()


class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()
        # 定义卷积层,其中参数包括输入通道,输出通道,以及核个数
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5)
        # 第一个卷积层的输入通道依赖于图片的色彩,灰色为1彩色为3
        self.conv2 = nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5)

    # 定义线性层(全连接层),其中参数包括输入参数,输出参数
        self.fc1 = nn.Linear(in_features=12*4*4, out_features=120)
        self.fc2 = nn.Linear(in_features=120, out_features=60)
        self.out = nn.Linear(in_features=60, out_features=10)  # 输出的输出特征依赖于训练集中类的个数


# 注:卷积层的所有输入通道都依赖于上一层的输出通道,线性层的所有输入特征都依赖于上一层的输出特征
def forward(self, t):
    # (1) input layer
    t = t

    # (2) hidden conv layer
    t = self.conv1(t)
    t = F.relu(t)
    t = F.max_pool2d(t, kernel_size=2, stride=2)

    # (3) hidden conv layer
    t = self.conv2(t)
    t = F.relu(t)
    t = F.max_pool2d(t, kernel_size=2, stride=2)

    # (4) hidden linear layer
    t = t.reshape(-1, 12 * 4 * 4)  # 张量重塑
    t = self.fc1(t)
    t = F.relu(t)

    # (5) hidden linear layer
    t = self.fc2(t)
    t = F.relu(t)

    # (6) output layer
    t = self.out(t)
    return t


network = Network()

train_set = torchvision.datasets.FashionMNIST(
    root='./data/FashionMNIST',  # 设置下载路径
    train=True,
    download=True,
    transform=transforms.Compose([
        transforms.ToTensor
    ])
)

train_loader = torch.utils.data.DataLoader(train_set, batch_size=100)
optimizer = optim.Adam(network.parameters(), lr=0.01)

total_loss = 0
total_correct = 0
for batch in train_loader:
    images, labels = batch

    preds = network(images)
    loss = F.cross_entropy(preds, labels)

    optimizer.zero_grad()
    loss.backward()  # 更新梯度
    optimizer.step()  # 更新权重

    total_loss += loss.item()
    total_correct += get_num_correct(preds, labels)
print('当前准确率为:{}'.format(total_correct / len(train_set)))

posted @ 2021-10-09 20:28  WangSir_Code  阅读(35)  评论(0)    收藏  举报