[集训队互测]calc

题目

点这里看题目。

分析

首先不难想到可以枚举递增的序列,最后在答案里面乘上\(n!\),于是有\(O(nk)\)的暴力 DP 一枚:

\(f(i,j)\)表示长度为\(i\)、最大值\(\le j\)的序列的贡献和。

转移显然:

\[f(i,j)=j\times f(i-1,j-1)+f(i,j-1) \]

那么可以发现,当序列长度固定的时候,\(f(n,x)\)肯定是关于\(x\)的函数。环顾四周,DP 转移方程中并不存在除法、开方、作为指数乘方等运算,所以可以推测\(f(n,x)\)就是\(x\)的多项式函数。

那么,它的次数是多少呢?这直接决定了我们如何进行插值。设\(f(n,x)\)的次数为\(g(n)\),考虑到转移左右两边的次数应该是相等的,就有:

\[f(i,j)-f(i,j-1)=j\times f(i-1,j-1)\Rightarrow g(n)-1=g(n-1)+1 \]

补充一下,多项式函数做差分,即\(f(x)-f(x-1)\),得到的结果的次数会比原多项式的小一,可以直接用二项式定理展开证明。

然后发现,\(g(n)=g(n-1)+2\),由于\(f(0,x)=1\),所以\(g(0)=0\),得到通项公式\(g(n)=2n\)

然后我们就知道了\(f(n,x)\)是关于\(x\)\(2n\)次的多项式函数,因此,我们需要算出\(2n+1\)个点值,用于插值。总时间\(O(n^2)\)

//f(i,j)=f(i-1,j-1)*j+f(i,j-1)
#include <cstdio>

const int MAXN = 1005;

template<typename _T>
void read( _T &x )
{
	x = 0;char s = getchar();int f = 1;
	while( s > '9' || s < '0' ){if( s == '-' ) f = -1; s = getchar();}
	while( s >= '0' && s <= '9' ){x = ( x << 3 ) + ( x << 1 ) + ( s - '0' ), s = getchar();}
	x *= f;
}

template<typename _T>
void write( _T x )
{
	if( x < 0 ){ putchar( '-' ); x = ( ~ x ) + 1; }
	if( 9 < x ){ write( x / 10 ); }
	putchar( x % 10 + '0' );
}

int f[MAXN][MAXN], y[MAXN];
int N, K, M, mod;

int qkpow( int base, int indx )
{
	int ret = 1;
	while( indx )
	{
		if( indx & 1 ) ret = 1ll * ret * base % mod;
		base = 1ll * base * base % mod, indx >>= 1;
	}
	return ret;
}

int inver( const int a ) { return qkpow( a, mod - 2 ); }
void add( int &x, const int v ) { x = ( x + v >= mod ? x + v - mod : x + v ); }

int Lagrange()
{
	if( K <= M ) return y[K];
	int ans = 0, tmp;
	for( int i = 1 ; i <= M ; i ++ )
	{
		tmp = 1;
		for( int j = 1 ; j <= M ; j ++ )
			if( i != j )
				tmp = 1ll * tmp * ( K - j ) % mod * inver( i - j + mod ) % mod;
		add( ans, 1ll * tmp * y[i] % mod );
	}
	return ans;
}

int main()
{
	read( K ), read( N ), read( mod );
	M = 2 * N + 1;
	for( int j = 0 ; j <= M ; j ++ )
		f[0][j] = 1; 
	for( int i = 1 ; i <= N ; i ++ )
		for( int j = 1 ; j <= M ; j ++ )
			f[i][j] = ( 1ll * f[i - 1][j - 1] * j % mod + f[i][j - 1] ) % mod;
	for( int i = 0 ; i <= M ; i ++ ) y[i] = f[N][i];
	int fac = 1; 
	for( int i = 1 ; i <= N ; i ++ ) fac = 1ll * fac * i % mod;
	write( 1ll * fac * Lagrange() % mod ), putchar( '\n' );
	return 0;
}
posted @ 2020-06-14 22:34  crashed  阅读(233)  评论(0编辑  收藏  举报