快速矩阵幂
两矩阵相乘,朴素算法的复杂度是O(N^3)。如果求一次矩阵的M次幂,按朴素的写法就是O(N^3*M)。既然是求幂,不免想到快速幂取模的算法,这里有快速幂取模的介绍,a^b %m 的复杂度可以降到O(logb)。如果矩阵相乘是不是也可以实现O(N^3 * logM)的时间复杂度呢?答案是肯定的。
先定义矩阵数据结构:
struct Mat {
double mat[N][N];
};
O(N^3)实现一次矩阵乘法
Mat operator * (Mat a, Mat b) {
Mat c;
memset(c.mat, 0, sizeof(c.mat));
int i, j, k;
for(k = 0; k < n; ++k) {
for(i = 0; i < n; ++i) {
if(a.mat[i][k] <= 0) continue; //不要小看这里的剪枝,cpu运算乘法的效率并不是想像的那么理想(加法的运算效率高于乘法,比如Strassen矩阵乘法)
for(j = 0; j < n; ++j) {
if(b.mat[k][j] <= 0) continue; //剪枝
c.mat[i][j] += a.mat[i][k] * b.mat[k][j];
}
}
}
return c;
}
下面介绍一种特殊的矩阵:单位矩阵

很明显的可以推知,任何矩阵乘以单位矩阵,其值不改变。
有了前边的介绍,就可以实现矩阵的快速连乘了。
Mat operator ^ (Mat a, int k) {
Mat c;
int i, j;
for(i = 0; i < n; ++i)
for(j = 0; j < n; ++j)
c.mat[i][j] = (i == j); //初始化为单位矩阵
for(; k; k >>= 1) {
if(k&1) c = c*a;
a = a*a;
}
return c;
}
举个例子:
求第n个Fibonacci数模M的值。如果这个n非常大的话,普通的递推时间复杂度为O(n),这样的复杂度很有可能会挂掉。这里可以用矩阵做优化,复杂度可以降到O(logn * 2^3)
如图:

