矩阵乘法笔记

前言

首先,本人并不想讲那些非常复杂的定义、证明,所以我会用我自己对矩阵的理解,尽可能的简单但是正确的讲出来。

第一:矩阵是什么

首先,我们可以把矩阵理解为一个二维数组。比如一个 \(n\times m\) 的矩阵为。

\[\begin{bmatrix}a_{11}&a_{12}&a_{13}&\cdots&a_{1m}\\a_{21}&a_{22}&a_{23}&\cdots&a_{2m}\\\vdots&\vdots&\vdots&\ddots&\vdots\\a_{n1}&a_{n2}&a_{n3}&\cdots&a_{nm}\end{bmatrix} \]

第二:矩阵的运算

矩阵的加减法非常简单,如下图所示。

\[\begin{bmatrix}a_{11}&a_{12}\\a_{21}&a_{22}\end{bmatrix}\pm \begin{bmatrix}b_{11}&b_{12}\\b_{21}&b_{22}\end{bmatrix}=\begin{bmatrix}a_{11} \pm b_{11}&a_{12} \pm b_{12}\\a_{21} \pm b_{21}&a_{22} \pm b_{22}\end{bmatrix}\]

而矩阵乘法就比较有趣,可以一句话解释,就是用第一个矩阵的第 \(i\) 行乘上第二个矩阵的第 \(j\) 列(每个数对应相乘,再求和)就是他们结果数组的第 \(i\) 行第 \(j\) 列的数。

不理解没有关系,设矩阵 \(C=A \times B\) 用公式写出来就是。

\[C_{i,j}=\sum_{k=1}^{n}A_{i,k}\times B_{k,j} \]

而用代码写出来就是。

for(int i=1;i<=n;i++)
	for(int j=1;j<=n;j++)
		for(int k=1;k<=n;k++)
			res.a[i][j]=(res.a[i][j]=(a[i][k]+b.a[k][j]));

这里有一个小注意事项,就是代码中的 \(n\) 是两个矩阵的行和列的 \(\max\)。因为在实际代码中,特别是矩阵快速幂里,两个矩阵的行大部分情况相同。

第三:矩阵快速幂

首先,先解释一下什么是快速幂。快速幂就是利用一个数在二进制下 \(1\) 的个数来达到以 \(\log\) 的时间复杂度来进行快速计算。

代码如下。

int qpow(int x,int y){
	int sum=1;
	while(y!=0){
		if(y&1)sum=sum*x;
		x=x*x;
		y>>=1;
	}
	return sum;
}

而矩阵快速幂就是在快速幂的基础上把数改成矩阵。
可以看一道例题来理解。P1962

题目就是要求斐波那契数列的第 \(n\) 项,但是 \(n<2^{63}\),所以直接用循环会超时。那么我们就可以用矩阵快速幂,观察一下可以发现,可以把斐波那契数列的递推式写成矩阵形式。

\[\begin{bmatrix}F_{n-1}&F_{n-2}\end{bmatrix}\times\begin{bmatrix}1&1\\1&0\end{bmatrix}=\begin{bmatrix}F_{n}&F_{n-1}\end{bmatrix} \]

这个很好理解,拿左边矩阵的第一行乘上右边的第一列,就会等于结果数组的第一个数,结果数组的第二个数同理。解释一下。

\[\begin{bmatrix}F_{n-1}&F_{n-2}\end{bmatrix}\times\begin{bmatrix}1\\1\end{bmatrix}=F_{n-1}+F_{n-2}=F_n \]

\[\begin{bmatrix}F_{n-1}&F_{n-2}\end{bmatrix}\times\begin{bmatrix}1\\0\end{bmatrix}=F_{n-1} \]

而这个变换就做到了每次乘上一个相同的矩阵就可以算出下一个数。那么求第 \(n\) 项的式子就变成了。

\[\begin{bmatrix}1&1\end{bmatrix}\times\begin{bmatrix}1&1\\1&0\end{bmatrix}^{n-2}=\begin{bmatrix}F_{n}&F_{n-1}\end{bmatrix} \]

这里要注意,因为矩阵满足结合律但不满足交换律,所以写代码时要注意不要写成。

\[\begin{bmatrix}1&1\\1&0\end{bmatrix}\times\begin{bmatrix}F_{n-1}&F_{n-2}\end{bmatrix} \]

