pytorch中学习率的调整方法

一、手动法

二、利用lr_scheduler()提供的集中衰减函数

2.1 利用lr_lambda函数

具体使用:

from torch.optim import SGD, lr_scheduler
import matplotlib.pyplot as plt
from torch.nn import Module, Sequential, Linear, CrossEntropyLoss


# 定义网络模型
class model(Module):
def __init__(self):
super(model, self).__init__()
self.fc = Sequential(
Linear(1,10)
)

def forward(self, input):
output = self.fc(input)
return output

# 初始化网络模型
Model = model()
# 定义损失函数
Loss = CrossEntropyLoss()
# 创建优化器
lr = 0.01
optimizer = SGD(Model.parameters(), lr=lr)
# 定义一个list保存学习率
lr_list = []

# 定义学习率与轮数关系的函数
lambda1 = lambda epoch:0.95 ** epoch # 学习率 = 0.95**(轮数)
scheduler = lr_scheduler.LambdaLR(optimizer,lr_lambda = lambda1)

for epoch in range(100):
print("epoch={}, lr={}".format(epoch, optimizer.state_dict()['param_groups'][0]['lr']))
scheduler.step()
lr_list.append(optimizer.state_dict()['param_groups'][0]['lr'])

plt.plot(range(100),lr_list,color = 'r',label = 'LambdaLR')
plt.ylabel('learning rate')
plt.xlabel('epoch')
plt.legend()
plt.show()

posted on 2023-02-20 21:37  陈酉西  阅读(57)  评论(0)    收藏  举报