【小白学算法】矩阵快速幂超详细解析+例题[HDU - 2802]

用于高效解决线性递推问题

前言

在算法竞赛和实际编程中,我们经常遇到需要计算矩阵的高次幂的问题。如果直接用朴素的矩阵乘法来计算,时间复杂度会达到O(n³ × k),其中n是矩阵的维度,k是幂次。当k非常大时(比如\(10^9\)),这样的时间复杂度是无法接受的。矩阵快速幂就是为了解决这个问题而诞生的算法。

什么是矩阵快速幂?

矩阵快速幂就是快速幂算法再矩阵运算中的应用。它基于这样的思想:

\(A^8\) = \((( A^2 )^2)^2\)

\(A^9\)=\(A × A^8\)

通过将指数进行二进制分解,我们可以在O(log k)次矩阵乘法内计算出\(A^k\)

快速幂基本原理

以计算a^n为例:

  • 如果n是偶数:\(a^n\) = \((a^{n/2})^2\)

  • 如果n是奇数:\(a^n\) = a ×$ a^{n-1}$

这个过程可以用递归或迭代实现。

矩阵快速幂的实现步骤

矩阵定义

struct matrix {
    int mat[6][6];
    void init() {
        memset(mat, 0, sizeof(mat));
    }
};

矩阵乘法

由于幂的过程中,我们要实现一个矩阵的乘法,所以首先要将矩阵乘法写出;

matrix mul(matrix a, matrix b) {   //return a*b
    matrix c;
    c.init();
    for (int i = 0; i < 6; i++) {
        for (int j = 0; j < 6; j++) {
            for (int k = 0; k < 6; k++) {
                c.mat[i][j] += ((a.mat[i][k] % mod) * (b.mat[k][j] % mod)) % mod;
                c.mat[i][j] %= mod;
            }
        }
    }
    return c;
}

矩阵快速幂

matrix fast_pow(matrix A, int n) {   //return A^n%mod
    matrix B;
    B.init();
    for (int i = 0; i < 6; i++) {   //单位矩阵
        B.mat[i][i] = 1;
    }
    while (n) {
        if (n & 1) {
            B = mul(B, A);
        }
        A = mul(A, A);
        n >>= 1;
    }
    return B;
}

例题HDU - 2802

本题就是根据递推公式,求其中的某一项即F(n),那么可以用矩阵快速幂来解决此题(本题还可以打表找出规律求解,这里只介绍矩阵快速幂方法^^)

在写之前先求出转移矩阵是此题的关键

递推式化简:

\(F(N)=F(N−2)+N^3−(N−1)^3\)

展开:

\(N^3−(N−1)^3=3*N^2−3*N+16\)

所以:

\(F(N)=F(N−2)+3*N^2−3*N+1\)

状态向量构造:

我们需要同时存$ F(N)\(,以及和 N 相关的多项式项(\)N_2\(, N, 常数)。 因为转移里有\) F(N−2)$,所以状态至少要带上 \(F(N)\),\(F(N−1)\)

定义:

\[S(n) = \begin{bmatrix} F(n) \\ F(n-1) \\ n^2 \\ n \\ 1 \end{bmatrix}\]

转移矩阵:

根据递推式:

\(F(N)=F(N−2)+3*N^2−3*N+1\)

所以:

  • \(F(N)\) 依赖 \(F(N−2)\),而$ F(N−2) $就是上一步向量里的第二维。

  • 多项式部分直接线性组合 \(N_2\),\(N\),1。

然后要写出矩阵,把 \(S(N)\) 表达为 \(M⋅S(N−1)\)

完整代码

#include <iostream>
#include <cstring>
#define int long long
using namespace std;

const int mod = 2009;

struct matrix {
    int mat[6][6];
    void init() {
        memset(mat, 0, sizeof(mat));
    }
};

matrix mul(matrix a, matrix b) {   //return a*b
    matrix c;
    c.init();
    for (int i = 0; i < 6; i++) {
        for (int j = 0; j < 6; j++) {
            for (int k = 0; k < 6; k++) {
                c.mat[i][j] += ((a.mat[i][k] % mod) * (b.mat[k][j] % mod)) % mod;
                c.mat[i][j] %= mod;
            }
        }
    }
    return c;
}

matrix fast_pow(matrix A, int n) {   //return A^n%mod
    matrix B;
    B.init();
    for (int i = 0; i < 6; i++) {   //单位矩阵
        B.mat[i][i] = 1;
    }
    while (n) {
        if (n & 1) {
            B = mul(B, A);
        }
        A = mul(A, A);
        n >>= 1;
    }
    return B;
}

signed main() {
    int N;
    while (cin >> N && N) {
        if (N == 1) {
            cout << 1 % mod << "\n";
            continue;
        }
        if (N == 2) {
            cout << 7 % mod << "\n";
            continue;
        }

        // 状态向量: [F(n), F(n-1), n^2, n, 1, dummy]
        // 6维里最后一个可以闲置
        // 我们要求 F(N),利用矩阵快速幂推到目标

        // 转移矩阵 M: S(k) -> S(k+1)
        matrix M;
        M.init();

        // F(k+1) = F(k-1) + 3(k+1)^2 - 3(k+1) + 1
        // => 依赖 F(k-1) 以及 (k+1)^2, (k+1), 1
        // 这里直接手写对应项
        // [F(k+1)] = [0 1 3 -3 1 0] * S(k)
        // [F(k)]   = [1 0 0 0 0 0] * S(k)
        // [ (k+1)^2 ] = 转移自 N^2, N, 1
        // [ (k+1) ]   = ...
        // [ 1 ]       = [0 0 0 0 1 0]

        M.mat[0][1] = 1;  // F(k-1)
        M.mat[0][2] = 3;  // +3*N^2
        M.mat[0][3] = 3; // -3*N
        M.mat[0][4] = 1;  // +1

        M.mat[1][0] = 1;  // F(k)

        // (k+1)^2 = N^2 + 2N + 1
        M.mat[2][2] = 1;
        M.mat[2][3] = 2;
        M.mat[2][4] = 1;

        // (k+1) = N + 1
        M.mat[3][3] = 1;
        M.mat[3][4] = 1;

        M.mat[4][4] = 1; // 常数项保持 1

        // 取模修正负数
        for (int i = 0; i < 6; i++) {
            for (int j = 0; j < 6; j++) {
                M.mat[i][j] = (M.mat[i][j] % mod + mod) % mod;
            }
        }

        // 初始状态 S(2)
        int S[6] = {7, 1, 4, 2, 1, 0}; // [F(2)=7, F(1)=1, 2^2=4, 2, 1, 0]

        // 快速幂:从 S(2) 推到 S(N)
        matrix P = fast_pow(M, N - 2);

        // 结果向量 = P * S
        int ans[6] = {0};
        for (int i = 0; i < 6; i++) {
            for (int j = 0; j < 6; j++) {
                ans[i] = (ans[i] + P.mat[i][j] * S[j]) % mod;
            }
        }

        cout << ans[0] % mod << "\n"; // F(N)
    }
    return 0;
}
posted @ 2025-09-21 17:56  芝士青瓜不拿铁  阅读(42)  评论(0)    收藏  举报