下面就是代码啦。

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int M=10;
ll n,mod=1e9+7;
struct tt{
	ll a[M][M];
	tt(){memset(a,0,sizeof(a));}
	tt operator*(const tt &b)const{
		tt res;
		for(int i=1;i<=2;i++)
			for(int j=1;j<=2;j++)
				for(int k=1;k<=2;k++)
					res.a[i][j]=(res.a[i][j]+a[i][k]*b.a[k][j])%mod;
		return res;
	}
}ans,base;
void qpow(ll y){
	while(y!=0){
		if(y&1)ans=ans*base;
		base=base*base;
		y>>=1;
	}
} 
int main(){
	base.a[1][1]=1,base.a[1][2]=1,base.a[2][1]=1;
	ans.a[1][1]=1,ans.a[1][2]=1;
	scanf("%lld",&n);
	if(n<=2){
		printf("1");
		return 0;
	}
	qpow(n-2);
	printf("%lld",ans.a[1][1]%mod);
	return 0;
}

感觉这题太简单了?那就再来一道练练手。P5678

这题就是给出数列 \(\{a_n\}\)\(\{b_n\}\) 以及 \(\{A_n\}\) 的递推关系, 试求出数列 \(\{A_n\}\)\(N\) 项。

递推关系为。

\[A_n=\begin{cases}a_n & 0 \le n < K \\ \bigoplus (A_{n-K+t} \otimes b_t) & n \ge K \end{cases} \]

其中,\(\otimes\) 表示与操作,\(\oplus\) 表示或操作。

容易得到,\(A_n\) 只与 \(b_i\)\(A_i\) 的前 \(k\) 项有关。所以我们可以构建一下矩阵(假设 \(k=4\))。

\[\begin{bmatrix}A_i&A_{i+1}&A_{i+2}&A_{i+3}\end{bmatrix}\times\begin{bmatrix}0&0&0&b_1\\-1&0&0&b_2\\0&-1&0&b_3\\0&0&-1&b_4\end{bmatrix}=\begin{bmatrix}A_{i+1}&A_{i+2}&A_{i+3}&A_{i+4}\end{bmatrix} \]

在这个矩阵里,我们要把原先的结果矩阵第 \(i\)\(j\)\(=\) 第一个矩阵的第 \(i\)\(\times\) 第二个矩阵的第 \(j\) 列再相加,把 \(\times\) 改为按位与,把相加改为按位或就可以了。

那么最终的答案如下(以 \(k=4\) 为例)。

\[\begin{bmatrix}a_1&a_2&a_3&a_4\end{bmatrix}\times\begin{bmatrix}0&0&0&b_1\\-1&0&0&b_2\\0&-1&0&b_3\\0&0&-1&b_4\end{bmatrix}^{n-4}=\begin{bmatrix}A_{n-3}&A_{n-2}&A_{n-1}&A_{n}\end{bmatrix} \]

所以代码也就很好打了,这里解释一下为什么第二的矩阵里面是 \(-1\) 而不是 \(1\),这是因为是要进行与运算,而一个数与上 \(-1\) 还等于他自己,所以用 \(-1\)。而 \(0\) 是因为任何数或 \(0\) 等于他自己。

代码。

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int M=101,N=1010101;
const int mod=1e4;
ll n,m,s,t,k,num,w[N],T,a[N],b[N];
struct tt{
	ll a[M][M];
	tt(){memset(a,0,sizeof(a));}
	tt operator*(const tt &b)const{
		tt res;
		for(int i=1;i<=k;i++)
			for(int j=1;j<=k;j++)
				for(int l=1;l<=k;l++)
					res.a[i][j]=(res.a[i][j]|(a[i][l]&b.a[l][j]));
		return res;
	}
}ans,base,st;
void qpow(tt base,ll y){
	while(y){
		if(y&1)ans=ans*base;
		base=base*base;
		y>>=1;
	}
}
int main(){
	scanf("%lld%lld",&n,&k);
	n++;
	for(int i=1;i<=k;i++)scanf("%lld",&a[i]),ans.a[1][i]=a[i];
	for(int i=1;i<=k;i++)scanf("%lld",&b[i]);
	for(int i=1;i<=k;i++)base.a[i][k]=b[i],base.a[i+1][i]=-1;
	if(n<=k){
		printf("%lld",a[n]);
		return 0;
	}
	qpow(base,n-k);
	printf("%lld",ans.a[1][k]);
	return 0;
}

第四:高斯消元

这个也是矩阵的一大应用,我在这里只介绍讲解高斯—约旦消元法。

首先对于一个三元一次方程组。

\[\begin{Bmatrix}3x&+&2y&+&z&=&10\\5x&+&y&+&6z&=&25\\2x&+&3y&+&4z&=&20\end{Bmatrix} \]

