[互测题目20200325]排列的和

题目

  设\(a,b\)分别为\(1\sim n\)的排列。
  求有多少个排列对\((a,b)\)满足\(\sum_{i=1}^n\max\{a_i,b_i\}\ge m\)
  两个排列对\((a,b)\)\((c,d)\)不同当且仅当存在一个\(i\),使得\(a_i\not=c_i\)或者\(b_i\not=d_i\)
  数据范围为\(n\le 50, 0\le m\le 10^9\)

分析

  首先发现\(\sum_{i=1}^n \max\{a_i,b_i\}\ge m\)有一个比较宽松的上界\(n^2\),因此我们只需要考虑\(m\le n^2\)的情况。
  动用一个套路——我们假设\(a_i=i\),那么我们就可以单纯枚举\(b\),找出这种情况下的方案数,再乘上\(n!\)即可。
  事实上我们需要做的就是计算有多少个排列(或者可以称为 " 置换 " )\(p\)满足\(\sum_{i=1}^n\max\{p_i,i\}\ge m\)。而计算这样的排列可以理解为下标和值的对应。
  因此可以考虑如下的 DP :
  \(f(i,j,k)\):前\(i\)个数中,分别有\(j\)个下标和值还没有对应上,已经对应的和为\(k\)的方案数。
  考虑转移,分 3 种情况:
  1.什么也不干,方案数为\(f(i-1,j-1,k)\)
  2.下标\(i\)与一个值配对,或者值\(i\)与一个下标配对。这样会有\(2j+1\)种情况(下标\(i\)与值\(i\)配对当然只算一次),因此方案数为\((2j+1)f(i-1,j,k-i)\)
  3.下标\(i\)与一个值配对,且值\(i\)与一个下标配对。注意到这样的话会一次减少一个未配对的下标和未配对的值,所以在进行配对前分别有\(j+1\)个下标和值未配对,因此情况为\((j+1)^2\),方案数为\((j+1)^2f(i-1,j+1,k-2i)\)
  最后统计\(k\)\([m,n^2]\)中的方案总数,并且不要忘了乘上\(n!\)

代码

#include <cstdio>

const int mod = 998244353;
const int MAXN = 55, MAXS = MAXN * MAXN;

template<typename _T>
void read( _T &x )
{
	x = 0;char s = getchar();int f = 1;
	while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
	while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
	x *= f;
}

template<typename _T>
void write( _T x )
{
	if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
	if( 9 < x ){ write( x / 10 ); }
	putchar( x % 10 + '0' );
}

template<typename _T>
_T MAX( const _T a, const _T b )
{
	return a > b ? a : b;
}

int f[MAXN][MAXN][MAXS];
int N, M;

int main()
{
	read( N ), read( M );
	f[0][0][0] = 1;
	for( int i = 1 ; i <= N ; i ++ )
		for( int j = 0 ; j <= i ; j ++ )
			for( int k = 0 ; k <= N * N ; k ++ )
			{
				if( j ) ( f[i][j][k] += f[i - 1][j - 1][k] ) %= mod;
				if( k >= i ) ( f[i][j][k] += 1ll * ( 2 * j + 1 ) % mod * f[i - 1][j][k - i] % mod ) %= mod;
				if( k >= 2 * i ) ( f[i][j][k] += 1ll * ( j + 1 ) * ( j + 1 ) % mod * 
					 								   f[i - 1][j + 1][k - 2 * i] % mod ) %= mod;
			}
	int ans = 0;
	for( int i = M ; i <= N * N ; i ++ ) ( ans += f[N][0][i] ) %= mod;
	for( int i = 2 ; i <= N ; i ++ ) ans = 1ll * ans * i % mod;
	write( ans ), putchar( '\n' );
	return 0;
}
posted @ 2020-03-27 19:22  crashed  阅读(82)  评论(0编辑  收藏  举报