1-线性模型

学习参考视频是B站刘二大人:https://www.bilibili.com/video/BV1Y7411d7Ys/?p=1&vd_source=7251363c91f375c2f19fef8bb1beab83

1、一般监督学习的步骤:准备数据集 -> 选择模型 -> 训练 -> 应用
2、数据集划分为:训练测、验证集、测试集
3、进行模型选择,一般先选择线性模型,不行换其他的
4、进行模型的训练,选择不同的模型权重值,计算loss(单个样本),然后计算平均平方误差MSE,选择MSE最低的权重值
5、绘图除了用matplotlib,也可以用pytorch的visdom

点击查看代码
# y = w * x
import numpy as np
import matplotlib.pyplot as plt

x_data = [1.0, 2.0, 3.0, 4.0]
y_data = [2.0, 4.0, 6.0, 8.0]

def forward(x):
    return w * x

def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) * (y_pred - y)

w_list = []
mse_list = []

for w in np.arange(0.0, 4.1, 0.1): # 穷举
    print('w=', w)
    l_sum = 0
    for x, y in zip(x_data, y_data):
        y_pred_val = forward(x)
        loss_val = loss(x, y)
        l_sum += loss_val
        print('\t', x, y, y_pred_val, loss_val)
    print('MSE=', l_sum/4)
    w_list.append(w)
    mse_list.append(l_sum/4)

plt.plot(w_list, mse_list)
plt.xlabel('w')
plt.ylabel('loss')
plt.show()
点击查看代码
# y = w * x + b
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

x_data = [1.0, 2.0, 3.0]
y_data = [3.0, 5.0, 7.0]

def forward(x):
    return w * x + b

def loss(x, y):
    y_pred = forward(x)
    return (y_pred - y) ** 2

w_list = []
b_list = []
mse_list = np.zeros((40, 40), dtype=float)

for i, w in enumerate(np.arange(0.0, 4.0, 0.1)):
    for j, b in enumerate(np.arange(0.0, 4.0, 0.1)):
        print('w={}, b={}'.format(w, b))
        l_sum = 0
        for x_val, y_val in zip(x_data, y_data):
            y_pred = forward(x_val)
            loss_val = loss(x_val, y_val)
            l_sum += loss_val
            print('\t', x_val, y_val, y_pred, loss_val)
        print('MSE={}'.format(l_sum/3))
        if w == 0:
            b_list.append(b) # 保证只记录一遍b的值
        mse_list[i][j] = l_sum/3
    w_list.append(w)

X, Y = np.meshgrid(np.array(w_list), np.array(b_list))
Z = np.array(mse_list)
print(X.shape)
print(Y.shape)
print(Z.shape)

# 创建一个新的图形
fig = plt.figure()
# 添加一个3d子图
ax = fig.add_subplot(projection='3d')

ax.plot_surface(X, Y, Z, rstride=1, cstride=2, cmap='viridis')
plt.show()


posted @ 2024-08-07 20:58  不是孩子了  阅读(19)  评论(0)    收藏  举报