如果按照数学的方法,肯定是分别化 \(x,y,z\) 的系数为一,再进行消元。那么高斯消元的基本原理,你就懂了,就是数学方法,复杂度为 \(O(n^3)\)

下面介绍一下高斯—约旦消元。
首先要构建矩阵。

scanf("%d",&n);
for(int i=1;i<=n;i++)for(int j=1;j<=n+1;j++)cin>>a[i][j];

每一行的第 \(n+1\) 项对应着这一个方程的答案。接下来,要消去第 \(i\) 个未知数,要先找到对应系数最大的一个方程,再交换,判断系数不为 \(0\),若最大系数都为 \(0\) 说明解不出这一个未知数。

for(int i=1;i<=n;i++){
	m=i;
	for(int j=i+1;j<=n;j++)if(abs(a[j][i])>abs(a[m][i]))m=j;
	if(abs(a[m][i])<1e-7){
		printf("No Solution");
		return 0;
	}
	swap(a[i],a[m]);

接下来就是把这一个方程的第 \(i\) 个未知数的系数化为一。

double k=a[i][i];
for(int j=1;j<=n+1;j++)a[i][j]/=k;

然后再减去其他的方程达到消去这一个未知数的目的。

for(int j=1;j<=n;j++){
	if(j==i)continue;
	double div=a[j][i]/a[i][i];
	for(int k=1;k<=n+1;k++)a[j][k]=a[j][k]-div*a[i][k];
}

最终化简完的矩阵长这样。

\[\begin{bmatrix}1&0&0&\cdots&0&ans_1\\0&1&0&\cdots&0&ans_2\\0&0&1&\cdots&0&ans_3\\\vdots&\vdots&\vdots&\ddots&\vdots&\vdots\\0&0&0&\cdots&1&ans_n\end{bmatrix} \]

所以对于第 \(i\) 个未知量的解就是对应这一行的第 \(n+1\) 个数。

完整代码如下。

#include<bits/stdc++.h>
using namespace std;
const int N=1010;
int n,m;
double a[N][N];
void gauss(){
	for(int i=1;i<=n;i++){
		m=i;
		for(int j=i+1;j<=n;j++)if(abs(a[j][i])>abs(a[m][i]))m=j;
		if(abs(a[m][i])<1e-7){
			printf("No Solution");
			return 0;
		}
		for(int j=1;j<=n+1;j++)swap(a[i][j],a[m][j]);
		double k=a[i][i];
		for(int j=1;j<=n+1;j++)a[i][j]/=k;
		for(int j=1;j<=n;j++){
			if(j==i)continue;
			double div=a[j][i]/a[i][i];
			for(int k=1;k<=n+1;k++)a[j][k]=a[j][k]-div*a[i][k];
		}
	}	
}
int main(){
	scanf("%d",&n);
	for(int i=1;i<=n;i++)for(int j=1;j<=n+1;j++)cin>>a[i][j];
	gauss();
	for(int i=1;i<=n;i++)printf("%0.2lf\n",a[i][n+1]);
	return 0;
}

因为高斯消元的题的思路感觉没有特别经典的,所以就不进行例题讲解了,提供几道练手的题。P2447 P2973 P4035

第五:矩乘与图论

给定一个有向图,问从 \(S\) 点恰好走 \(k\) 步(允许重复经过边)到达 \(T\) 点的方案数 \(\bmod p\) 的值。

这类题有一个经典的做法,给定了有向图,可以得到该图的邻接矩阵 \(A\),在邻接矩阵 \(A\) 中,\(A(i,j)=1\) 当且仅当存在一条边 \(i->j\)。若 \(i->j\) 不存在直接相连接的边,则 \(A(i,j)=0\)
\(C=A\times A\),那么 \(C(i,j)=\sum A(i,k)*A(k,j)\),实际上就等于从点 \(i\) 到点 \(j\) 恰好经过 \(2\) 条边的路径数(\(k\) 为中转点)。
类似地,令 \(C=A\times A\times A\)的第 \(i\) 行第 \(j\) 列就表示从 \(i\)\(j\) 经过 \(3\) 条边的路径数。同理,如果要求经过 \(k\) 步的路径数,只需要采用快速幂运算求出 \(A^k\) 即可。通常我们会定义一个单位矩阵为 \(C\) 来进行计算,也就是只再对角线上的数为 \(1\) 的矩阵。

\[\begin{bmatrix}1&0&0\\0&1&0\\0&0&1\end{bmatrix} \]

那么现在来看一道例题。P4159


题目大意

有向图有 \(n\) 个节点,节点从 \(1\)\(n\) 编号,\(windy\) 从节点 \(1\) 出发,他必须恰好在 \(t\) 时刻到达节点 \(n\)。现在给出该有向图,你能告诉 \(windy\) 总共有多少种不同的路径吗?答案对 \(2009\) 取模。

思路

这题与模版不同点在于从 \(i->j\) 要花费 \(c_{ij}\) 的时间,但是这个很好解决,因为 \(c_{ij}<=9\),所以可以把每一条边拆成 \(c_{ij}\) 条边,花费都是 \(1\),那么是不是就与模板一样了。

code

#include<bits/stdc++.h>
using namespace std;
const int M=1010;
int n,m,t,A,B,sum,mod=2009,tot,x[M],y[M];
struct tt{
	int a[M][M];
	tt(){memset(a,0,sizeof(a));}
	tt operator*(const tt &b)const{
		tt res;
		for(int i=1;i<=n;i++)
			for(int j=1;j<=n;j++)
				for(int k=1;k<=n;k++)
					res.a[i][j]=(res.a[i][j]+a[i][k]*b.a[k][j])%mod; 
		return res;
	}
}ans,base;
void qpow(int k){
	while(k!=0){
		if(k&1)ans=ans*base;
		base=base*base;
		k>>=1;
	}
}
int main(){
	scanf("%d%d",&n,&m);
	int nn=n;n*=9;
	for(int i=1;i<=nn;i++)
		for(int j=1;j<=8;j++)
			base.a[9*(i-1)+j][9*(i-1)+j+1]=1;
	for(int i=1;i<=nn;i++){
		char s[1010];
		scanf("%s",s+1);
		for(int j=1;j<=nn;j++){
			if(s[j]>'0')base.a[9*(i-1)+s[j]-'0'][9*(j-1)+1]=1;
		}
	}
	for(int i=1;i<=nn;i++)ans.a[i][i]=1;
	qpow(m);
	printf("%d",ans.a[1][nn*9-8]);
	return 0;
}

再来一道吧。P2579

题目大意

一个无向图,每一时刻有一些点不能到达,这些点的出现有三种周期 \(3,4,5\),求 \(s\)\(t\) 的路径数。

思路

注意到周期很受限,而且 \(3,4,5\) 的最小公倍数为 \(12\),所以我们可以暴力枚举出从开始到经过了 \(12\) 个时间的矩阵 \(base_i\),然后以这 \(12\) 个图为一个周期,如果经过若干个周期后时间 \(k\) 还有剩余,就依次乘上 \(base_1\)\(base_k\)

code

#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int M=101,N=1010101;
const int mod=1e4;
ll n,m,s,t,k,num,w[N],T;
struct tt{
	ll a[M][M];
	tt(){memset(a,0,sizeof(a));}
	tt operator*(const tt &b)const{
		tt res;
		for(int i=1;i<=n;i++)
			for(int j=1;j<=n;j++)
				for(int k=1;k<=n;k++)
					res.a[i][j]=(res.a[i][j]+a[i][k]*b.a[k][j])%mod;
		return res;
	}
}ans,base[13],st;
tt qpow(tt base,ll y){
	tt res=st;
	while(y!=0){
		if(y&1)res=res*base;
		base=base*base;
		y>>=1;
	}
	return res;
}
int main(){
	scanf("%lld%lld%lld%lld%lld",&n,&m,&s,&t,&k);
	s++,t++;
	for(int i=1;i<=n;i++)ans.a[i][i]=st.a[i][i]=1;
	for(int i=1;i<=m;i++){
		ll x,y;
		scanf("%lld%lld",&x,&y);
		x++,y++;
		for(int j=1;j<=12;j++)base[j].a[x][y]=base[j].a[y][x]=1;
	}
	scanf("%lld",&T);
	while(T--){
		scanf("%lld",&num);
		for(int i=1;i<=num;i++)scanf("%lld",&w[i]),w[i]++;
		for(int i=1;i<=12;i++)
			for(int j=1;j<=n;j++)
				base[i].a[j][w[i%num+1]]=0;
	}
	for(int i=1;i<=12;i++)ans=(ans*base[i]);
	ans=qpow(ans,k/12);
	for(int i=1;i<=k%12;i++)ans=(ans*base[i]);
	printf("%lld",ans.a[s][t]);
	return 0;
}

第六:结语

这可能还是一篇不完善的笔记,我后续还会继续改进。
希望大家通过我这篇并不是很严谨的笔记,能对矩阵有更深的认识。撒花。

posted @ 2025-09-10 21:40  一班的hoko  阅读(17)  评论(0)    收藏  举报