牛客题解 | 使用正规方程的线性回归

题目

题目链接

线性回归是一类回归问题,其目标是通过找到一组参数,使得输入数据和输出数据之间的线性关系尽可能地接近。其数学表达式为:

\[y = X \times w \]

其中,\(X\) 是输入矩阵,\(w\) 是回归系数,\(y\) 是输出矩阵。
而正规方程是一种求解线性回归问题的方法,它通过求解矩阵的逆来得到回归系数。其具体步骤如下:

1. 初始化矩阵

  • 创建一个与输入矩阵 \(X\) 和输出矩阵 \(y\) 相关的矩阵 \(A\)
  • 数学表达式为:

\[A = X^T \times X \]

2. 求解回归系数

  • 通过求解矩阵 \(A\) 的逆来得到回归系数。
  • 数学表达式为:

\[w = A^{-1} \times X^T \times y \]

3. 返回回归系数

  • 将计算得到的回归系数返回。

标准代码如下

def linear_regression_normal_equation(X: list[list[float]], y: list[float]) -> list[float]:
    X = np.array(X)
    # 将y转换为列向量
    y = np.array(y).reshape(-1, 1)
    X_transpose = X.T
    # 计算正规方程的解
    theta = np.linalg.inv(X_transpose.dot(X)).dot(X_transpose).dot(y)
    theta = np.round(theta, 4).flatten().tolist()
    return theta
posted @ 2025-03-12 12:45  wangxiaoxiao  阅读(39)  评论(0)    收藏  举报