牛客题解 | 实现求解线性方程组的共轭梯度法

题目

题目链接

共轭梯度法是一种求解线性方程组的迭代方法。具体步骤如下:

  1. 计算初始残差向量\(r\)

    \[r = b - Ax \]

  2. 计算初始搜索方向向量\(p\)

    \[p = r \]

  3. 迭代更新\(x\)\(r\)\(p\),直到满足收敛条件

    \[x = x + \alpha p \]

    \[r_{i+1} = r_i - \alpha (A p_i) \]

    \[bet = \frac{r_{i+1} \cdot r_{i+1}}{r_i \cdot r_i} \]

    \[p = r_{i+1} + bet p_i \]

共轭梯度法的关键在于使用正交搜索方向,确保每次迭代都能获得更多的信息,而不需要重复搜索。个中原理可以参考相关资料。

标准代码如下

import numpy as np
def conjugate_gradient(A: np.array, b: np.array, n: int, x0: np.array=None, tol=1e-8) -> np.array:
    # calculate initial residual vector
    x = np.zeros_like(b)
    r = residual(A, b, x) # residual vector
    rPlus1 = r
    p = r # search direction vector
    for i in range(n):
        # line search step value - this minimizes the error along the current search direction
        alp = alpha(A, r, p)
        # new x and r based on current p (the search direction vector)
        x = x + alp * p
        rPlus1 = r - alp * (A@p)
        # calculate beta - this ensures that all vectors are A-orthogonal to each other
        bet = beta(r, rPlus1)
        # update x and r
        # using a othogonal search direction ensures we get all the information we need in more direction and then don't have to search in that direction again
        p = rPlus1 + bet * p
        # update residual vector
        r = rPlus1
        # break if less than tolerance
        if np.linalg.norm(residual(A,b,x)) < tol:
            break

    return x
def residual(A: np.array, b: np.array, x: np.array) -> np.array:
    # calculate linear system residuals
    return b - A @ x
def alpha(A: np.array, r: np.array, p: np.array) -> float:

    # calculate step size
    alpha_num = np.dot(r, r)
    alpha_den = np.dot(p @ A, p)

    return alpha_num/alpha_den
def beta(r: np.array, r_plus1: np.array) -> float:

    # calculate direction scaling
    beta_num = np.dot(r_plus1, r_plus1)
    beta_den = np.dot(r, r)

    return beta_num/beta_den
if __name__ == "__main__":
    A = eval(input())
    b = eval(input())
    n = int(input())
    print(conjugate_gradient(A, b, n))
posted @ 2025-03-12 13:01  wangxiaoxiao  阅读(37)  评论(0)    收藏  举报