对形如AX=B的矩阵方程求解
首先看题
题目描述
\(求矩阵X使得AX=B。答案对998244353取模。\)
输入格式
\(第一行两个整数n和m,代表矩阵A的长宽\)
\(之后n行m列,表示矩阵A。\)
\(之后一行两个整数n和r,代表矩B的长宽。\)
\(之后n行r列,表示矩阵B。\)
\(1≤n,m≤400,1≤n\cdot m\cdot r≤6.4\times10^{7},保证矩阵A可逆\)
输出格式
\(m行r列,表示矩阵X。\)
首先想什么样的两个矩阵相乘会等于\(B\)。显然\(IB=B\)(废话),那么我们只要求出\(A^{-1}\),再乘上\(B\)就好了。
那么怎么求出\(A^{-1}\)呢?我们注意到以下式子:
\(\begin{bmatrix}
1&0&0&0&\cdots&0\\
0&0&1&0&\cdots&0\\
0&1&0&0&\cdots&0\\
0&0&0&1&\cdots&0\\
\vdots &\vdots&\vdots&\vdots&\ddots&\vdots\\
0&0&0&0&\cdots&1
\end{bmatrix}
\begin{bmatrix}
a_{11}&a_{12}&a_{13}&a_{14}&\cdots&a_{1m}\\
a_{21}&a_{22}&a_{23}&a_{24}&\cdots&a_{2m}\\
a_{31}&a_{32}&a_{33}&a_{34}&\cdots&a_{3m}\\
a_{41}&a_{42}&a_{43}&a_{44}&\cdots&a_{4m}\\
\vdots &\vdots&\vdots&\vdots&\ddots&\vdots\\
a_{n1}&a_{n2}&a_{n3}&a_{n4}&\cdots&a_{nm}\\
\end{bmatrix}=
\begin{bmatrix}
a_{11}&a_{12}&a_{13}&a_{14}&\cdots&a_{1m}\\
a_{31}&a_{32}&a_{33}&a_{34}&\cdots&a_{3m}\\
a_{21}&a_{22}&a_{23}&a_{24}&\cdots&a_{2m}\\
a_{41}&a_{42}&a_{43}&a_{44}&\cdots&a_{4m}\\
\vdots &\vdots&\vdots&\vdots&\ddots&\vdots\\
a_{n1}&a_{n2}&a_{n3}&a_{n4}&\cdots&a_{nm}\\
\end{bmatrix}\)
\(\begin{bmatrix} 1&0&0&0&\cdots&0\\ 0&k&0&0&\cdots&0\\ 0&0&1&0&\cdots&0\\ 0&0&0&1&\cdots&0\\ \vdots &\vdots&\vdots&\vdots&\ddots&\vdots\\ 0&0&0&0&\cdots&1 \end{bmatrix} \begin{bmatrix} a_{11}&a_{12}&a_{13}&a_{14}&\cdots&a_{1m}\\ a_{21}&a_{22}&a_{23}&a_{24}&\cdots&a_{2m}\\ a_{31}&a_{32}&a_{33}&a_{34}&\cdots&a_{3m}\\ a_{41}&a_{42}&a_{43}&a_{44}&\cdots&a_{4m}\\ \vdots &\vdots&\vdots&\vdots&\ddots&\vdots\\ a_{n1}&a_{n2}&a_{n3}&a_{n4}&\cdots&a_{nm}\\ \end{bmatrix}= \begin{bmatrix} a_{11}&a_{12}&a_{13}&a_{14}&\cdots&a_{1m}\\ ka_{21}&ka_{22}&ka_{23}&ka_{24}&\cdots&ka_{2m}\\ a_{31}&a_{32}&a_{33}&a_{34}&\cdots&a_{3m}\\ a_{41}&a_{42}&a_{43}&a_{44}&\cdots&a_{4m}\\ \vdots &\vdots&\vdots&\vdots&\ddots&\vdots\\ a_{n1}&a_{n2}&a_{n3}&a_{n4}&\cdots&a_{nm}\\ \end{bmatrix}\)
\(\begin{bmatrix}
1&0&0&0&\cdots&0\\
0&1&0&k&\cdots&0\\
0&0&1&0&\cdots&0\\
0&0&0&1&\cdots&0\\
\vdots &\vdots&\vdots&\vdots&\ddots&\vdots\\
0&0&0&0&\cdots&1
\end{bmatrix}
\begin{bmatrix}
a_{11}&a_{12}&a_{13}&a_{14}&\cdots&a_{1m}\\
a_{21}&a_{22}&a_{23}&a_{24}&\cdots&a_{2m}\\
a_{31}&a_{32}&a_{33}&a_{34}&\cdots&a_{3m}\\
a_{41}&a_{42}&a_{43}&a_{44}&\cdots&a_{4m}\\
\vdots &\vdots&\vdots&\vdots&\ddots&\vdots\\
a_{n1}&a_{n2}&a_{n3}&a_{n4}&\cdots&a_{nm}\\
\end{bmatrix}=
\begin{bmatrix}
a_{11}&a_{12}&a_{13}&a_{14}&\cdots&a_{1m}\\
a_{21}+ka_{41}&a_{22}+ka_{42}&a_{23}+ka_{43}&a_{24}+ka_{44}&\cdots&a_{2m}+ka_{4m}\\
a_{31}&a_{32}&a_{33}&a_{34}&\cdots&a_{3m}\\
a_{41}&a_{42}&a_{43}&a_{44}&\cdots&a_{4m}\\
\vdots &\vdots&\vdots&\vdots&\ddots&\vdots\\
a_{n1}&a_{n2}&a_{n3}&a_{n4}&\cdots&a_{nm}\\
\end{bmatrix}\)
由上三式可得,矩阵的初等行变换可以变为左乘一个方阵的矩阵乘法,那么当我们对一个方阵进行高斯消元变为单位矩阵的过程实际上就是乘上了它的逆,那么我们只需要把我们的操作记录下来就可以求得\(A\)的逆了(不用担心\(A\)不是方阵怎么办,不是方阵它根本就没有逆)。
那怎么记录呢?考虑对矩阵\(\left(A\mid I\right)\)进行高斯消元,由于单位矩阵的特性,最后我们可以得到\((I\mid A^{-1})\)。
最后再乘上B就好啦。
\(\cdots\)
但是矩阵乘法和高斯消元本身复杂度就很高,这道题这么做的话会超时,那能不能把常数降下去呢?
有的
回想一下我们干了什么,会发现我们实际上是计算了\(A^{-1}IB\)的值。诶,矩阵乘法是有结合律的啊,我们可以直接算\(A^{-1}B\)也就是对\((A\mid B)\)直接进行高斯消元就可以得到\((I\mid A^{-1}B)\)。
再附一个丑陋的代码参考一下吧。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const long long MOD=998244353;
long long n,m,r,tmp[410][410],a[64000010],h[410];
long long fm(long long bot,long long pow)
{
long long s=1;
for(;pow;pow>>=1,bot=bot*bot%MOD)
if(pow&1)
s=s*bot%MOD;
return s;
}
void read(long long &x)
{
char c=getchar();
long long f=1;
x=0;
while(c<'0'||'9'<c)
f=c=='-'?-1:1,c=getchar();
while('0'<=c&&c<='9')
x=(x<<1)+(x<<3)+(c&15),c=getchar();
x*=f;
}
void print(long long x)
{
if(x>=10)
print(x/10);
putchar(x%10+'0');
}
int main()
{
read(n),read(m);
for(long long i=1;i<=n;++i)
for(long long j=1;j<=m;++j)
read(tmp[i][j]),tmp[i][j]=(MOD+tmp[i][j])%MOD;
read(n),read(r);
for(long long i=1,k=1;i<=n;++i)
{
h[i]=k-1;
for(long long j=1;j<=m;++j,++k)
a[k]=tmp[i][j];
for(long long j=1;j<=r;++j,++k)
read(a[k]),a[k]=(MOD+a[k])%MOD;
}
r+=m;
for(long long i=1;i<=n;++i)
{
for(long long j=i;j<=n;++j)
if(a[h[j]+i])
{
if(j!=i)
swap(h[j],h[i]);
break;
}
long long div=fm(a[h[i]+i],MOD-2);
for(long long j=1;j<=r;++j)
a[h[i]+j]=a[h[i]+j]*div%MOD;
for(long long j=1;j<=n;++j)
{
if(j==i)
continue;
div=a[h[j]+i];
for(long long k=1;k<=r;++k)
a[h[j]+k]=(MOD+a[h[j]+k]-a[h[i]+k]*div%MOD)%MOD;
}
}
for(long long i=1;i<=n;++i)
{
for(long long j=m+1;j<=r;++j)
{
if(a[h[i]+j]<0)
putchar('-'),a[h[i]+j]=-a[h[i]+j];
print(a[h[i]+j]),putchar(' ');
}
putchar('\n');
}
return 0;
}