浅谈矩阵快速幂
\(update\ on\ 2022.9.28\) 发现自己以前写的一堆垃圾,这里做一些说明:
这篇博客是19年几乎算是初学算法时写的,当时非常 naive,理解不够透彻。
矩阵只是计算的一种方式,而计算矩阵的方法并不只是快速幂,还可以用数据结构维护等。
常用的领域是 \(dp\),以及臭名昭著的ddp,不过由于一些矩阵优化线性 \(dp\) 的转移是相同的,因此用快速幂可以迅速计算,显著特点是巨大的转移次数。当然如果转移式与其他信息有关就得上线段树之类的了还有你甚至可以矩阵套矩阵。
当然由于懒得再写一篇了所以这篇博客只讲矩阵与快速幂的结合优化 \(dp\) 以及一个与图论的联系。
在谈矩阵快速幂前,我们先说说什么是矩阵。
关于矩阵
顾名思义,矩阵就是把一些数填入一个矩形形式如下:
所以呢?
所以我们再来介绍一下矩阵的乘法。
矩阵乘法
矩阵乘法就是将两个矩阵相乘,可是我们要怎么运算呢?百度搜了搜从书上得知,矩阵乘法有如下定义:
设 \(A,B\) 是两个矩阵,令 \(C=A\times B\),那么:
-
\(A\) 的列数必须和 \(B\) 的行数相等;
-
设 \(A\) 是 \(n\times r\) 的矩阵,\(B\) 是 \(r\times m\) 的矩阵;那么 \(C\) 是一个 \(n×m\) 的矩阵;
-
\(C_{i,j}=A_{i,1}\times B_{1,j}+A_{i,2}\times B_{2,j}+A_{i,3}\times B_{3,j}+…+A_{i,n}\times B_{n,j}=\sum\limits^{n}_{k=1}A_{i,k}\times B_{k,j}\)
可以简单的理解为 \(C_{i,j}=\) \(A\)的第 \(i\) 行与 \(B\) 的第 \(j\) 列依次相乘的和。
例如:
(其实你也可以自己定义矩阵乘法,比如 \(\min+\))
由此我们可以写出代码:
struct matrix{
int n,m,g[N][N];//n为行数,m为列数,g是矩阵,N为矩阵大小
matrix operator*(const matrix&b)const{
matrix c;c.n=a.n;c.m=b.m;
memset(c.g,0,sizeof(c.g));//初始化
for(int i=1;i<=n;i++)
for(int k=1;k<=m;k++)
for(int j=1;j<=b.m;j++)
c.g[i][j]=c.g[i][j]+g[i][k]*b.g[k][j];
return c;
}
}
matrix multiply(matrix a,matrix b){
matrix c;c.n=a.n;c.m=b.m;
memset(c.g,0,sizeof(c.g));
for(int i=1;i<=a.n;i++)
for(int k=1;k<=a.m;k++)
for(int j=1;j<=b.m;j++)
c.g[i][j]=c.g[i][j]+a.g[i][k]*b.g[k][j];
return c;
}
那么矩阵乘法满不满足运算律呢?
有的,矩阵乘法满足结合律,证明如下:
设 \(n\) 阶矩阵为:
\(A=(a_{i,j})\)
\(B=(b_{i,j})\)
\(C=(c_{i,j})\)
\(A\times{B}=(d_{i,j})\)
\(B\times{C}=(e_{i,j})\)
\((A\times{B})\times{C}=(f_{i,j})\)
$A\times{(B\times{C})}=(g_{i,j}) $
由矩阵的乘法得
\(d_{i,j}=a_{i,1}\times{b_{1,j}}+a_{i,2}\times{b_{2,j}}+...+a_{i,n}\times{b_{n,j}}\)
\(e_{i,j}=b_{i,1}\times{c_{1,j}}+b_{i,2}\times{c_{2,j}}+...+b_{i,n}\times{c_{n,j}}\)
\(f_{i,j}=d_{i,1}\times{c_{1,j}}+d_{i,2}\times{c_{2,j}}+...+d_{i,n}\times{c_{n,j}}\)
\(g_{i,j}=a_{i,1}\times{e_{1,j}}+a_{i,2}\times{e_{2,j}}+...+a_{i,n}\times{e_{n,j}}\)
故对任意\(i,j=1,2,...,n\)有,
$f_{i,j}=d_{i,1}\times{c_{1,j}}+d_{i,2}\times{c_{2,j}}+...+d_{i,n}\times{c_{n,j}} $
\(=(a_{i,1}\times{b_{1,1}}+a_{i,2}\times{b_{2,1}}+...+a_{i,n}\times{b_{n,1}})\times{c_{1,j}}+(a_{i,1}\times{b_{1,1}}+a_{i,2}\times{b_{2,1}}+...+a_{i,n}\times{b_{n,1}})\times{c_{2,j}}\)
$+...+(a_{i,1}\times{b_{1,n}}+a_{i,2}\times{b_{2,n}+...+a_{i,n}}\times{b_{n,n}})\times{c_{n,j}} $
\(=a_{i,1}\times{(b_{1,1}\times{c_{1,j}}+b_{1,2}\times{c_{2,j}}+...+b_{1,n}\times{c_{n,j}})}+a_{i,2}\times{(b_{2,1}\times{c_{1,j}}+b_{2,2}\times{c_{2,j}}+...+b_{2,n}\times{c_{n,j}})}\)
$+...+a_{i,n}\times{(b_{n,1}\times{c_{1,j}}+b_{n,2}\times{c_{2,j}}+...+b_{n,n}\times{c_{n,j}})} $
$=a_{i,1}\times{e_{1,j}}+a_{i,2}\times{e_{2,j}}+...+a_{i,n}\times{e_{n,j}}=g_{i,j} $
故\((A\times{B})\times{C}=A\times{(B\times{C})}.\)
但显然是不满足交换律的。
证明来自互联网:https://www.cnblogs.com/Jakson/articles/4557558.html
那这又有什么用呢?让我们步入正题。
矩阵快速幂
对于一个\(n×n\)的矩阵,我们称其为方阵,方阵可以进行幂运算,定义为\(C=A^n\)。显然,对于任意一个矩阵,只有它是方阵才能进行幂运算(不然n和m都不相等啊)。
又因为矩阵乘法满足结合律,所以我们可以用快速幂来进行矩阵的幂运算。
矩阵乘法的好处在于它能将有用的状态储存在一个矩阵中,并通过一次乘法得到一次 \(dp\) 的值( \(dp\) 状态转移方程必须是线性的,虽然有些题目中状态转移方程不是线性的,但是我们可以通过一定的转化将其转化为线性方程)
一般的形式有行向量左乘矩阵和列向量右乘矩阵,例如
实现时一般在向量中补 \(0\) 变为方阵。最后要乘上一个初始矩阵,由于没有交换律,所以要注意乘的顺序。
为什么必须线性?因为在快速幂过程中转移是批量计算的,无法中途得到之前的某一个值。
结合快速幂:
matrix operator^(int k){
matrix ans,x=*this;
memset(ans.g,0,sizeof(ans.g));
for(int i=1;i<=N;i++)ans.g[i][i]=1;
for(;k;k>>=1,x=x*x)if(k&1)ans=ans*x;
return ans;
}
matrix power(matrix a,int b){
matrix ans;
memset(ans.g,0,sizeof(ans.g));
for(int i=1;i<=N;i++)ans.g[i][i]=1;
while(b){
if(b&1)ans=multiply(ans,a);
a=multiply(a,a);
b>>=1;
}
return ans;
}
关于单位矩阵 \(I\):单位矩阵相当于矩阵乘法中的 \(1\),在每次快速幂之前需初始化 \(ans\) 矩阵为 \(I\)。
因为使用快速幂,所以矩阵快速幂的时间复杂度基本为\(O(N^3\times{log(n)})\),其中 \(N\) 为矩阵大小
下面讲解一些例题:
斐波那契数列
矩阵快速幂模板题。
因为
\(f_{i}=1\times{f_{i-1}}+1\times{f_{i-2}}\)
\(f_{i-1}=1\times{f_{i-1}}+0\times{f_{i-2}}\)
所以得到转移矩阵:
转移为:
用列向量
乘转移矩阵 \(n\) 次得到 \(G_{1,1}=f_{n+1},G_{2,1}=f_{n}\)。输出 \(G_{2,1}\) 即可。
代码如下:
#include<bits/stdc++.h>
#define int long long
#define inf 0x3f3f3f3f
using namespace std;
int read(){
int w=0,h=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')h=-h;ch=getchar();}
while(ch>='0'&&ch<='9'){w=w*10+ch-'0';ch=getchar();}
return w*h;
}
const int mod=1e9+7;
int n,q;
struct matrix{
int g[105][105];
matrix operator*(const matrix&b)const{
matrix c;
for(int i=1;i<=2;i++)
for(int j=1;j<=2;j++){
c.g[i][j]=0;
for(int k=1;k<=2;k++){
c.g[i][j]=(c.g[i][j]+g[i][k]*b.g[k][j]%mod)%mod;
}
}
return c;
}
matrix operator^(int k){
matrix ans,x=*this;memset(ans.g,0,sizeof(ans.g));
for(int i=1;i<=2;i++)ans.g[i][i]=1;
for(;k;k>>=1,x=x*x)if(k&1)ans=ans*x;
return ans;
}
}f;
signed main(){
q=read();
f.g[1][1]=1;f.g[1][2]=1;
f.g[2][1]=1;f.g[2][2]=0;
f=f^q;
printf("%lld\n",f.g[2][1]%mod);
return 0;
}
[NOI2012]随机数生成器
矩阵快速幂进阶的运用。题目中给出了状态转移方程,根据转移方程列出转移矩阵:
初始化 \(res\) 矩阵: \(res.G_{1,1}=X_0,res.G_{2,2}=1\);由于我们要推的是 \(X_n\),第一项为 \(X_0\),所以不应该初始化为单位矩阵(这里谬误:初始化的 \(res\) 相当于初始矩阵已经乘了 \(I\))
发现模数 \(m\) 有点大,用龟速乘防止爆 \(long\ long\)(什么,你不知道龟速乘?)
#include<bits/stdc++.h>
#pragma GCC optimize(2)
#define int long long
using namespace std
int read(){
int w=0,h=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')h=-h;ch=getchar();}
while(ch>='0'&&ch<='9'){w=w*10+ch-'0';ch=getchar();}
return w*h;
}
struct node{
int g[5][5];
}f,res,ans;
int m,a,c,x0,n,mod;
void init(){
res.g[1][1]=x0%m;res.g[1][2]=1;
f.g[1][1]=a%m;f.g[1][2]=0;
f.g[2][1]=c%m;f.g[2][2]=1;
}
int cnt(int x,int y){
int s=0;
while(y){
if(y&1)s=(s+x)%m;
x=x*2%m;
y>>=1;
}
return s%m;
}
node multiple(node x,node y){
memset(ans.g,0,sizeof(ans.g));
for(int i=1;i<=2;i++)
for(int j=1;j<=2;j++)
for(int k=1;k<=2;k++)
ans.g[i][j]=(ans.g[i][j]+cnt(x.g[i][k],y.g[k][j]))%m;
return ans;
}
void Fast(int n){
while(n){
if(n&1)res=multiple(res,f);
f=multiple(f,f);
n>>=1;
}
}
signed main(){
m=read();a=read();c=read();x0=read();n=read();mod=read();
init();
Fast(n);
printf("%lld",res.g[1][1]%mod);
return 0;
}
Isaac
这是我刚上初中时教练给我们出的烤试题,然后我当然是不会的,所以就一直放在收藏里,然后某天我点开了这道题,发现居然是矩阵乘法的标签,于是就谔谔开心的做起来,然后我当然还是不会了的,于是就A掉了这题。
看到这题,第一反应当然是图论,但是看到 \(k\) 的要求和数据范围,立刻就懵圈了知道是矩阵乘法,可是图论和矩阵乘法有什么关系呢?\(QwQ\),于是想了想。
图论和矩阵乘法到底有什么关系呢?请看下文分解:
我们知道:\(C_{i,j}=\sum\limits_{k=1}^{n} A_{i,k}×B_{k,j}\)
而在图论的路径计数中,我们怎样计算从 \(i\) 到 \(j\) 的路径条数呢?与多源最短路算法Floyd相似,我们枚举中转点,根据乘法原理和加法原理,我们得出如下代码:
for(int i=1;i<=N;i++)
for(int j=1;j<=N;j++)
for(int k=1;k<=N;k++)
c[i][j]=c[i][j]+c[i][k]*c[k][j];
与矩阵乘法鲸人的相似!!!
那么,方阵的幂又表示什么呢?
当图 \(G\) 的指数为一时,显然就是一步从 \(i\) 到 \(j\) 的方案数,我们考虑\(G^2\)。
\(G^2\) 时,我们枚举了一个中转点,于是原来的路径断成了两部分,而这正是走两步从 \(i\) 到 \(j\) 的方案数!这时,我们就解决了在 \(k\) 时刻恰好到达终点的问题。
但是这题还有一个难点:珂爱的小怪怪。但只要我们抓住这题的关键:数据范围,就能迎刃而解了。
发现珂爱的小怪怪的游走规律不超过4,而1,2,3,4的最小公倍数仅有12,于是我们考虑分情况讨论。我们把每种情况图中怪物将会到达的点标记下来,在计算答案时判断怪物是否在该点并且血量是否足够通过当前边,快速幂前将12种情况乘在一起达到优化效果,注意矩阵乘法没有交换律,所以第0种情况最后乘。由于题目询问最小值,套上二分即可。
(u1s1,以撒挺好玩的)
新鲜出炉的代码:
#include<bits/stdc++.h>
#define int long long
using namespace std;
int read(){
int w=0,h=1;char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')h=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){w=w*10+ch-'0';ch=getchar();}
return w*h;
}
const int mod=10000;
int n,m,St,En,K,NaCly_Fish;
struct node{
int g[55][55];
node operator * (node b){
node c;memset(c.g,0,sizeof(c.g));
for(int i=1;i<=n;i++){
for(int j=1;j<=n;j++){
for(int k=1;k<=n;k++){
c.g[i][j]|=g[i][k]&b.g[k][j];//二分只需要判断能否到达,所以采用较快的位运算
}
}
}
return c;
}
node operator ^ (int x){
node ans,t=*this;memset(ans.g,0,sizeof(ans.g));
for(int i=1;i<=n;i++)ans.g[i][i]=1;
for(int i=x;i;i>>=1,t=t*t)if(i&1)ans=ans*t;
return ans;
}
}a[12],ans,res,check[13],mp;
signed main(){
n=read();m=read();St=read();En=read();K=read();
for(int i=1;i<=m;i++){
int u=read(),v=read(),val=read();
mp.g[u][v]=mp.g[v][u]=val;
}
NaCly_Fish=read();//知道为什么是NaCly_Fish吗?因为我做了沼泽愕鱼,沼泽愕鱼里食人鱼数量是NFish……
for(int i=0;i<12;i++)check[i]=mp;
for(int i=1;i<=NaCly_Fish;i++){
int T=read();
for(int j=0;j<T;j++){
int cur=read();
for(int k=j;k<12;k+=T)
for(int l=1;l<=n;l++)
check[k].g[l][cur]=0;//标记在一个循环(轮回)中怪物会到达的点
}
}
int l=0,r=3e9;
while(l+1<r){
int mid=(l+r)>>1;
for(int i=0;i<12;i++)a[i]=check[i];
for(int i=0;i<12;i++){
for(int j=1;j<=n;j++){
for(int k=1;k<=n;k++){
if(a[i].g[j][k])a[i].g[j][k]=a[i].g[j][k]<=mid;//判断在一个循环中第i个单位时间到达的点是否有怪且血量是否足够
}
}
}
memset(res.g,0,sizeof(res.g));
for(int i=1;i<=n;i++)res.g[i][i]=1;
for(int i=1;i<12;i++)res=res*a[i];res=res*a[0];
int t=K/12,step=K%12;//t为12种情况一起乘的次数,step为剩余要乘的情况数
memset(ans.g,0,sizeof(ans.g));
for(int i=1;i<=n;i++)ans.g[i][i]=1;
ans=ans*(res^t);
for(int i=1;i<=step;i++)ans=ans*a[i];
if(ans.g[St][En])r=mid;
else l=mid;
}
if(l>=(1<<31-3))cout<<"'IMP0SSBLE!!";
else cout<<r;
return 0;
}

浙公网安备 33010602011771号