CF1603E

令 $M=\max{a_i},m=\min{a_i},S=\sum a_i$ ,则 $M \cdot m \geq S$ 。

又因为 $S \geq m \cdot n$ ,所以 $M \geq n$ ,当且仅当所有 $a_i$ 相同时取等号。

又因为 $a_i \leq n+1$ ,所以 $M=n+1,(n+1)m \geq S$ 。

变形得 $\sum a_i-m \leq m$ 。

可以假设序列单调不降,然后考虑每一个数出现次数,设 $i$ 在序列中出现了 $num_i$ 次。

则这种情况下方案数为:

$$\frac{n!}{\prod num_i!}$$

于是令 $dp_{i,j,k}$ 表示最小的 $j$ 个数中,其中最大数减去最小数为 $i$ ,所有数减去最小数后和为 $k$ 的方案数,枚举最大数的个数转移。

$\prod num_i!$ 在转移过程中顺带计算。

因为 $(n-j) \cdot i \leq n$ 所以时间复杂度不超过 $O(n^3logn)$ 。

但是这样会把不可行的答案包含进来。

原题要求对于全部子序列都有 $M \cdot m \geq S$ ,但刚才的算法只要求整个序列满足条件。

仍然假设序列单调不降,于是只要对于任何前缀都有 $M \cdot m \geq S$ 即可,因为加入一些小于原有最大值的数只会让限制更强。

每一个前缀序列可以用一系列dp状态表示,于是考虑dp状态 ${i,j,k}$ 。

在此状态中, $M=m+i,S=jm+k$ ,于是 $m(m+i) \geq jm+k$ ,即 $m+i \geq j+\frac{k}{m}$ 。

因为 $k=\sum a_i-m \leq m$ 所以 $\frac{k}{m} \in (0,1]$ (将$k=0$即所有数相等特判掉之后)。

$m,i,j$ 是整数,于是 $j-i+1 \leq m$ 。枚举 $m$ 后再dp可以保证答案正确,但是时间复杂度过高。

观察发现最小值 $m$ 不可以太小,事实上由于对于最小的 $x$ 个数有 $M \geq x$ ,于是所有数之和不小于 $nm+\frac{(n-m)(n-m+1)}{2}$ 。

又因为 $\sum a_i-m \leq m$ 所以 $m \geq n-O(\sqrt{n})$ 。

可以枚举后dp,时间复杂度 $O(n^{3.5})$ 。

代码:

#include<cstdio>
#include<algorithm>
using namespace std;
#define ll long long int

int n,mod;
int fac[202],inv[202];
inline int inc(int a,int b)
{
    return (a+b)<mod?a+b:a+b-mod;
}
inline int mul(int a,int b)
{
    return (int)((ll)a*b%mod);
}
inline int pow(int a,int b)
{
    int res=1;
    while(b>0){
        if(b&1) res=mul(res,a);
        a=mul(a,a); b>>=1;
    }
    return res;
}

#include<cstring>
int dp[202][202][202];
int solve(int m)
{
    memset(dp,0,sizeof(dp));
    for(int i=1;i<=n;i++)
        dp[0][i][0]=inv[i];
    for(int i=1;i<=n+1;i++)
        for(int j=n-n/i;j<=n&&j<=m+i-1;j++)
            for(int k=0;k<=n;k++)
                for(int p=0;p<=j&&p<=k/i;p++)
                    dp[i][j][k]=inc(dp[i][j][k],mul(dp[i-1][j-p][k-i*p],inv[p]));
    int res=0;
    if(m>=n) res=inv[n];
    for(int k=1;k<=m;k++)
        res=inc(res,dp[n+1-m][n][k]);
    return res;
}

int main(void)
{
    scanf("%d%d",&n,&mod);
    fac[0]=inv[0]=1;
    for(int i=1;i<=n;i++){
        fac[i]=mul(fac[i-1],i);
        inv[i]=pow(fac[i],mod-2);
    }
    int ans=0;
    for(int m=n-17;m<=n+1;m++)
        ans=inc(ans,solve(m));
    printf("%d\n",mul(ans,fac[n]));
    return 0;
}
View Code

 

posted @ 2021-10-31 09:11  Miracle_Creater  阅读(144)  评论(0)    收藏  举报