小土堆pytorch学习—P24-优化器
优化器
是什么?
优化器根据反向传播的梯度对参数进行调整,达到调优(降低整体误差)的目的。
怎么用?
提供的优化器很多,参数也不尽相同。大部分前两个一致。model.parameters()&lr(learn rate)。
parameter作用:优化器知道模型结构是什么,可以调节的参数有哪些
lr-learning rate作用:学习速率
优化器中有step方法,是用损失梯度计算之后对参数进行更新。 调整之后再次循环。案例👇
for input , target in dataset:
optimizer.zero_grad()#梯度清零,对loss.backward()计算出的梯度,以防止在本次循环中造成影响。因为每次都是取一个batch_size数量的样本进行训练
output = model(input)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
案例完整代码👇
import torch.optim
import torchvision
from torch import nn
from torch.utils.data import DataLoader
dataset = torchvision.datasets.CIFAR10(root = "hymenoptera_data/val/CIFAR10", train = False , transform = torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset = dataset , batch_size =2)
class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(3,32,5,padding = 2),
nn.MaxPool2d(kernel_size = 2),
nn.Conv2d(32,32,5,padding = 2),
nn.MaxPool2d(kernel_size = 2),
nn.Conv2d(32 , 64 , 5 ,padding = 2),
nn.MaxPool2d(kernel_size = 2),
nn.Flatten(),
nn.Linear(in_features = 1024 ,
out_features = 64),
nn.Linear(in_features = 64 ,
out_features = 10)
)
def forward(self , input):
input = self.model(input)
return input
tudui = Tudui()
loss = nn.CrossEntropyLoss()
optim = torch.optim.SGD(params = tudui.parameters(),
#随机梯度下降算法,可以直接使用模型的parameters方法
lr=0.01
#学习速率设置为0.01,一开始可以设置较大的学习速率,之后设置较小的学习速率
)
i = 1
for epoch in range(20):
sum_loss = 0.0
#每一轮的所有损失的求和,一共设置循环了20轮
for data in dataloader:#该循环相当于只对数据进行了一轮的学习,需要循环多轮
imgs , targets = data
outputs = tudui(imgs)
result_loss = loss(outputs , targets )
#step1:调节模型中每个参数梯度调节为0
optim.zero_grad()
#step2:优化器需要每个参数的梯度
result_loss.backward()
#step3:对每个参数进行调优
optim.step()
# print(f"result_loss = {result_loss}")
sum_loss = sum_loss+result_loss
print(f"第{i}次的损失为:{sum_loss}")
i+=1
#调试过程👇
#运行到optim.zero_grad()==》控制台中的tudui==》
#Protected Attributes==》_modules==》model==》
#Proteceted Attributes ==》_modules==》'0'Conv2d
#==》weight
#看到grad就是None
效果是什么?

这个地方训练到10次损失值就反弹回来一直上升了,这说明了什么问题?

浙公网安备 33010602011771号