D
G
O
L

pytorch学习了解

import torchvision
from model1test import *
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

trian_data=torchvision.datasets.CIFAR10('./datasets',train=True,transform=torchvision.transforms.ToTensor())
test_data=torchvision.datasets.CIFAR10('./datasets',train=False,transform=torchvision.transforms.ToTensor())

print('训练集长度 {}'.format(len(trian_data)))
print('测试集长度 {}'.format(len(test_data)))

traindata_loader=DataLoader(trian_data,batch_size=64)
testdata_loader=DataLoader(test_data,batch_size=64)

# 创建模型
mymodel1=mymodel()
if torch.cuda.is_available():
    mymodel1=mymodel1.cuda()

# 创建损失函数
loss_fn=nn.CrossEntropyLoss()
loss_fn=loss_fn.cuda()
# 创建优化器
learing_rate=0.001
optimzer=torch.optim.SGD(mymodel1.parameters(),lr=learing_rate)


# 开始训练
# 训练次数
total_train_step=0

# 测试次数

total_test_step=0
writer=SummaryWriter('first-train')
# 训练轮数
epoch=10
for i in range(epoch):
    print('第{}轮'.format(i+1))
#     开始训练
    for data in traindata_loader:
        imgs,targets=data
        imgs=imgs.cuda()
        targets=targets.cuda()
#         用模型训练
        outputs=mymodel1(imgs)

#         计算损失
        loss=loss_fn(outputs,targets)

#         优化器优化
        optimzer.zero_grad()
        loss.backward()
        optimzer.step()
#         总步数
        total_train_step+=1
        if total_train_step % 100 ==0:
            print('训练次数{},loss{}'.format(total_train_step,loss.item()))
            writer.add_scalar('train',loss.item(),total_train_step)

    # 开始测试模型
    total_test_loss=0
    with torch.no_grad():
        for data in testdata_loader:
            imgs,targets=data
            imgs=imgs.cuda()
            targets=targets.cuda()
            outputs=mymodel1(imgs)
            loss=loss_fn(outputs,targets)
            total_test_loss=total_test_loss+loss
    print('整体测试集{}'.format(total_test_loss))

    writer.add_scalar('test',total_test_loss.item(),total_test_step)
    total_test_step+=1

    torch.save(mymodel1,'number{}'.format(i))



writer.close()
posted @ 2023-09-19 16:54  jinganglang567  阅读(15)  评论(0)    收藏  举报