A = F(n - 1), B = F(N - 2),这样使构造矩阵
的n次幂乘以初始矩阵
得到的结果就是
。
因为是2*2的据称,所以一次相乘的时间复杂度是O(2^3),总的复杂度是O(logn * 2^3 + 2*2*1)。
下面给出一种比较基础的类型的矩阵快速幂:
f(n)= a*f(n-1)+b*f(n-2)型
下面两题适合作为此种矩阵快速幂的模板来使用
http://poj.org/problem?id=3070 (纯模板题,直接用)
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
struct Mat
{
int mat[2][2];
};
Mat d;
int n,mod;
Mat mul(Mat a,Mat b)
{
Mat c;
memset(c.mat,0,sizeof(c.mat));
for(int i=0;i<n;++i)
{
for(int k=0;k<n;++k)
{
if(a.mat[i][k])
for(int j=0;j<n;++j)
{
c.mat[i][j]+=a.mat[i][k]*b.mat[k][j];
if(c.mat[i][j]>=mod) c.mat[i][j]%=mod;
}
}
}
return c;
}
Mat expo(Mat p,int k)
{
if(k==1) return p;
Mat e;
memset(e.mat,0,sizeof(e.mat));
for(int i=0;i<n;++i)
e.mat[i][i]=1;
if(k==0) return e;
while(k)
{
if(k&1) e=mul(p,e);
p=mul(p,p);
k>>=1;
}
return e;
}
int main()
{
n=2;
mod=10000;
d.mat[1][1]=0;
d.mat[0][1]=d.mat[1][0]=d.mat[0][0]=1;
int k;
while(cin>>k)
{
if(k==-1) break;
Mat res=expo(d,k);
int ans=res.mat[0][1]%mod;
cout<<ans<<endl;
}
return 0;
}
链接:http://codeforces.com/contest/450/problem/B (纯模板题的变形题)
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
struct Mat
{
int mat[2][2];
};
Mat d;
int n,mod;
Mat mul(Mat a,Mat b)
{
Mat c;
memset(c.mat,0,sizeof(c.mat));
for(int i=0;i<n;++i)
{
for(int k=0;k<n;++k)
{
if(a.mat[i][k])
for(int j=0;j<n;++j)
{
c.mat[i][j]+=a.mat[i][k]*b.mat[k][j];
if(c.mat[i][j]>=mod) c.mat[i][j]%=mod;
}
}
}
return c;
}
Mat expo(Mat p,int k)
{
if(k==1) return p;
Mat e;
memset(e.mat,0,sizeof(e.mat));
for(int i=0;i<n;++i)
e.mat[i][i]=1;
if(k==0) return e;
while(k)
{
if(k&1) e=mul(p,e);
p=mul(p,p);
k>>=1;
}
return e;
}
int main()
{
n=2;
mod=10000;
d.mat[1][1]=0;
d.mat[0][1]=d.mat[1][0]=d.mat[0][0]=1;
int k;
while(cin>>k)
{
if(k==-1) break;
Mat res=expo(d,k);
int ans=res.mat[0][1]%mod;
cout<<ans<<endl;
}
return 0;
}
S = A + A^2 + A^3 + … + A^k类型
链接:http://poj.org/problem?id=3233
给定三个参数n、k、m,n为矩阵的行数和列数,k表示最高次幂,m用于取模。
对于给定的矩阵A,要求输出A^1+A^2+……+A^k的结果矩阵。
求A^i可以使用二分快速幂,这个是足够快的了。
但k最大可以达到10^9,因此虽然题目只有一组数据,但直接一次循环也必然超时。
这里的求和可以采用二分的思想:
对于S=A^1+A^2+……+A^k
若k是偶数,则S=(1+A^(k/2))(A^1+A^2+……+A^(k/2))
若k是奇数,则S=(1+A^(k/2))(A^1+A^2+……+A^(k/2))+A^k
以上的k/2指的是程序中的除法,即舍弃小数的除法。
采用这种二分思想,可以大大减少时间复杂度,因此可以满足题目的要求。
应当注意的是这里要求的结果矩阵是每个元素模m之后的矩阵,可以在运算过程中可能超过m的时候判断一下,对m取模。
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int maxn=31;
struct Mat
{
int mat[maxn][maxn];
};
Mat d;
int n,m;
Mat mul(Mat a,Mat b)
{
Mat c;
memset(c.mat,0,sizeof(c.mat));
for(int i=0;i<n;++i)
{
for(int k=0;k<n;++k)
{
if(a.mat[i][k])
for(int j=0;j<n;++j)
{
c.mat[i][j]+=a.mat[i][k]*b.mat[k][j];
if(c.mat[i][j]>=m) c.mat[i][j]%=m;
}
}
}
return c;
}
Mat expo(Mat p,int k)
{
if(k==1) return p;
Mat e;
memset(e.mat,0,sizeof(e.mat));
for(int i=0;i<n;++i)
e.mat[i][i]=1;
while(k)
{
if(k&1) e=mul(p,e);
p=mul(p,p);
k>>=1;
}
return e;
}
Mat sum(Mat p,int k)
{
for(int i=0;i<n;++i)
{
for(int j=0;j<n;++j)
{
if(p.mat[i][j]>=m) p.mat[i][j]%=m;
}
}
if(k==1) return p;
Mat m1=expo(p,k/2);
for(int i=0;i<n;++i)
m1.mat[i][i]+=1;
Mat m2=sum(p,k/2);
Mat m3=mul(m1,m2);
if(k&1)
{
Mat temp=expo(p,k);
for(int i=0;i<n;++i)
{
for(int j=0;j<n;++j)
{
m3.mat[i][j]+=temp.mat[i][j];
if(m3.mat[i][j]>=m) m3.mat[i][j]%=m;
}
}
}
return m3;
}
int main()
{
int k;
while(cin>>n>>k>>m)
{
for(int i=0;i<n;++i)
for(int j=0;j<n;++j)
scanf("%d",&d.mat[i][j]);
Mat arry=sum(d,k);
for(int i=0;i<n;++i)
{
for(int j=0;j<n;++j)
{
if(j!=0) printf(" ");
printf("%d",arry.mat[i][j]);
}
printf("\n");
}
}
return 0;
}
矩阵变换类型
这种题目可以用矩阵快速幂,从而实现矩阵的多次变换
链接:http://poj.org/problem?id=3735
【题意】:有n只猫咪,开始时每只猫咪有花生0颗,现有一组操作,由下面三个中的k个操作组成:
1. g i 给i只猫咪一颗花生米
2. e i 让第i只猫咪吃掉它拥有的所有花生米
3. s i j 将猫咪i与猫咪j的拥有的花生米交换
现将上述一组操作做m次后,问每只猫咪有多少颗花生?
【题解】:m达到10^9,显然不能直接算。
因为k个操作给出之后就是固定的,所以想到用矩阵,矩阵快速幂可以把时间复杂度降到O(logm)。问题转化为如何构造转置矩阵?
说下我的思路,观察以上三种操作,发现第二,三种操作比较容易处理,重点落在第一种操作上。
有一个很好的办法就是添加一个辅助,使初始矩阵变为一个n+1元组,编号为0到n,下面以3个猫为例:
定义初始矩阵A = [1 0 0 0],0号元素固定为1,1~n分别为对应的猫所拥有的花生数。
对于第一种操作g i,我们在单位矩阵基础上使Mat[0][i]变为1,例如g 1:
1 1 0 0
0 1 0 0
0 0 1 0
0 0 0 1,显然[1 0 0 0]*Mat = [1 1 0 0]
对于第二种操作e i,我们在单位矩阵基础使Mat[i][i] = 0,例如e 2:
1 0 0 0
0 1 0 0
0 0 0 0
0 0 0 1, 显然[1 2 3 4]*Mat = [1 2 0 4]
对于第三种操作s i j,我们在单位矩阵基础上使第i列与第j互换,例如s 1 2:
1 0 0 0
0 0 0 1
0 0 1 0
0 1 0 0,显然[1 2 0 4]*Mat = [1 4 0 2]
现在,对于每一个操作我们都可以得到一个转置矩阵,把k个操作的矩阵相乘我们可以得到一个新的转置矩阵T。
A * T 表示我们经过一组操作,类似我们可以得到经过m组操作的矩阵为 A * T ^ m,最终矩阵的[0][1~n]即为答案。
上述的做法比较直观,但是实现过于麻烦,因为要构造k个不同矩阵。
有没有别的方法可以直接构造转置矩阵T?答案是肯定的。
我们还是以单位矩阵为基础:
对于第一种操作g i,我们使Mat[0][i] = Mat[0][i] + 1;
对于第二种操作e i,我们使矩阵的第i列清零;
对于第三种操作s i j,我们使第i列与第j列互换。
这样实现的话,我们始终在处理一个矩阵,免去构造k个矩阵的麻烦。
至此,构造转置矩阵T就完成了,接下来只需用矩阵快速幂求出 A * T ^ m即可,还有一个注意的地方,该题需要用到long long。
具体实现可以看下面的代码。
个人采用的是第二种方法
#include <iostream>
#include <cstring>
#include <cstdio>
#define LL long long
using namespace std;
struct met{
LL at[105][105];
};
met ret,d;
LL n,m,k;
met mul(met a,met b)
{
memset(ret.at,0,sizeof(ret.at));
for(int i=0;i<=n;++i)
{
for(int k=0;k<=n;++k)
{
if(a.at[i][k])
{
for(int j=0;j<=n;++j)
{
ret.at[i][j]+=a.at[i][k]*b.at[k][j];
}
}
}
}
return ret;
}
met expo(met a,LL k)
{
if(k==1) return a;
met e;
memset(e.at,0,sizeof(e.at));
for(int i=0;i<=n;++i){e.at[i][i]=1;}
if(k==0)return e;
while(k)
{
if(k&1)e=mul(e,a);
k>>=1;
a=mul(a,a);
}
return e;
}
int main()
{
while(~scanf("%lld%lld%lld",&n,&m,&k))
{
LL a,b;
char ch[5];
if(!n&&!k&&!m)break;
memset(d.at,0,sizeof(d.at));
for(int i=0;i<=n;++i)
{d.at[i][i]=1;}
while(k--)
{
scanf("%s",ch);
if(ch[0]=='g')
{
scanf("%lld",&a);
d.at[0][a]++;
}
else if(ch[0]=='e')
{
scanf("%lld",&a);
for(int i=0;i<=n;++i)
{
d.at[i][a]=0;
}
}
else {
scanf("%lld%lld",&a,&b);
for(int i=0;i<=n;++i)
{
LL t=d.at[i][a];
d.at[i][a]=d.at[i][b];
d.at[i][b]=t;
}
}
}
met ans=expo(d,m);
printf("%lld",ans.at[0][1]);
for(int i=2;i<=n;++i)
{
printf(" %lld",ans.at[0][i]);
}
printf("\n");
}
return 0;
}


浙公网安备 33010602011771号