矩阵乘法笔记
前言
首先,本人并不想讲那些非常复杂的定义、证明,所以我会用我自己对矩阵的理解,尽可能的简单但是正确的讲出来。
第一:矩阵是什么
首先,我们可以把矩阵理解为一个二维数组。比如一个 \(n\times m\) 的矩阵为。
第二:矩阵的运算
矩阵的加减法非常简单,如下图所示。
而矩阵乘法就比较有趣,可以一句话解释,就是用第一个矩阵的第 \(i\) 行乘上第二个矩阵的第 \(j\) 列(每个数对应相乘,再求和)就是他们结果数组的第 \(i\) 行第 \(j\) 列的数。
不理解没有关系,设矩阵 \(C=A \times B\) 用公式写出来就是。
而用代码写出来就是。
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
for(int k=1;k<=n;k++)
res.a[i][j]=(res.a[i][j]=(a[i][k]+b.a[k][j]));
这里有一个小注意事项,就是代码中的 \(n\) 是两个矩阵的行和列的 \(\max\)。因为在实际代码中,特别是矩阵快速幂里,两个矩阵的行大部分情况相同。
第三:矩阵快速幂
首先,先解释一下什么是快速幂。快速幂就是利用一个数在二进制下 \(1\) 的个数来达到以 \(\log\) 的时间复杂度来进行快速计算。
代码如下。
int qpow(int x,int y){
int sum=1;
while(y!=0){
if(y&1)sum=sum*x;
x=x*x;
y>>=1;
}
return sum;
}
而矩阵快速幂就是在快速幂的基础上把数改成矩阵。
可以看一道例题来理解。P1962
题目就是要求斐波那契数列的第 \(n\) 项,但是 \(n<2^{63}\),所以直接用循环会超时。那么我们就可以用矩阵快速幂,观察一下可以发现,可以把斐波那契数列的递推式写成矩阵形式。
这个很好理解,拿左边矩阵的第一行乘上右边的第一列,就会等于结果数组的第一个数,结果数组的第二个数同理。解释一下。
而这个变换就做到了每次乘上一个相同的矩阵就可以算出下一个数。那么求第 \(n\) 项的式子就变成了。
这里要注意,因为矩阵满足结合律但不满足交换律,所以写代码时要注意不要写成。
下面就是代码啦。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int M=10;
ll n,mod=1e9+7;
struct tt{
ll a[M][M];
tt(){memset(a,0,sizeof(a));}
tt operator*(const tt &b)const{
tt res;
for(int i=1;i<=2;i++)
for(int j=1;j<=2;j++)
for(int k=1;k<=2;k++)
res.a[i][j]=(res.a[i][j]+a[i][k]*b.a[k][j])%mod;
return res;
}
}ans,base;
void qpow(ll y){
while(y!=0){
if(y&1)ans=ans*base;
base=base*base;
y>>=1;
}
}
int main(){
base.a[1][1]=1,base.a[1][2]=1,base.a[2][1]=1;
ans.a[1][1]=1,ans.a[1][2]=1;
scanf("%lld",&n);
if(n<=2){
printf("1");
return 0;
}
qpow(n-2);
printf("%lld",ans.a[1][1]%mod);
return 0;
}
感觉这题太简单了?那就再来一道练练手。P5678
这题就是给出数列 \(\{a_n\}\) 和 \(\{b_n\}\) 以及 \(\{A_n\}\) 的递推关系, 试求出数列 \(\{A_n\}\) 第 \(N\) 项。
递推关系为。
其中,\(\otimes\) 表示与操作,\(\oplus\) 表示或操作。
容易得到,\(A_n\) 只与 \(b_i\) 和 \(A_i\) 的前 \(k\) 项有关。所以我们可以构建一下矩阵(假设 \(k=4\))。
在这个矩阵里,我们要把原先的结果矩阵第 \(i\) 行 \(j\) 列 \(=\) 第一个矩阵的第 \(i\) 行 \(\times\) 第二个矩阵的第 \(j\) 列再相加,把 \(\times\) 改为按位与,把相加改为按位或就可以了。
那么最终的答案如下(以 \(k=4\) 为例)。
所以代码也就很好打了,这里解释一下为什么第二的矩阵里面是 \(-1\) 而不是 \(1\),这是因为是要进行与运算,而一个数与上 \(-1\) 还等于他自己,所以用 \(-1\)。而 \(0\) 是因为任何数或 \(0\) 等于他自己。
代码。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int M=101,N=1010101;
const int mod=1e4;
ll n,m,s,t,k,num,w[N],T,a[N],b[N];
struct tt{
ll a[M][M];
tt(){memset(a,0,sizeof(a));}
tt operator*(const tt &b)const{
tt res;
for(int i=1;i<=k;i++)
for(int j=1;j<=k;j++)
for(int l=1;l<=k;l++)
res.a[i][j]=(res.a[i][j]|(a[i][l]&b.a[l][j]));
return res;
}
}ans,base,st;
void qpow(tt base,ll y){
while(y){
if(y&1)ans=ans*base;
base=base*base;
y>>=1;
}
}
int main(){
scanf("%lld%lld",&n,&k);
n++;
for(int i=1;i<=k;i++)scanf("%lld",&a[i]),ans.a[1][i]=a[i];
for(int i=1;i<=k;i++)scanf("%lld",&b[i]);
for(int i=1;i<=k;i++)base.a[i][k]=b[i],base.a[i+1][i]=-1;
if(n<=k){
printf("%lld",a[n]);
return 0;
}
qpow(base,n-k);
printf("%lld",ans.a[1][k]);
return 0;
}
第四:高斯消元
这个也是矩阵的一大应用,我在这里只介绍讲解高斯—约旦消元法。
首先对于一个三元一次方程组。
如果按照数学的方法,肯定是分别化 \(x,y,z\) 的系数为一,再进行消元。那么高斯消元的基本原理,你就懂了,就是数学方法,复杂度为 \(O(n^3)\)。
下面介绍一下高斯—约旦消元。
首先要构建矩阵。
scanf("%d",&n);
for(int i=1;i<=n;i++)for(int j=1;j<=n+1;j++)cin>>a[i][j];
每一行的第 \(n+1\) 项对应着这一个方程的答案。接下来,要消去第 \(i\) 个未知数,要先找到对应系数最大的一个方程,再交换,判断系数不为 \(0\),若最大系数都为 \(0\) 说明解不出这一个未知数。
for(int i=1;i<=n;i++){
m=i;
for(int j=i+1;j<=n;j++)if(abs(a[j][i])>abs(a[m][i]))m=j;
if(abs(a[m][i])<1e-7){
printf("No Solution");
return 0;
}
swap(a[i],a[m]);
接下来就是把这一个方程的第 \(i\) 个未知数的系数化为一。
double k=a[i][i];
for(int j=1;j<=n+1;j++)a[i][j]/=k;
然后再减去其他的方程达到消去这一个未知数的目的。
for(int j=1;j<=n;j++){
if(j==i)continue;
double div=a[j][i]/a[i][i];
for(int k=1;k<=n+1;k++)a[j][k]=a[j][k]-div*a[i][k];
}
最终化简完的矩阵长这样。
所以对于第 \(i\) 个未知量的解就是对应这一行的第 \(n+1\) 个数。
完整代码如下。
#include<bits/stdc++.h>
using namespace std;
const int N=1010;
int n,m;
double a[N][N];
void gauss(){
for(int i=1;i<=n;i++){
m=i;
for(int j=i+1;j<=n;j++)if(abs(a[j][i])>abs(a[m][i]))m=j;
if(abs(a[m][i])<1e-7){
printf("No Solution");
return 0;
}
for(int j=1;j<=n+1;j++)swap(a[i][j],a[m][j]);
double k=a[i][i];
for(int j=1;j<=n+1;j++)a[i][j]/=k;
for(int j=1;j<=n;j++){
if(j==i)continue;
double div=a[j][i]/a[i][i];
for(int k=1;k<=n+1;k++)a[j][k]=a[j][k]-div*a[i][k];
}
}
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++)for(int j=1;j<=n+1;j++)cin>>a[i][j];
gauss();
for(int i=1;i<=n;i++)printf("%0.2lf\n",a[i][n+1]);
return 0;
}
因为高斯消元的题的思路感觉没有特别经典的,所以就不进行例题讲解了,提供几道练手的题。P2447 P2973 P4035
第五:矩乘与图论
给定一个有向图,问从 \(S\) 点恰好走 \(k\) 步(允许重复经过边)到达 \(T\) 点的方案数 \(\bmod p\) 的值。
这类题有一个经典的做法,给定了有向图,可以得到该图的邻接矩阵 \(A\),在邻接矩阵 \(A\) 中,\(A(i,j)=1\) 当且仅当存在一条边 \(i->j\)。若 \(i->j\) 不存在直接相连接的边,则 \(A(i,j)=0\)。
令 \(C=A\times A\),那么 \(C(i,j)=\sum A(i,k)*A(k,j)\),实际上就等于从点 \(i\) 到点 \(j\) 恰好经过 \(2\) 条边的路径数(\(k\) 为中转点)。
类似地,令 \(C=A\times A\times A\)的第 \(i\) 行第 \(j\) 列就表示从 \(i\) 到 \(j\) 经过 \(3\) 条边的路径数。同理,如果要求经过 \(k\) 步的路径数,只需要采用快速幂运算求出 \(A^k\) 即可。通常我们会定义一个单位矩阵为 \(C\) 来进行计算,也就是只再对角线上的数为 \(1\) 的矩阵。
那么现在来看一道例题。P4159
题目大意
有向图有 \(n\) 个节点,节点从 \(1\) 至 \(n\) 编号,\(windy\) 从节点 \(1\) 出发,他必须恰好在 \(t\) 时刻到达节点 \(n\)。现在给出该有向图,你能告诉 \(windy\) 总共有多少种不同的路径吗?答案对 \(2009\) 取模。
思路
这题与模版不同点在于从 \(i->j\) 要花费 \(c_{ij}\) 的时间,但是这个很好解决,因为 \(c_{ij}<=9\),所以可以把每一条边拆成 \(c_{ij}\) 条边,花费都是 \(1\),那么是不是就与模板一样了。
code
#include<bits/stdc++.h>
using namespace std;
const int M=1010;
int n,m,t,A,B,sum,mod=2009,tot,x[M],y[M];
struct tt{
int a[M][M];
tt(){memset(a,0,sizeof(a));}
tt operator*(const tt &b)const{
tt res;
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
for(int k=1;k<=n;k++)
res.a[i][j]=(res.a[i][j]+a[i][k]*b.a[k][j])%mod;
return res;
}
}ans,base;
void qpow(int k){
while(k!=0){
if(k&1)ans=ans*base;
base=base*base;
k>>=1;
}
}
int main(){
scanf("%d%d",&n,&m);
int nn=n;n*=9;
for(int i=1;i<=nn;i++)
for(int j=1;j<=8;j++)
base.a[9*(i-1)+j][9*(i-1)+j+1]=1;
for(int i=1;i<=nn;i++){
char s[1010];
scanf("%s",s+1);
for(int j=1;j<=nn;j++){
if(s[j]>'0')base.a[9*(i-1)+s[j]-'0'][9*(j-1)+1]=1;
}
}
for(int i=1;i<=nn;i++)ans.a[i][i]=1;
qpow(m);
printf("%d",ans.a[1][nn*9-8]);
return 0;
}
再来一道吧。P2579
题目大意
一个无向图,每一时刻有一些点不能到达,这些点的出现有三种周期 \(3,4,5\),求 \(s\) 到 \(t\) 的路径数。
思路
注意到周期很受限,而且 \(3,4,5\) 的最小公倍数为 \(12\),所以我们可以暴力枚举出从开始到经过了 \(12\) 个时间的矩阵 \(base_i\),然后以这 \(12\) 个图为一个周期,如果经过若干个周期后时间 \(k\) 还有剩余,就依次乘上 \(base_1\) 到 \(base_k\)。
code
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int M=101,N=1010101;
const int mod=1e4;
ll n,m,s,t,k,num,w[N],T;
struct tt{
ll a[M][M];
tt(){memset(a,0,sizeof(a));}
tt operator*(const tt &b)const{
tt res;
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
for(int k=1;k<=n;k++)
res.a[i][j]=(res.a[i][j]+a[i][k]*b.a[k][j])%mod;
return res;
}
}ans,base[13],st;
tt qpow(tt base,ll y){
tt res=st;
while(y!=0){
if(y&1)res=res*base;
base=base*base;
y>>=1;
}
return res;
}
int main(){
scanf("%lld%lld%lld%lld%lld",&n,&m,&s,&t,&k);
s++,t++;
for(int i=1;i<=n;i++)ans.a[i][i]=st.a[i][i]=1;
for(int i=1;i<=m;i++){
ll x,y;
scanf("%lld%lld",&x,&y);
x++,y++;
for(int j=1;j<=12;j++)base[j].a[x][y]=base[j].a[y][x]=1;
}
scanf("%lld",&T);
while(T--){
scanf("%lld",&num);
for(int i=1;i<=num;i++)scanf("%lld",&w[i]),w[i]++;
for(int i=1;i<=12;i++)
for(int j=1;j<=n;j++)
base[i].a[j][w[i%num+1]]=0;
}
for(int i=1;i<=12;i++)ans=(ans*base[i]);
ans=qpow(ans,k/12);
for(int i=1;i<=k%12;i++)ans=(ans*base[i]);
printf("%lld",ans.a[s][t]);
return 0;
}
第六:结语
这可能还是一篇不完善的笔记,我后续还会继续改进。
希望大家通过我这篇并不是很严谨的笔记,能对矩阵有更深的认识。撒花。

浙公网安备 33010602011771号