小土堆pytorch笔记

I 验证网络结构是否有误

  1. 初始化一个符合网络的输入数据

    input = torch.ones((64, 3, 32, 32))

  2. 将输入数据传进网络,看是否报错

    print(network(input).shape)

II 修改已知网络(比如vgg16)

vgg16_false = torchvision.models.vgg16(weights=None)   
vgg16_true = torchvision.models.vgg16(weights=VGG16_Weights.DEFAULT)  
# 添加模块  
vgg16_false.classifier.add_module(name="add_linear", module=nn.Linear(in_features=1000, out_features=10))  
# 修改vgg的classifier的第7个模块  
vgg16_true.classifier[6] = nn.Linear(4096, 10)

III 模型保存及对应加载方法

# 保存模型方法1:模型结构+模型参数。  
def save1(model, filename):  
    torch.save(model, filename)


# 对应使用save1()保存的模型的加载方法1。注意:要让加载的模型可被该方法访问。
def load1(filename):
    return torch.load(filename)


# 保存模型方法2:保存模型参数(官方推荐)
def save2(model, filename):
    torch.save(model.state_dict(), filename)


# 对应使用save2()保存的模型的加载方法2
def load2(model, filename):
    state_dict = torch.load(filename)
    model.load_state_dict(state_dict)
    return model

IV 训练模型步骤

0. 定义训练的设备

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

1. 设置数据集

1.1 定义数据集

train_dataset = torchvision.datasets.CIFAR10(root="../dataset", train=True, transform=torchvision.transforms.ToTensor(),
                                             download=True)
test_dataset = torchvision.datasets.CIFAR10(root="../dataset", train=False, transform=torchvision.transforms.ToTensor(),
                                            download=True)

1.2 定义数据集相关参数

train_dataset_length = len(train_dataset)
test_dataset_length = len(test_dataset)

1.3 利用DataLoader加载数据集

train_dataloader = DataLoader(train_dataset, batch_size=64)
test_dataloader = DataLoader(test_dataset, batch_size=64)

2. 设置模型

2.1 创建模型

tudui = Tudui()
tudui.to(device) # 可以不写成tudui = tudui.to(device)

2.2 定义训练相关参数

epoch = 10  # 训练轮数
train_total_step = 0  # 训练总次数
test_total_step = 0  # 测试总次数

2.3 设置tensorboard

writer = SummaryWriter(log_dir="../logs-train")

3. 定义损失函数

loss_fn = nn.CrossEntropyLoss()
loss_fn.to(device)

4. 定义优化器

learning_rate = 1e-2  # 学习率
optim = torch.optim.SGD(params=tudui.parameters(), lr=learning_rate)

5. 训练模型

for i in range(epoch):
    # 5.1 开始训练
    tudui.train() # 对有drop等层有效
    for inputs, targets in train_dataloader:
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = tudui(inputs)
        # 计算损失值
        loss = loss_fn(outputs, targets)
        # 使用优化器优化模型
        optim.zero_grad()
        loss.backward()
        optim.step()
        # 记录训练情况
        train_total_step += 1
        if train_total_step % 100 == 0:
            writer.add_scalar(tag="time", scalar_value=end_time-start_time, global_step=train_total_step)
            print("训练次数:{}, loss:{}".format(train_total_step, loss.item()))
            writer.add_scalar(tag="train_loss", scalar_value=loss.item(), global_step=train_total_step)

    # 5.2 测试网络
    tudui.eval()
    test_total_accuracy = 0
    test_total_loss = 0
    with torch.no_grad():
        for inputs, targets in test_dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = tudui(inputs)
            loss = loss_fn(outputs, targets)
            test_total_loss += loss.item()
            accuracy = (outputs.argmax(1) == targets).sum()
            test_total_accuracy += accuracy

    print("整体测试集loss:{}, accuracy:{}".format(test_total_loss, test_total_accuracy / test_dataset_length))
    writer.add_scalar(tag="test_loss", scalar_value=test_total_loss, global_step=i)
    writer.add_scalar(tag="test_accuracy", scalar_value=test_total_accuracy / test_dataset_length, global_step=i)

    torch.save(tudui, "tudui_{}.pth".format(i))
    print("模型文件:tudui_{}.pth已保存".format(i))

writer.close()

V 完整测试模型步骤

1. 导入图片

image = Image.open("../imgs/airplane.png")

1.1 调整图片通道数

由于png图片有4通道(多了一个透明度通道),故须转为RGB图片的三通道,从而适应jpg、png等各种类型的图片
image = image.convert("RGB")

1.2 调整图片尺寸大小和数据类型

trans = transforms.Compose([transforms.Resize((32, 32)),
                    transforms.ToTensor()])
image = trans(image)

1.3 调整图片维数

增加一个batch size维度
image = torch.reshape(image, (1, 3, 32, 32))

2. 加载训练好的模型权重

modle = torch.load("xx.pth")

3. 测试模型并打印输出

modle.eval()
with torch.no_grad():
    output = modle(image)
# 打印出每行最大的列索引值
print(output.argmax(1))
posted @ 2023-02-01 09:57  Kurie  阅读(120)  评论(0编辑  收藏  举报