牛客题解 | 使用正规方程的线性回归
题目
线性回归是一类回归问题,其目标是通过找到一组参数,使得输入数据和输出数据之间的线性关系尽可能地接近。其数学表达式为:
\[y = X \times w
\]
其中,\(X\) 是输入矩阵,\(w\) 是回归系数,\(y\) 是输出矩阵。
而正规方程是一种求解线性回归问题的方法,它通过求解矩阵的逆来得到回归系数。其具体步骤如下:
1. 初始化矩阵
- 创建一个与输入矩阵 \(X\) 和输出矩阵 \(y\) 相关的矩阵 \(A\)。
- 数学表达式为:
\[A = X^T \times X
\]
2. 求解回归系数
- 通过求解矩阵 \(A\) 的逆来得到回归系数。
- 数学表达式为:
\[w = A^{-1} \times X^T \times y
\]
3. 返回回归系数
- 将计算得到的回归系数返回。
标准代码如下
def linear_regression_normal_equation(X: list[list[float]], y: list[float]) -> list[float]:
X = np.array(X)
# 将y转换为列向量
y = np.array(y).reshape(-1, 1)
X_transpose = X.T
# 计算正规方程的解
theta = np.linalg.inv(X_transpose.dot(X)).dot(X_transpose).dot(y)
theta = np.round(theta, 4).flatten().tolist()
return theta

浙公网安备 33010602011771号