题目链接

https://atcoder.jp/contests/agc039/tasks/agc039_f

题解

又是很简单的F题我不会。。。
考虑先给每行每列钦定一个最小值\(a_i,b_j\),并假设每行每列的最小值是这个数,且每行每列只需要放\(\ge\)这个数的数即可,那么这种情况的价值是\(\prod^n_{i=1}\prod^m_{j=1}\min(a_i,b_j)\), 方案数是\(\prod^n_{i=1}\prod^m_{j=1}(n+1-\max(a_i,b_j))\)
然后我们需要把最小值的限制容斥掉,也就是枚举若干行若干列容斥掉(限制\(+1\)同时系数乘以\(-1\))。
这样的话直接暴力DP就可以解决。设\(f[k][i][j]\)表示当前用\([1,k]\)中的数填满了\(i\)\(j\)列。转移可以直接枚举不被容斥的行数、不被容斥的列数、容斥的行数、容斥的列数,乘上贡献系数,得到了一个多项式时间复杂度的算法。
但是我们发现这样转移显然很浪费,我们可以把四个变量同时枚举改成分四个阶段依次枚举,这样转移时间复杂度降到了\(O(n)\).(注意因为要保证从小到大填数,所以必须先枚举不被容斥再枚举被容斥)
不过这题还挺卡常的……需要\(O(n^3)\)预处理一下转移系数,详见代码
时间复杂度\(O(n^4)\)
orz myh

代码

#include<bits/stdc++.h>
#define llong long long
#define mkpr make_pair
#define riterator reverse_iterator
using namespace std;

inline int read()
{
	int x = 0,f = 1; char ch = getchar();
	for(;!isdigit(ch);ch=getchar()) {if(ch=='-') f = -1;}
	for(; isdigit(ch);ch=getchar()) {x = x*10+ch-48;}
	return x*f;
}

const int N = 100;
int P;
llong pw[N+3][N*N+3];
llong comb[N+3][N+3];
llong f[2][N+3][N+3];
llong trans[N+3][N+3];
int n,m,p;

llong quickpow(llong x,llong y)
{
	llong cur = x,ret = 1ll;
	for(int i=0; y; i++)
	{
		if(y&(1ll<<i)) {y-=(1ll<<i); ret = ret*cur%P;}
		cur = cur*cur%P;
	}
	return ret;
}

void initmath()
{
	for(int i=0; i<=N; i++)
	{
		pw[i][0] = 1ll; for(int j=1; j<=N*N; j++) pw[i][j] = pw[i][j-1]*i%P;
	}
	comb[0][0] = 1ll;
	for(int i=1; i<=N; i++)
	{
		comb[i][0] = comb[i][i] = 1ll;
		for(int j=1; j<i; j++) comb[i][j] = (comb[i-1][j]+comb[i-1][j-1])%P;
	}
}

llong updsum(llong &x,llong y) {x = x+y>=P?x+y-P:x+y;}

int main()
{
	scanf("%d%d%d%lld",&n,&m,&p,&P);
	initmath();
	int curk = 0; f[0][0][0] = 1ll;
	for(int k=1; k<=p; k++)
	{
		curk^=1; memset(f[curk],0,sizeof(f[curk]));
		for(int j=0; j<=m; j++) for(int ii=0; ii<=n; ii++) trans[j][ii] = pw[k][ii*(m-j)]%P*pw[p-k+1][ii*j]%P;
		for(int i=0; i<=n; i++)
		{
			for(int j=0; j<=m; j++)
			{
				llong x = f[curk^1][i][j]; if(!x) continue;
				for(int ii=0; ii+i<=n; ii++)
				{
					updsum(f[curk][i+ii][j],x*comb[i+ii][i]%P*trans[j][ii]%P);
				}
			}
		}
		curk^=1; memset(f[curk],0,sizeof(f[curk]));
		for(int i=0; i<=n; i++) for(int jj=0; jj<=m; jj++) trans[i][jj] = pw[k][jj*(n-i)]%P*pw[p-k+1][jj*i]%P;
		for(int i=0; i<=n; i++)
		{
			for(int j=0; j<=m; j++)
			{
				llong x = f[curk^1][i][j]; if(!x) continue;
				for(int jj=0; jj+j<=m; jj++)
				{
					updsum(f[curk][i][j+jj],x*comb[j+jj][j]%P*trans[i][jj]%P);
				}
			}
		}
		curk^=1; memset(f[curk],0,sizeof(f[curk]));
		for(int j=0; j<=m; j++) for(int ii=0; ii<=n; ii++) trans[j][ii] = pw[k][ii*(m-j)]%P*pw[p-k][ii*j]%P;
		for(int i=0; i<=n; i++)
		{
			for(int j=0; j<=m; j++)
			{
				llong x = f[curk^1][i][j]; if(!x) continue;
				for(int ii=0; ii+i<=n; ii++)
				{
					llong y = x*comb[i+ii][i]%P*trans[j][ii]%P;
					updsum(f[curk][i+ii][j],ii&1?P-y:y);
				}
			}
		}
		curk^=1; memset(f[curk],0,sizeof(f[curk]));
		for(int i=0; i<=n; i++) for(int jj=0; jj<=m; jj++) trans[i][jj] = pw[k][jj*(n-i)]%P*pw[p-k][i*jj]%P;
		for(int i=0; i<=n; i++)
		{
			for(int j=0; j<=m; j++)
			{
				llong x = f[curk^1][i][j]; if(!x) continue;
				for(int jj=0; jj+j<=m; jj++)
				{
					llong y = x*comb[j+jj][j]%P*trans[i][jj]%P;
					updsum(f[curk][i][j+jj],jj&1?P-y:y);
				}
			}
		}
	}
	printf("%lld\n",f[curk][n][m]);
	return 0;
}