// 侧边栏目录 // https://blog-static.cnblogs.com/files/douzujun/marvin.nav.my1502.css //目录导航 //生成目录索引列表

P10102 [GDKOI2023 提高组] 矩阵

题目描述

传送门

多次给定三个 \(n \times n\) 的矩阵 \(A, B, C\),你需要判断 \(A \times B\) 在模 \(998244353\) 意义下是否等于 \(C\)
其中 \(×\) 为矩阵乘法,\(C_{i,j} = \sum_{k=1}^{n}A_{i,k}B_{k,j}\)

本题读入量较大,建议使用快速读入。

输入格式

\(1\) 行输入一个正整数 \(T\),表示数据组数。

接下来包含 \(T\) 组数据,每组数据第一行为一个正整数 \(n\),表示矩阵大小。

接下来 \(n\) 行,每行 \(n\) 个整数,表示矩阵 \(A\)

接下来 \(n\) 行,每行 \(n\) 个整数,表示矩阵 \(B\)

接下来 \(n\) 行,每行 \(n\) 个整数,表示矩阵 \(C\)

输出格式

输出 \(T\) 行 Yes 或 No,表示 \(A \times B\) 在模 \(998244353\) 意义下是否等于 \(C\)

输入输出样例 #1

输入 #1

3
1
2
3
6
2
1 2
3 4
5 6
7 8
19 22
43 51
2
1111111 2222222
3333333 4444444
5555555 6666666
7777777 8888888
39625305 256038638
772687616 944903942

输出 #1

Yes
No
Yes

说明/提示

对于 20% 的数据,满足 \(\sum n ≤ 300\)

对于另外 20% 的数据,满足 \(A_{i,j} \ne 0\) 的位置不超过 \(n\) 个。

对于 100% 的数据,满足 \(1 ≤ T, n ≤ 3000,\sum n ≤ 3000, 0 ≤ A_{i,j} , B_{i,j} , C_{i,j} < 998244353\)

基本思路

参考题解

关于矩阵乘法如何计算,这里不如引用蓝书的说法吧:
\(A\)\(n*m\) 的矩阵, \(B\)\(m*p\) 的矩阵,则 \(C=A*B\)\(n*p\) 的矩阵(我理解为取 \(A\) 的列,取 \(B\) 的行),并且∀\(i\)\([1,n]\) , ∀\(j\)\([1,p]\).

那么直接进入正题好了,如果我们直接强行计算 \(A*B\) 的话,程序复杂度会变成 \(O(n^3)\) ,显然是不能通过的。那么我们采取随机化算法,用随机数生成一个 \(n*1\) 的矩阵 \(R\) ,利用性质 \(A*B*R=C*R\) ,进行验证,这样复杂度就能被降到 \(O(n^2)\) ,而且不放心的话可以多验证几次。不过要注意的是需要利用交换律先计算 \(B*R\) ,而不是先计算 \(A*B\) ,不然依旧会退化到 \(O(n^3)\)

代码

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
const int N=3e3+10;
ll n;
void mul(ll a[N][N],ll b[N],ll c[N]){
	for(int i=1;i<=n;i++){
		c[i]=0;
		for(int j=1;j<=n;j++){
			c[i]+=a[i][j]*b[j];c[i]%=mod;
		}
	}
}
ll T,a[N][N],b[N][N],c[N][N],r[N],rec[N],ret[N],cnt[N],res[N];
bool f;
int main(){
//	freopen("transform.in","r",stdin);
//	freopen("transform.out","w",stdout);
	ios::sync_with_stdio(false);
	cin>>T;
	while(T--){
		cin>>n;
		for(int i=1;i<=n;i++)
		for(int j=1;j<=n;j++)
			cin>>a[i][j];
		for(int i=1;i<=n;i++)
		for(int j=1;j<=n;j++)
			cin>>b[i][j];
		for(int i=1;i<=n;i++)
		for(int j=1;j<=n;j++)
			cin>>c[i][j];
		f=true;
		for(int ti=1;ti<=4;ti++){
			if(!f) break;
			for(int i=1;i<=n;i++)
				r[i]=rand();
			mul(b,r,ret);
			mul(a,ret,rec);
			mul(c,r,cnt);
			for(int i=1;i<=n;i++)if(rec[i]!=cnt[i]){
				f=false;
				break;
			}
		}
		if(f) cout<<"Yes"<<endl;
		else cout<<"No"<<endl;
	}
	return 0;
}

这里再附赠一个容易爆零的点,请观察:

void mul(ll a[N][N],ll b[N],ll c[N]){
	memset(c,0,sizeof(c));
	for(int i=1;i<=n;i++){
		for(int j=1;j<=n;j++){
			c[i]+=a[i][j]*b[j];c[i]%=mod;
		}
	}
}

有什么问题吗?问题出在 \(memset\) 上,这个函数传数组默认都是传指针的,所以 \(sizeof(c)\) 得到的是指针的长度,也就是 \(8\)

posted @ 2025-08-07 17:57  SSL_LMZ  阅读(40)  评论(0)    收藏  举报