[模板]矩阵快速幂(斐波那契数列)
首先回顾矩阵乘法的定义
\[c_{ij}=\sum_{i=1}^{k}a_{ik} \times b_{kj}
\]
显然,对\(F_{n}=F_{n-1}+F_{n-2}\)这样的柿子,我们可以用待定系数法求得递推矩阵:
设\(f_{n}=\begin{bmatrix}
F_{n} & F_{n-1}
\end{bmatrix}(n > 1)\),解\(f_{n}=f_{n-1}
\begin{bmatrix}
a & b \\
c & d
\end{bmatrix}\)
得
\[\begin{bmatrix}
a & b \\
c & d
\end{bmatrix} =
\begin{bmatrix}
1 & 1 \\
1 & 0
\end{bmatrix}
\]
于是
\[f_{n} = f_{2} \times\begin{bmatrix}
1 & 1\\
1 &0
\end{bmatrix}^{n-2}
\]
利用矩阵快速幂即可\(O(\log_{}{n})\)求解\(F_{n}\)
我们定义一个结构体来进行矩阵的表示,初始化,乘法运算等
struct Matrix {
ll a[maxn][maxn];
ll r, c;
Matrix() {
memset(a, 0, sizeof(a));//构造时进行初始化
}
void print() const { // for debug
if (r < 0 || c < 0) return;
for (int i = 1; i <= r; i++) {
for (int j = 1; j <= c; j++) {
printf("%lld ", a[i][j]);
}
puts("");
}
}
Matrix operator * (const Matrix b) {
Matrix res;
if (c != b.r) {
puts("error");
exit(0);
}
res.r = r; res.c = b.c;
//print(); b.print();
for (int i = 1; i <= r; i++)
for (int j = 1; j <= c; j++)
for (int k = 1; k <= c; k++)
res.a[i][j] = (res.a[i][j] + a[i][k] * b.a[k][j]) % mod;
//res.print();
return res;
}
};
然后快速幂自然也是很简单的啦
Matrix matrix_pow(Matrix m, ll n) {
Matrix res;
res.r = res.c = m.r;
for (int i = 1; i <= res.r; i++) res.a[i][i] = 1;//单位矩阵
while (n) {
if (n & 1) res = res * m;
m = m * m;
n >>= 1;
}
return res;
}
附上完整代码
点击查看代码
#include <iostream>
#include <cstring>
#define ll long long
#define mod 1000000007
#define maxn 3
using namespace std;
struct Matrix {
ll a[maxn][maxn];
ll r, c;
Matrix() {
memset(a, 0, sizeof(a));
}
void print() const { // for debug
if (r < 0 || c < 0) return;
for (int i = 1; i <= r; i++) {
for (int j = 1; j <= c; j++) {
printf("%lld ", a[i][j]);
}
puts("");
}
}
Matrix operator * (const Matrix b) {
Matrix res;
if (c != b.r) {
puts("error");
exit(0);
}
res.r = r; res.c = b.c;
//print(); b.print();
for (int i = 1; i <= r; i++)
for (int j = 1; j <= c; j++)
for (int k = 1; k <= c; k++)
res.a[i][j] = (res.a[i][j] + a[i][k] * b.a[k][j]) % mod;
//res.print();
return res;
}
};
Matrix matrix_pow(Matrix m, ll n) {
Matrix res;
res.r = res.c = m.r;
for (int i = 1; i <= res.r; i++) res.a[i][i] = 1;
while (n) {
if (n & 1) res = res * m;
m = m * m;
n >>= 1;
}
return res;
}
ll n;
Matrix f, base, I;
void init() {
f.r = 1; f.c = 2;
f.a[1][1] = f.a[1][2] = 1;
base.r = base.c = 2;
base.a[1][1] = base.a[1][2] = base.a[2][1] = 1;
base.a[2][2] = 0;
}
int main() {
while (cin >> n) {
init();
//(f * matrix_pow(base, n - 2)).print();
if (n > 2) cout << (f * matrix_pow(base, n - 2)).a[1][1] << endl;
else cout << 1 << endl;
}
return 0;
}

浙公网安备 33010602011771号