pytorch(04)简单的线性回归

线性回归

线性回归是分析一个变量与另外一个变量之间关系的方法
因变量:y 自变量:x 关系:线性
y = wx+b
分析:求解w,b
求解步骤:

  1. 确定模型,Model:y = wx+b
  2. 选择损失函数,MSE:

\[\frac{1}{m}\sum^{m}_{i=1}(y_i-\hat{y_i}) \]

  1. 求解梯度并更新w,b
    w = w - LR* w.grad
    b = b -LR* b.grad
import os
import torch
import matplotlib.pyplot as plt
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
lr = 0.0005
torch.manual_seed(4)
x = torch.randn(80, 1) * 20
y = 5 * x + (6 + 10*torch.randn(80, 1))

w = torch.randn((1), requires_grad= True)
b = torch.zeros((1), requires_grad= True)

for i in range(1000):
    wx = torch.mul(w,x)
    wxb = torch.add(wx, b)
    # wxb = torch.addcmul(b, x, w, value=1)
    loss1 =  (0.5 * (y - wxb) ** 2).mean()
    loss1.backward()
    b.data.sub_(lr * b.grad)
    w.data.sub_(lr * w.grad)

    b.grad.zero_()
    w.grad.zero_()
    if i % 1 == 0:
        plt.scatter(x.data.numpy(), y.data.numpy())
        plt.plot(x.data.numpy(), wxb.data.numpy(), 'r-')
        plt.title("the loss:{}".format(loss1.data.numpy()))
        plt.xlim(-50,50)
        plt.ylim(-250,250)
        plt.pause(0.4)
        plt.clf()

        if loss1.data.numpy() < 70:
            break

posted @ 2020-12-14 17:43  笨喵敲代码  阅读(95)  评论(0)    收藏  举报