【小白学算法】矩阵快速幂超详细解析+例题[HDU - 2802]
用于高效解决线性递推问题
前言
在算法竞赛和实际编程中,我们经常遇到需要计算矩阵的高次幂的问题。如果直接用朴素的矩阵乘法来计算,时间复杂度会达到O(n³ × k),其中n是矩阵的维度,k是幂次。当k非常大时(比如\(10^9\)),这样的时间复杂度是无法接受的。矩阵快速幂就是为了解决这个问题而诞生的算法。
什么是矩阵快速幂?
矩阵快速幂就是快速幂算法再矩阵运算中的应用。它基于这样的思想:
\(A^8\) = \((( A^2 )^2)^2\)
\(A^9\)=\(A × A^8\)
通过将指数进行二进制分解,我们可以在O(log k)次矩阵乘法内计算出\(A^k\)。
快速幂基本原理
以计算a^n为例:
-
如果n是偶数:\(a^n\) = \((a^{n/2})^2\)
-
如果n是奇数:\(a^n\) = a ×$ a^{n-1}$
这个过程可以用递归或迭代实现。
矩阵快速幂的实现步骤
矩阵定义
struct matrix {
int mat[6][6];
void init() {
memset(mat, 0, sizeof(mat));
}
};
矩阵乘法
由于幂的过程中,我们要实现一个矩阵的乘法,所以首先要将矩阵乘法写出;
matrix mul(matrix a, matrix b) { //return a*b
matrix c;
c.init();
for (int i = 0; i < 6; i++) {
for (int j = 0; j < 6; j++) {
for (int k = 0; k < 6; k++) {
c.mat[i][j] += ((a.mat[i][k] % mod) * (b.mat[k][j] % mod)) % mod;
c.mat[i][j] %= mod;
}
}
}
return c;
}
矩阵快速幂
matrix fast_pow(matrix A, int n) { //return A^n%mod
matrix B;
B.init();
for (int i = 0; i < 6; i++) { //单位矩阵
B.mat[i][i] = 1;
}
while (n) {
if (n & 1) {
B = mul(B, A);
}
A = mul(A, A);
n >>= 1;
}
return B;
}
例题HDU - 2802
本题就是根据递推公式,求其中的某一项即F(n),那么可以用矩阵快速幂来解决此题(本题还可以打表找出规律求解,这里只介绍矩阵快速幂方法^^)
在写之前先求出转移矩阵是此题的关键
递推式化简:
\(F(N)=F(N−2)+N^3−(N−1)^3\)
展开:
\(N^3−(N−1)^3=3*N^2−3*N+16\)
所以:
\(F(N)=F(N−2)+3*N^2−3*N+1\)
状态向量构造:
我们需要同时存$ F(N)\(,以及和 N 相关的多项式项(\)N_2\(, N, 常数)。 因为转移里有\) F(N−2)$,所以状态至少要带上 \(F(N)\),\(F(N−1)\)。
定义:
转移矩阵:
根据递推式:
\(F(N)=F(N−2)+3*N^2−3*N+1\)
所以:
-
\(F(N)\) 依赖 \(F(N−2)\),而$ F(N−2) $就是上一步向量里的第二维。
-
多项式部分直接线性组合 \(N_2\),\(N\),1。
然后要写出矩阵,把 \(S(N)\) 表达为 \(M⋅S(N−1)\)。
完整代码
#include <iostream>
#include <cstring>
#define int long long
using namespace std;
const int mod = 2009;
struct matrix {
int mat[6][6];
void init() {
memset(mat, 0, sizeof(mat));
}
};
matrix mul(matrix a, matrix b) { //return a*b
matrix c;
c.init();
for (int i = 0; i < 6; i++) {
for (int j = 0; j < 6; j++) {
for (int k = 0; k < 6; k++) {
c.mat[i][j] += ((a.mat[i][k] % mod) * (b.mat[k][j] % mod)) % mod;
c.mat[i][j] %= mod;
}
}
}
return c;
}
matrix fast_pow(matrix A, int n) { //return A^n%mod
matrix B;
B.init();
for (int i = 0; i < 6; i++) { //单位矩阵
B.mat[i][i] = 1;
}
while (n) {
if (n & 1) {
B = mul(B, A);
}
A = mul(A, A);
n >>= 1;
}
return B;
}
signed main() {
int N;
while (cin >> N && N) {
if (N == 1) {
cout << 1 % mod << "\n";
continue;
}
if (N == 2) {
cout << 7 % mod << "\n";
continue;
}
// 状态向量: [F(n), F(n-1), n^2, n, 1, dummy]
// 6维里最后一个可以闲置
// 我们要求 F(N),利用矩阵快速幂推到目标
// 转移矩阵 M: S(k) -> S(k+1)
matrix M;
M.init();
// F(k+1) = F(k-1) + 3(k+1)^2 - 3(k+1) + 1
// => 依赖 F(k-1) 以及 (k+1)^2, (k+1), 1
// 这里直接手写对应项
// [F(k+1)] = [0 1 3 -3 1 0] * S(k)
// [F(k)] = [1 0 0 0 0 0] * S(k)
// [ (k+1)^2 ] = 转移自 N^2, N, 1
// [ (k+1) ] = ...
// [ 1 ] = [0 0 0 0 1 0]
M.mat[0][1] = 1; // F(k-1)
M.mat[0][2] = 3; // +3*N^2
M.mat[0][3] = 3; // -3*N
M.mat[0][4] = 1; // +1
M.mat[1][0] = 1; // F(k)
// (k+1)^2 = N^2 + 2N + 1
M.mat[2][2] = 1;
M.mat[2][3] = 2;
M.mat[2][4] = 1;
// (k+1) = N + 1
M.mat[3][3] = 1;
M.mat[3][4] = 1;
M.mat[4][4] = 1; // 常数项保持 1
// 取模修正负数
for (int i = 0; i < 6; i++) {
for (int j = 0; j < 6; j++) {
M.mat[i][j] = (M.mat[i][j] % mod + mod) % mod;
}
}
// 初始状态 S(2)
int S[6] = {7, 1, 4, 2, 1, 0}; // [F(2)=7, F(1)=1, 2^2=4, 2, 1, 0]
// 快速幂:从 S(2) 推到 S(N)
matrix P = fast_pow(M, N - 2);
// 结果向量 = P * S
int ans[6] = {0};
for (int i = 0; i < 6; i++) {
for (int j = 0; j < 6; j++) {
ans[i] = (ans[i] + P.mat[i][j] * S[j]) % mod;
}
}
cout << ans[0] % mod << "\n"; // F(N)
}
return 0;
}

浙公网安备 33010602011771号