P10102 [GDKOI2023 提高组] 矩阵
题目描述
多次给定三个 \(n \times n\) 的矩阵 \(A, B, C\),你需要判断 \(A \times B\) 在模 \(998244353\) 意义下是否等于 \(C\)。
其中 \(×\) 为矩阵乘法,\(C_{i,j} = \sum_{k=1}^{n}A_{i,k}B_{k,j}\)。
本题读入量较大,建议使用快速读入。
输入格式
第 \(1\) 行输入一个正整数 \(T\),表示数据组数。
接下来包含 \(T\) 组数据,每组数据第一行为一个正整数 \(n\),表示矩阵大小。
接下来 \(n\) 行,每行 \(n\) 个整数,表示矩阵 \(A\)。
接下来 \(n\) 行,每行 \(n\) 个整数,表示矩阵 \(B\)。
接下来 \(n\) 行,每行 \(n\) 个整数,表示矩阵 \(C\)。
输出格式
输出 \(T\) 行 Yes 或 No,表示 \(A \times B\) 在模 \(998244353\) 意义下是否等于 \(C\)。
输入输出样例 #1
输入 #1
3
1
2
3
6
2
1 2
3 4
5 6
7 8
19 22
43 51
2
1111111 2222222
3333333 4444444
5555555 6666666
7777777 8888888
39625305 256038638
772687616 944903942
输出 #1
Yes
No
Yes
说明/提示
对于 20% 的数据,满足 \(\sum n ≤ 300\)。
对于另外 20% 的数据,满足 \(A_{i,j} \ne 0\) 的位置不超过 \(n\) 个。
对于 100% 的数据,满足 \(1 ≤ T, n ≤ 3000,\sum n ≤ 3000, 0 ≤ A_{i,j} , B_{i,j} , C_{i,j} < 998244353\)。
基本思路
关于矩阵乘法如何计算,这里不如引用蓝书的说法吧:
设 \(A\) 是 \(n*m\) 的矩阵, \(B\) 是 \(m*p\) 的矩阵,则 \(C=A*B\)是 \(n*p\) 的矩阵(我理解为取 \(A\) 的列,取 \(B\) 的行),并且∀\(i\) ∈ \([1,n]\) , ∀\(j\) ∈ \([1,p]\).
那么直接进入正题好了,如果我们直接强行计算 \(A*B\) 的话,程序复杂度会变成 \(O(n^3)\) ,显然是不能通过的。那么我们采取随机化算法,用随机数生成一个 \(n*1\) 的矩阵 \(R\) ,利用性质 \(A*B*R=C*R\) ,进行验证,这样复杂度就能被降到 \(O(n^2)\) ,而且不放心的话可以多验证几次。不过要注意的是需要利用交换律先计算 \(B*R\) ,而不是先计算 \(A*B\) ,不然依旧会退化到 \(O(n^3)\)。
代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
const int N=3e3+10;
ll n;
void mul(ll a[N][N],ll b[N],ll c[N]){
for(int i=1;i<=n;i++){
c[i]=0;
for(int j=1;j<=n;j++){
c[i]+=a[i][j]*b[j];c[i]%=mod;
}
}
}
ll T,a[N][N],b[N][N],c[N][N],r[N],rec[N],ret[N],cnt[N],res[N];
bool f;
int main(){
// freopen("transform.in","r",stdin);
// freopen("transform.out","w",stdout);
ios::sync_with_stdio(false);
cin>>T;
while(T--){
cin>>n;
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
cin>>a[i][j];
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
cin>>b[i][j];
for(int i=1;i<=n;i++)
for(int j=1;j<=n;j++)
cin>>c[i][j];
f=true;
for(int ti=1;ti<=4;ti++){
if(!f) break;
for(int i=1;i<=n;i++)
r[i]=rand();
mul(b,r,ret);
mul(a,ret,rec);
mul(c,r,cnt);
for(int i=1;i<=n;i++)if(rec[i]!=cnt[i]){
f=false;
break;
}
}
if(f) cout<<"Yes"<<endl;
else cout<<"No"<<endl;
}
return 0;
}
这里再附赠一个容易爆零的点,请观察:
void mul(ll a[N][N],ll b[N],ll c[N]){
memset(c,0,sizeof(c));
for(int i=1;i<=n;i++){
for(int j=1;j<=n;j++){
c[i]+=a[i][j]*b[j];c[i]%=mod;
}
}
}
有什么问题吗?问题出在 \(memset\) 上,这个函数传数组默认都是传指针的,所以 \(sizeof(c)\) 得到的是指针的长度,也就是 \(8\)