torch基础学习三
In [1]:
# eg:线性回归
import numpy as np
import torch
import torch.nn as nn
class LinearRegressionModel(nn.Module):
"""
// 线性回归模型
// 其实线性回归就是一个不加激活函数的连接层
"""
def __init__(self, input_dim, putput_dim):
super(LinearRegressionModel, self).__init__()
self.linear = nn.Linear(input_dim, putput_dim)
def forward(self, x):
out = self.linear(x)
return out
# 构造一组数据X和其对应的标签Y
x_values = [i for i in range(11)]
x_train = np.array(x_values, dtype=np.float32)
# 转换成矩阵
x_train = x_train.reshape(-1, 1)
y_values = [2*i + 1 for i in x_train]
y_train = np.array(y_values, dtype=np.float32)
input_dim = 1
output_dim = 1
model = LinearRegressionModel(input_dim, output_dim)
# model
# 指定好参数和损失函数
# 指定好参数和损失函数
epochs = 1000
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()
# 训练函数
for epoch in range(epochs):
epoch += 1
# 注意转行成tensor
inputs = torch.from_numpy(x_train)
lables = torch.from_numpy(y_train)
# 梯度清零每次迭代
optimizer.zero_grad()
# 向前传播
outputs = model(inputs)
# 计算损失率
loss =criterion(outputs, lables)
# 反向传播
loss.backward()
# 更新权重
optimizer.step()
if epoch % 50 == 0:
print("迭代次数:",epoch ,"损失率:", loss.item())
迭代次数: 50 损失率: 0.06982214003801346 迭代次数: 100 损失率: 0.03982393443584442 迭代次数: 150 损失率: 0.02271413989365101 迭代次数: 200 损失率: 0.012955266050994396 迭代次数: 250 损失率: 0.007389168255031109 迭代次数: 300 损失率: 0.004214527551084757 迭代次数: 350 损失率: 0.002403814345598221 迭代次数: 400 损失率: 0.0013710411731153727 迭代次数: 450 损失率: 0.000782000133767724 迭代次数: 500 损失率: 0.00044601960689760745 迭代次数: 550 损失率: 0.0002543936425354332 迭代次数: 600 损失率: 0.0001450968557037413 迭代次数: 650 损失率: 8.275517757283524e-05 迭代次数: 700 损失率: 4.7198649554047734e-05 迭代次数: 750 损失率: 2.6921472453977913e-05 迭代次数: 800 损失率: 1.5357030861196108e-05 迭代次数: 850 损失率: 8.758789590501692e-06 迭代次数: 900 损失率: 4.9950144784816075e-06 迭代次数: 950 损失率: 2.8487954750744393e-06 迭代次数: 1000 损失率: 1.625527033866092e-06
In [2]:
# 测试模型预测结果
predicted = model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
predicted
Out[2]:
array([[ 0.99762845],
[ 2.99797 ],
[ 4.998312 ],
[ 6.9986534 ],
[ 8.998995 ],
[10.999336 ],
[12.999679 ],
[15.00002 ],
[17.000362 ],
[19.000704 ],
[21.001045 ]], dtype=float32)
浙公网安备 33010602011771号