【小白学算法】矩阵快速幂超详细解析+例题[HDU - 2802]
用于高效解决线性递推问题
前言
在算法竞赛和实际编程中,我们经常遇到需要计算矩阵的高次幂的问题。如果直接用朴素的矩阵乘法来计算,时间复杂度会达到O(n³ × k),其中n是矩阵的维度,k是幂次。当k非常大时(比如\(10^9\)),这样的时间复杂度是无法接受的。矩阵快速幂就是为了解决这个问题而诞生的算法。
什么是矩阵快速幂?
矩阵快速幂就是快速幂算法再矩阵运算中的应用。它基于这样的思想:
\(A^8\) = \((( A^2 )^2)^2\)
\(A^9\)=\(A × A^8\)
通过将指数进行二进制分解,我们可以在O(log k)次矩阵乘法内计算出\(A^k\)。
快速幂基本原理
以计算a^n为例:
- 
如果n是偶数:\(a^n\) = \((a^{n/2})^2\)
 - 
如果n是奇数:\(a^n\) = a ×$ a^{n-1}$
 
这个过程可以用递归或迭代实现。
矩阵快速幂的实现步骤
矩阵定义
struct matrix {
    int mat[6][6];
    void init() {
        memset(mat, 0, sizeof(mat));
    }
};
矩阵乘法
由于幂的过程中,我们要实现一个矩阵的乘法,所以首先要将矩阵乘法写出;
matrix mul(matrix a, matrix b) {   //return a*b
    matrix c;
    c.init();
    for (int i = 0; i < 6; i++) {
        for (int j = 0; j < 6; j++) {
            for (int k = 0; k < 6; k++) {
                c.mat[i][j] += ((a.mat[i][k] % mod) * (b.mat[k][j] % mod)) % mod;
                c.mat[i][j] %= mod;
            }
        }
    }
    return c;
}
矩阵快速幂
matrix fast_pow(matrix A, int n) {   //return A^n%mod
    matrix B;
    B.init();
    for (int i = 0; i < 6; i++) {   //单位矩阵
        B.mat[i][i] = 1;
    }
    while (n) {
        if (n & 1) {
            B = mul(B, A);
        }
        A = mul(A, A);
        n >>= 1;
    }
    return B;
}
例题HDU - 2802
本题就是根据递推公式,求其中的某一项即F(n),那么可以用矩阵快速幂来解决此题(本题还可以打表找出规律求解,这里只介绍矩阵快速幂方法^^)
在写之前先求出转移矩阵是此题的关键
递推式化简:
\(F(N)=F(N−2)+N^3−(N−1)^3\)
展开:
\(N^3−(N−1)^3=3*N^2−3*N+16\)
所以:
\(F(N)=F(N−2)+3*N^2−3*N+1\)
状态向量构造:
我们需要同时存$ F(N)\(,以及和 N 相关的多项式项(\)N_2\(, N, 常数)。 因为转移里有\) F(N−2)$,所以状态至少要带上 \(F(N)\),\(F(N−1)\)。
定义:
转移矩阵:
根据递推式:
\(F(N)=F(N−2)+3*N^2−3*N+1\)
所以:
- 
\(F(N)\) 依赖 \(F(N−2)\),而$ F(N−2) $就是上一步向量里的第二维。
 - 
多项式部分直接线性组合 \(N_2\),\(N\),1。
 
然后要写出矩阵,把 \(S(N)\) 表达为 \(M⋅S(N−1)\)。
完整代码
#include <iostream>
#include <cstring>
#define int long long
using namespace std;
const int mod = 2009;
struct matrix {
    int mat[6][6];
    void init() {
        memset(mat, 0, sizeof(mat));
    }
};
matrix mul(matrix a, matrix b) {   //return a*b
    matrix c;
    c.init();
    for (int i = 0; i < 6; i++) {
        for (int j = 0; j < 6; j++) {
            for (int k = 0; k < 6; k++) {
                c.mat[i][j] += ((a.mat[i][k] % mod) * (b.mat[k][j] % mod)) % mod;
                c.mat[i][j] %= mod;
            }
        }
    }
    return c;
}
matrix fast_pow(matrix A, int n) {   //return A^n%mod
    matrix B;
    B.init();
    for (int i = 0; i < 6; i++) {   //单位矩阵
        B.mat[i][i] = 1;
    }
    while (n) {
        if (n & 1) {
            B = mul(B, A);
        }
        A = mul(A, A);
        n >>= 1;
    }
    return B;
}
signed main() {
    int N;
    while (cin >> N && N) {
        if (N == 1) {
            cout << 1 % mod << "\n";
            continue;
        }
        if (N == 2) {
            cout << 7 % mod << "\n";
            continue;
        }
        // 状态向量: [F(n), F(n-1), n^2, n, 1, dummy]
        // 6维里最后一个可以闲置
        // 我们要求 F(N),利用矩阵快速幂推到目标
        // 转移矩阵 M: S(k) -> S(k+1)
        matrix M;
        M.init();
        // F(k+1) = F(k-1) + 3(k+1)^2 - 3(k+1) + 1
        // => 依赖 F(k-1) 以及 (k+1)^2, (k+1), 1
        // 这里直接手写对应项
        // [F(k+1)] = [0 1 3 -3 1 0] * S(k)
        // [F(k)]   = [1 0 0 0 0 0] * S(k)
        // [ (k+1)^2 ] = 转移自 N^2, N, 1
        // [ (k+1) ]   = ...
        // [ 1 ]       = [0 0 0 0 1 0]
        M.mat[0][1] = 1;  // F(k-1)
        M.mat[0][2] = 3;  // +3*N^2
        M.mat[0][3] = 3; // -3*N
        M.mat[0][4] = 1;  // +1
        M.mat[1][0] = 1;  // F(k)
        // (k+1)^2 = N^2 + 2N + 1
        M.mat[2][2] = 1;
        M.mat[2][3] = 2;
        M.mat[2][4] = 1;
        // (k+1) = N + 1
        M.mat[3][3] = 1;
        M.mat[3][4] = 1;
        M.mat[4][4] = 1; // 常数项保持 1
        // 取模修正负数
        for (int i = 0; i < 6; i++) {
            for (int j = 0; j < 6; j++) {
                M.mat[i][j] = (M.mat[i][j] % mod + mod) % mod;
            }
        }
        // 初始状态 S(2)
        int S[6] = {7, 1, 4, 2, 1, 0}; // [F(2)=7, F(1)=1, 2^2=4, 2, 1, 0]
        // 快速幂:从 S(2) 推到 S(N)
        matrix P = fast_pow(M, N - 2);
        // 结果向量 = P * S
        int ans[6] = {0};
        for (int i = 0; i < 6; i++) {
            for (int j = 0; j < 6; j++) {
                ans[i] = (ans[i] + P.mat[i][j] * S[j]) % mod;
            }
        }
        cout << ans[0] % mod << "\n"; // F(N)
    }
    return 0;
}

                
            
浙公网安备 33010602011771号