矩阵快速幂
一、前置知识
(1)快速幂
定义
快速幂,二进制取幂(Binary Exponentiation,也称平方法),是一个在 \(O(log n)\) 的时间内计算 \(a^n\) 的小技巧,而暴力的计算需要 \(O(n)\) 的时间。
这个技巧也常常用在非计算的场景,因为它可以应用在任何具有结合律的运算中。其中显然的是它可以应用于模意义下取幂、矩阵幂等运算,我们接下来会讨论。
解释
计算 \(a\) 的 \(n\) 次方表示将 \(n\) 个 \(a\) 乘在一起:\(a^n = \underbrace{a \times a \times \cdots \times a}_{n \text{ 个 } a}\)。然而当 \(a,n\) 太大的时候,这种方法就不太适用了。不过我们知道:\(a^{b+c} = a^b \cdot a^c\);\(a^{2b} = a^b \cdot a^b = (a^b)^2\)。二进制取幂的想法是,我们将取幂的任务按照指数的 二进制表示 来分割成更小的任务。
过程
首先我们将 \(n\) 表示为 2 进制,举一个例子:
因为 \(n\) 有 \(\lfloor \log_2 n \rfloor + 1\) 个二进制位,因此当我们知道了 \(a^1, a^2, a^4, a^8, \dots, a^{2^{\lfloor \log_2 n \rfloor}}\) 后,我们只用计算 \(O(log n)\) 次乘法就可以计算出 \(a^n\)。
于是我们只需要知道一个快速的方法来计算上述的 \(3\) 的 \(2^k\) 次幂的序列。这个问题很简单,因为序列中(除第一个)任意一个元素就是其前一个元素的平方。举一个例子:
因此为了计算 (3^{13}),我们只需要将对应二进制位为 1 的整系数幂乘起来就行了:
参考代码
点击查看代码
int qmi(int a, int b) {
int res = 1; // 初始化结果为 1
while (b) { // 当指数 b 不为 0 时继续循环
if (b & 1) // 判断当前 b 的二进制最低位是否为 1(等同于 b % 2 == 1)
res = res * a; // 若为 1,则将当前的 a 乘入结果中
a = a * a; // 每次循环让 a 平方,依次得到 a^1, a^2, a^4, a^8...
b >>= 1; // 将 b 右移一位(等同于 b /= 2),去掉二进制的最低位
}
return res; // 返回最终结果
}
参考例题
(2)矩阵乘法
定义
由 \(n \times m\) 个数 \(a_{i,j}\) 排成的 \(n\) 行 \(m\) 列的数表称为 \(n\) 行 \(m\) 列的矩阵,简称 \(n \times m\) 矩阵。记作:
矩阵的乘法满足以下运算律:
- 结合律:\((AB)C = A(BC)\)
- 左分配律:\((A + B)C = AC + BC\)
- 右分配律:\(C(A + B) = CA + CB\)
- 矩阵乘法不满足交换律,即 $A\times B $ != $ B\times A$
过程
矩阵相乘只有在第一个矩阵的列数和第二个矩阵的行数相同时才有意义。
设 \(A\) 为 \(P \times M\) 的矩阵,\(B\) 为 \(M \times Q\) 的矩阵,设矩阵 \(C\) 为矩阵 \(A\) 与 \(B\) 的乘积,
其中矩阵 \(C\) 中的第 \(i\) 行第 \(j\) 列元素可以表示为:
在矩阵乘法中,结果 \(C\) 矩阵的第 \(i\) 行第 \(j\) 列的数,就是由矩阵 \(A\) 第 \(i\) 行 \(M\) 个数与矩阵 \(B\) 第 \(j\) 列 \(M\) 个数分别 相乘再相加 得到的。口诀为 左行右列。
参考代码
点击查看代码
const int p=100,q=100,m=100; // 注意创建数组时不能用变量制定大小,需要常量
int a[p][m],b[m][q],c[p][q];
void mul(){
for(int i=1;i<=p;i++){ // 第一个矩阵的行数为 p
for(int j=1;j<=q;j++){ // 第二个矩阵的列数为 q
for(int k=1;k<=m;k++){ // 第一个矩阵的列数和第二个矩阵的行数为 m
c[i][j]+=a[i][k]*b[k][j]; // 依据矩阵乘法的原理展开
}
}
}
}
参考例题
二、矩阵快速幂
原理
快速幂的技巧适用于所有具有结合律的运算,而矩阵乘法就具有结合律,所以当然可以将快速幂推广到矩阵乘法。
方法
方法也很简单,就是将原来快速幂的乘法替换成矩阵乘法,同时要把原来的res初始值做一下改变,变为单位矩阵。
单位矩阵(identity matrix)指的是在矩阵的乘法中,一种如同数的乘法中1的作用的特殊方阵。根据单位矩阵的特点,任何矩阵与单位矩阵相乘都等于本身。
特征是:从左上角到右下角的对角线(称为主对角线)上的元素均为1。除此以外全都为0。
参考代码
点击查看代码
const int n=100,m=100; // 注意创建数组时不能用变量制定大小,需要常量
// 矩阵 a 与矩阵 b 相乘,最后将结果记录在 c 中
void mul(int a[n][m],int b[m][m],int res[m][m]){
int tmp[m][m]={0}; // 使用中间数组,防止改变 a 或 b
for(int i=1;i<=n;i++){
for(int j=1;j<=m;j++){
for(int k=1;k<=m;k++){
tmp[i][j]+=a[i][k]*b[k][j];
}
}
}
memcpy(res,tmp,sizeof tmp);
}
void qmi(int a[n][m],int k,int res[n][m]){ // 求矩阵的 k 次方
// 构建单位矩阵
int tmp[m][m]={0};
for(int i=1;i<=m;i++) tmp[i][i]=1;
while(k){
if(k&1) mul(tmp,a,tmp);
mul(a,a,a);
k>>=1;
}
memcpy(res,tmp,sizeof tmp);
}
参考例题
三、应用
矩阵加速递推
斐波那契数列(Fibonacci Sequence)大家应该都非常的熟悉了。在斐波那契数列当中,\(F_1 = F_2 = 1\),\(F_i = F_{i - 1} + F_{i - 2} (i \geq 3)\)。
如果有一道题目让你求斐波那契数列第 \(n\) 项的值,最简单的方法莫过于直接递推了。但是如果 \(n\) 的范围达到了 \(10^{18}\) 级别,递推就不行了,稳 TLE。考虑矩阵加速递推。
设 \(Fib(n)\) 表示一个 \(1 \times 2\) 的矩阵 \(\begin{bmatrix} F_n & F_{n - 1} \end{bmatrix}\)。我们希望根据 \(Fib(n - 1) = \begin{bmatrix} F_{n - 1} & F_{n - 2} \end{bmatrix}\) 推出 \(Fib(n)\)。
试推导一个矩阵 \(\text{base}\),使 \(Fib(n - 1) \times \text{base} = Fib(n)\),即 \(\begin{bmatrix} F_{n - 1} & F_{n - 2} \end{bmatrix} \times \text{base} = \begin{bmatrix} F_n & F_{n - 1} \end{bmatrix}\)。
怎么推呢?因为 \(F_n = F_{n - 1} + F_{n - 2}\),所以 \(\text{base}\) 矩阵第一列应该是 \(\begin{bmatrix} 1 \\ 1 \end{bmatrix}\),这样在进行矩阵乘法运算的时候才能令 \(F_{n - 1}\) 与 \(F_{n - 2}\) 相加,从而得出 \(F_n\)。同理,为了得出 \(F_{n - 1}\),矩阵 \(\text{base}\) 的第二列应该为 \(\begin{bmatrix} 1 \\ 0 \end{bmatrix}\)。
综上所述:\(\text{base} = \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}\),原式化为 \(\begin{bmatrix} F_{n - 1} & F_{n - 2} \end{bmatrix} \times \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix} = \begin{bmatrix} F_n & F_{n - 1} \end{bmatrix}\)
转化为代码,应该怎么求呢?
定义初始矩阵 \(\text{ans} = \begin{bmatrix} F_2 & F_1 \end{bmatrix} = \begin{bmatrix} 1 & 1 \end{bmatrix}\),\(\text{base} = \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}\)。那么,\(F_n\) 就等于 \(\text{ans} \times \text{base}^{n - 2}\) 这个矩阵的第一行第一列元素,也就是 \(\begin{bmatrix} 1 & 1 \end{bmatrix} \times \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}^{n - 2}\) 的第一行第一列元素。
注意,矩阵乘法不满足交换律,所以一定不能写成 \(\begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}^{n - 2} \times \begin{bmatrix} 1 & 1 \end{bmatrix}\) 的第一行第一列元素。另外,对于 \(n \leq 2\) 的情况,直接输出 \(1\) 即可,不需要执行矩阵快速幂。
为什么要乘上 \(\text{base}\) 矩阵的 \(n - 2\) 次方而不是 \(n\) 次方呢?因为 \(F_1, F_2\) 是不需要进行矩阵乘法就能求的。也就是说,如果只进行一次乘法,就已经求出 \(F_3\) 了。如果还不是很理解为什么幂是 \(n - 2\),建议手算一下。
参考代码
点击查看代码
// 矩阵 a 与矩阵 b 相乘,最后将结果记录在 c 中
void mul(int a[][3],int b[][3],int c[][3]){
int tmp[3][3]={0}; // 使用中间数组,防止改变 a 或 b
for(int i=1;i<=2;i++){
for(int j=1;j<=2;j++){
for(int k=1;k<=2;k++){
tmp[i][j]+=a[i][k]*b[k][j];
}
}
}
memcpy(c,tmp,sizeof tmp);
}
void qmi(int a[][3],int k,int res[][3]){ // 求矩阵的 k 次方
// 构建单位矩阵
int tmp[3][3]={0};
for(int i=1;i<=2;i++) tmp[i][i]=1;
while(k){
if(k&1) mul(tmp,a,tmp);
mul(a,a,a);
k>>=1;
}
memcpy(res,tmp,sizeof tmp);
}
void solve(){
int n;
cin>>n;
if(n<=2) cout<<1<<endl;
else{
int ans[2][3]={0},base[3][3]={0};
ans[1][1]=a[1][2]=1;
base[1][1]=base[1][2]=base[2][1]=1;
qmi(base,n-2,base);
mul(ans,base,ans);
cout<<ans[1][1]<<endl;
}
}
参考例题
以下题目都是利用矩阵加速递推,本质上是一样的。只是矩阵表示不一样而已。
斐波那契数列
矩阵表示
参考代码
点击查看代码
#include <bits/stdc++.h>
#define int long long
#define pb push_back
#define fi first
#define se second
#define endl "\n"
using namespace std;
const int N=1e6+10,M=2010,mod=1e9+7,INF=0x3f3f3f3f;
// 矩阵 a 与矩阵 b 相乘,最后将结果记录在 c 中
void mul(int a[][3],int b[][3],int c[][3]){
int tmp[3][3]={0}; // 使用中间数组,防止改变 a 或 b
for(int i=1;i<=2;i++){
for(int j=1;j<=2;j++){
for(int k=1;k<=2;k++){
tmp[i][j]+=a[i][k]*b[k][j];
tmp[i][j]%=mod;
}
}
}
memcpy(c,tmp,sizeof tmp);
}
void qmi(int a[][3],int k,int res[][3]){ // 求矩阵的 k 次方
// 构建单位矩阵
int tmp[3][3]={0};
for(int i=1;i<=2;i++) tmp[i][i]=1;
while(k){
if(k&1) mul(tmp,a,tmp);
mul(a,a,a);
k>>=1;
}
memcpy(res,tmp,sizeof tmp);
}
void solve(){
int n;
cin>>n;
if(n<=2) cout<<1<<endl;
else{
int a[2][3]={0},base[3][3]={0};
a[1][1]=a[1][2]=1;
base[1][1]=base[1][2]=base[2][1]=1;
qmi(base,n-2,base);
mul(a,base,a);
cout<<a[1][1]<<endl;
}
}
signed main(){
ios::sync_with_stdio(0);cin.tie(0),cout.tie(0);
int _=1;
// cin>>_;
while(_--) solve();
return 0;
}
矩阵表示
参考代码
点击查看代码
#include <bits/stdc++.h>
#define int long long
#define pb push_back
#define fi first
#define se second
#define endl "\n"
using namespace std;
const int N=1e6+10,M=2010,mod=1e9+7,INF=0x3f3f3f3f;
// 矩阵 a 与矩阵 b 相乘,最后将结果记录在 c 中
void mul(int a[][4],int b[][4],int c[][4]){
int tmp[4][4]={0}; // 使用中间数组,防止改变 a 或 b
for(int i=1;i<=3;i++){
for(int j=1;j<=3;j++){
for(int k=1;k<=3;k++){
tmp[i][j]+=a[i][k]*b[k][j];
tmp[i][j]%=mod;
}
}
}
memcpy(c,tmp,sizeof tmp);
}
void qmi(int a[][4],int k,int res[][4]){ // 求矩阵的 k 次方
// 构建单位矩阵
int tmp[4][4]={0};
for(int i=1;i<=2;i++) tmp[i][i]=1;
while(k){
if(k&1) mul(tmp,a,tmp);
mul(a,a,a);
k>>=1;
}
memcpy(res,tmp,sizeof tmp);
}
void solve(){
int n;
cin>>n;
if(n<=3) cout<<1<<endl;
else{
int a[2][4]={0},base[4][4]={0};
a[1][1]=a[1][2]=a[1][3]=1;
base[1][1]=base[1][2]=base[2][3]=base[3][1]=1;
qmi(base,n-2,base);
mul(a,base,a);
cout<<a[1][1]<<endl;
}
}
signed main(){
ios::sync_with_stdio(0);cin.tie(0),cout.tie(0);
int _=1;
cin>>_;
while(_--) solve();
return 0;
}
矩阵表示
参考代码
点击查看代码
#include <bits/stdc++.h>
#define int long long
#define pb push_back
#define fi first
#define se second
#define endl "\n"
using namespace std;
const int N=1e6+10,M=110,mod=1e9+7,INF=0x3f3f3f3f;
int n,m;
int a[M][M],base[M][M];
// 矩阵 a 与矩阵 b 相乘,最后将结果记录在 c 中
void mul(int a[][M],int b[][M],int c[][M]){
int tmp[M][M]={0}; // 使用中间数组,防止改变 a 或 b
for(int i=1;i<=m;i++){
for(int j=1;j<=m;j++){
for(int k=1;k<=m;k++){
tmp[i][j]+=a[i][k]*b[k][j];
tmp[i][j]%=mod;
}
}
}
memcpy(c,tmp,sizeof tmp);
}
void qmi(int a[][M],int k,int res[][M]){ // 求矩阵的 k 次方
// 构建单位矩阵
int tmp[M][M]={0};
for(int i=1;i<=m;i++) tmp[i][i]=1;
while(k){
if(k&1) mul(tmp,a,tmp);
mul(a,a,a);
k>>=1;
}
memcpy(res,tmp,sizeof tmp);
}
void solve(){
cin>>m>>n;
for(int i=1;i<=m;i++) cin>>a[1][i];
for(int i=1;i<=m;i++){
for(int j=i;j<=m;j++){
base[i][j]=1;
}
}
qmi(base,n,base);
mul(a,base,a);
for(int i=1;i<=m;i++) cout<<a[1][i]<<" ";
cout<<endl;
}
signed main(){
ios::sync_with_stdio(0);cin.tie(0),cout.tie(0);
int _=1;
// cin>>_;
while(_--) solve();
return 0;
}
矩阵表示
参考代码
点击查看代码
#include <bits/stdc++.h>
#define int long long
#define pb push_back
#define fi first
#define se second
#define endl "\n"
using namespace std;
const int N=1e6+10,M=110,mod=1e9+7,INF=0x3f3f3f3f;
int n,m;
int a[M][M],base[M][M];
// 矩阵 a 与矩阵 b 相乘,最后将结果记录在 c 中
void mul(int a[][M],int b[][M],int c[][M]){
int tmp[M][M]={0}; // 使用中间数组,防止改变 a 或 b
for(int i=1;i<=2;i++){
for(int j=1;j<=2;j++){
for(int k=1;k<=2;k++){
tmp[i][j]+=a[i][k]*b[k][j];
tmp[i][j]%=m;
}
}
}
memcpy(c,tmp,sizeof tmp);
}
void qmi(int a[][M],int k,int res[][M]){ // 求矩阵的 k 次方
// 构建单位矩阵
int tmp[M][M]={0};
for(int i=1;i<=2;i++) tmp[i][i]=1;
while(k){
if(k&1) mul(tmp,a,tmp);
mul(a,a,a);
k>>=1;
}
memcpy(res,tmp,sizeof tmp);
}
void solve(){
int p,q,a1,a2;
cin>>p>>q>>a1>>a2>>n>>m;
if(n==1){
cout<<a1%m<<endl;
return;
}else if(n==2){
cout<<a2%m<<endl;
return;
}
a[1][1]=a2,a[1][2]=a1;
base[1][1]=p,base[1][2]=1,base[2][1]=q;
qmi(base,n-2,base);
mul(a,base,a);
cout<<a[1][1]<<endl;
}
signed main(){
ios::sync_with_stdio(0);cin.tie(0),cout.tie(0);
int _=1;
// cin>>_;
while(_--) solve();
return 0;
}
引用申明
本文章引用部分OI Wiki的相关内容,若有侵权,可以联系删除。

浙公网安备 33010602011771号