AGC057E RowCol/ColRow Sort 【观察,组合计数】

考虑排序网络的 \(\texttt{01}\) 原理,合法当且仅当对每个 \(k\in[0,8]\),对 \([A_{i,j}\le k]\) 做操作都得到 \([B_{i,j}\le k]\)

现在就对 \(\texttt{01}\) 矩阵排序,注意到,考虑每行之和 \(r_i\),对行排序不改变,对列排序就是对 \(r_i\) 降序排序,每列之和 \(c_j\) 同理。且结果只与 \(r_i,c_j\) 有关,所以是充分的。

设排列 \(p_k\) 使得 \(r_{p_k(i)}\ge r_{p_k(i+1)}\)\(q_k\) 同理,则条件即为 \(A_{i,j}\le k\iff B_{p_k(i),q_k(j)}\le k\),那么这堆排列就唯一确定了 \(A_{i,j}\),当然这会算重,最后除以一堆阶乘即可。

不关注 \(A\),条件就是 \(B_{p_k(i),q_k(j)}\le k\implies B_{p_{k+1}(i),q_{k+1}(j)}\le k+1\),这只跟 \(p_{k+1}\circ p_k^{-1}\)\(q_{k+1}\circ q_k^{-1}\) 有关,不如看成 \(B_{i,j}\le k\implies B_{p_k(i),q_k(j)}\le k+1\)

考察 \([B_{i,j}\le k]\) 的杨图结构,设 \(a_i=\sum_j[B_{i,j}\le k]\)\(b_j=\sum_i[B_{i,j}\le k+1]\),条件即为 \(p_k(i)\le b_{\max\{q_k(1),\cdots,q_k(a_i)\}}\),不等式右侧递增,所以即为枚举 \(q_k\),计算 \(\prod_{i=1}^n(b_{\max\{q_k(1),\cdots,q_k(a_i)\}}-i+1)\) 之和,其中 \(b_0=n\)。从小到大枚举 \(a_i\),进行一个 dp 就可以了。

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 1503, mod = 998244353;
int ksm(int a, int b){
	int res = 1;
	for(;b;b >>= 1, a = (LL)a * a % mod)
		if(b & 1) res = (LL)res * a % mod;
	return res;
}
int fac[N], rfac[N];
void init(int m){
	*fac = 1;
	for(int i = 1;i <= m;++ i) fac[i] = (LL)fac[i - 1] * i % mod;
	rfac[m] = ksm(fac[m], mod - 2);
	for(int i = m;i;-- i) rfac[i - 1] = (LL)rfac[i] * i % mod;
}
int n, m, ans = 1, a[10][N], b[10][N], cnt[N], f[N][N];
void solve(int k){
	memset(f, 0, sizeof(f));
	int t = n;
	for(;t && !a[k][t];-- t) ans = ans * (n - t + 1ll) % mod;
	for(int i = 1;i <= m;++ i) f[1][i] = 1;
	for(;t && a[k][t] == 1;-- t)
		for(int i = 1;i <= m;++ i)
			f[1][i] = f[1][i] * max(b[k + 1][i] - t + 1ll, 0ll) % mod;
	for(int i = 2;i <= m;++ i){
		int tmp = f[i - 1][i - 1];
		for(int j = i;j <= m;++ j){
			tmp += f[i - 1][j]; if(tmp >= mod) tmp -= mod;
			f[i][j] = (tmp + LL(j - i) * f[i - 1][j]) % mod;
		}
		for(;t && a[k][t] == i;-- t)
			for(int j = i;j <= m;++ j)
				f[i][j] = f[i][j] * max(b[k + 1][j] - t + 1ll, 0ll) % mod;
	}
	ans = (LL)ans * f[m][m] % mod;
	memset(cnt, 0, (m + 1) << 2);
	for(int i = 1;i <= n;++ i) ++ cnt[a[k][i]];
	for(int i = 0;i <= m;++ i) ans = (LL)ans * rfac[cnt[i]] % mod;
	memset(cnt, 0, (n + 1) << 2);
	for(int i = 1;i <= m;++ i) ++ cnt[b[k][i]];
	for(int i = 0;i <= n;++ i) ans = (LL)ans * rfac[cnt[i]] % mod;
}
int main(){
	ios::sync_with_stdio(0);
	cin >> n >> m; init(N - 1);
	for(int i = 1;i <= n;++ i)
		for(int j = 1, x;j <= m;++ j){
			cin >> x; ++ a[x][i]; ++ b[x][j];
		}
	for(int i = 1;i <= 9;++ i){
		for(int j = 1;j <= n;++ j) a[i][j] += a[i - 1][j];
		for(int j = 1;j <= m;++ j) b[i][j] += b[i - 1][j];
	}
	for(int i = 0;i <= 8;++ i) solve(i);
	printf("%d\n", ans);
}
posted @ 2022-06-22 20:21  mizu164  阅读(179)  评论(0编辑  收藏  举报