BZOJ 4555: [Tjoi2016&Heoi2016]求和 [FFT 组合计数 容斥原理]

4555: [Tjoi2016&Heoi2016]求和

题意:求\[ \sum_{i=0}^n \sum_{j=0}^i S(i,j)\cdot 2^j\cdot j! \\ S是第二类斯特林数 \]


首先你要把这个组合计数肝出来,于是我去翻了一波《组合数学》

分治fft做法见上一篇,本篇是容斥原理+fft做法


组合计数

斯特林数 \(S(n,i)\)表示将n个不同元素划分成i个相同集合非空的方案数


考虑集合不相同情况\(S'(n,i)=S(n,i)*i!\),我们用容斥原理推♂倒她
\[ 每个集合非空的限制太强了,我们弱化它,可以有\ge k个空集合 \\ ans = \ge 0个空集合 - \ge 1个空集合 + \ge 2 个空集合 \\ S'(n,i) = \sum_{k=0}^{i} (-1)^k \binom{i}{k} (i-k)^n \\ \]
最后的\((i-k)^n\)含义是n个元素每个可以放入任意一个集合中



然后把这个式子带进去化啊化,具体过程WerKeyTom_FTD大爷已经写过了
注意有一步把第一个带着i的求和移到最后,是一个等比数列求和
最后得到的是
\[ ans=\sum_{j=0}^nj!*2^j*\sum_{k=0}^j\frac{(-1)^k}{k!}*\frac{\sum_{i=0}^n(j-k)^i}{(j-k)!} \]
后面是卷积的形式,一遍ntt就行了

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
using namespace std;
typedef long long ll;
const int N=(1<<18)+5, INF=1e9;
const ll P=998244353, g=3;
inline int read(){
    char c=getchar();int x=0,f=1;
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){x=x*10+c-'0';c=getchar();}
    return x*f;
}

ll Pow(ll a, ll b) {
    ll ans=1;
    for(; b; b>>=1, a=a*a%P)
        if(b&1) ans=ans*a%P;
    return ans;
}

namespace ntt{
    int n, rev[N];
    void ini(int lim) {
        n=1; int k=0;
        while(n<lim) n<<=1, k++;
        for(int i=0; i<n; i++) rev[i] = (rev[i>>1]>>1) | ((i&1)<<(k-1));
    }
    void dft(ll *a, int flag) {
        for(int i=0; i<n; i++) if(i<rev[i]) swap(a[i], a[rev[i]]);
        for(int l=2; l<=n; l<<=1) {
            int m=l>>1;
            ll wn = Pow(g, flag==1 ? (P-1)/l : P-1-(P-1)/l);
            for(ll *p=a; p!=a+n; p+=l) {
                ll w=1;
                for(int k=0; k<m; k++) {
                    ll t = w * p[k+m]%P;
                    p[k+m]=(p[k]-t+P)%P;
                    p[k]=(p[k]+t)%P;
                    w=w*wn%P;
                }
            }
        }
        if(flag==-1) {
            ll inv=Pow(n, P-2);
            for(int i=0; i<n; i++) a[i]=a[i]*inv%P;
        }
    }
    void mul(ll *a, ll *b) {
        dft(a, 1); dft(b, 1);
        for(int i=0; i<n; i++) a[i]=a[i]*b[i];
        dft(a, -1);
    }
}using ntt::ini; using ntt::mul;

int n, rev[N];
ll inv[N], fac[N], facInv[N];
ll f[N], a[N], b[N];

int main() {
    freopen("in","r",stdin);
    n=read(); 
    inv[1]=1; fac[0]=facInv[0]=1;
    for(int i=1; i<=n; i++) {
        if(i!=1) inv[i] = (P-P/i)*inv[P%i]%P;
        fac[i] = fac[i-1]*i%P;
        facInv[i] = facInv[i-1]*inv[i]%P;
    }
    a[0]=1; b[0]=1; b[1]=n+1;
    for(int i=1; i<=n; i++) a[i] = (i&1 ? -1 : 1) * facInv[i];
    for(int i=2; i<=n; i++) b[i] = (Pow(i, n+1)-1) * inv[i-1] %P * facInv[i] %P;
    ini(n+n+1); mul(a, b);
    ll ans=0;
    for(int i=0; i<=n; i++) ( ans += Pow(2, i)*fac[i]%P * a[i]%P )%=P;
    if(ans<0) ans+=P;
    printf("%lld", ans);
}
posted @ 2017-03-30 21:49 Candy? 阅读(...) 评论(...) 编辑 收藏