bzoj4555

ntt+cdq分治

原来zwh出的cf是斯特林

第二类斯特林数的定义是S(i,j)表示将i个物品分到j个无序集合的方案数,那么这道题中S(i,j)*j!*2^j是指将i个物品分到j个有序集合中并且每个集合可以选或不选的方案数,那么我们改变这个公式,得出

F[i]=∑F[j]*2*C(i,j),j=0-n,意思是第一个集合选n-j个的方案数,那么这个集合有两种情况选或不选,乘上2,再乘上选出元素的方案数。然后展开组合数,得出F[i]=∑F[j]*2*i!/(i-j)!/j!,移项得出F[i]/i!=∑F[j]/j!*2/(i-j)!

设新的函数G[i]=F[i]/i!,那么G[i]=∑G[j]*2/(i-j)!,后面是卷积形式,用ntt优化,又因为两边都有G,所以我们用cdq分治求和,复杂度nlog^2n

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = (1 << 18) + 5, mod = 998244353;
int n;
ll ans;
int rev[N];
ll a[N], b[N], fac[N], facinv[N], inv[N], f[N];
ll power(ll x, ll t)
{
    ll ret = 1;
    for(; t; t >>= 1, x = x * x % mod) if(t & 1) ret = ret * x % mod;
    return ret;
}
void ntt(ll *a, int n, int f) 
{
    for(int i = 0; i < n; ++i) if(i < rev[i]) swap(a[i], a[rev[i]]);
    for(int m = 2; m <= n; m <<= 1) 
    {
        int mid = (m >> 1);
        ll wn = power(3, f == 1 ? (mod - 1) / m : mod - 1 - (mod - 1) / m);
        for(int i = 0; i < n; i += m) 
        {
            ll w = 1;
            for(int j = 0; j < mid; ++j) 
            {
                ll u = a[i + j], v = a[i + j + mid] * w % mod;
                a[i + j] = (u + v) % mod;
                a[i + j + mid] = (u - v + mod) % mod;
                w = w * wn % mod;
            }
        }       
    }   
    if(f == -1) 
    {
        ll inv = power(n, mod - 2);
        for(int i = 0; i < n; ++i) a[i] = a[i] * inv % mod;
    }
}
void cdq(int l, int r)
{
    if(l == r) return;
    int mid = (l + r) >> 1;
    cdq(l, mid);
    int lim = r - l + 1, n = 1, k = 0;
    while(n < lim) 
    {
        n <<= 1;
        ++k;
    }
    for(int i = 0; i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (k - 1)); 
    for(int i = 0; i < n; ++i) a[i] = b[i] = 0;
    for(int i = l; i <= mid; ++i) a[i - l] = f[i];
    for(int i = 0; i < lim; ++i) b[i] = facinv[i];
    ntt(a, n, 1);
    ntt(b, n, 1);
    for(int i = 0; i < n; ++i) a[i] = a[i] * b[i] % mod;
    ntt(a, n, -1);
    for(int i = mid + 1; i <= r; ++i) f[i] = (f[i] + 2 * a[i - l]) % mod;
    cdq(mid + 1, r);     
}
int main()
{
    scanf("%d", &n);
    fac[0] = inv[1] = facinv[0] = 1;
    for(int i = 1; i <= n; ++i) 
    {
        fac[i] = fac[i - 1] * i % mod;
        if(i != 1) inv[i] = (mod - mod / i) * inv[mod % i] % mod;
        facinv[i] = facinv[i - 1] * inv[i] % mod;
    }   
    f[0] = 1;
    cdq(0, n);  
    for(int i = 0; i <= n; ++i) ans = (ans + f[i] * fac[i] % mod) % mod;
    printf("%lld\n", ans);
    return 0;
}
View Code

 

posted @ 2017-11-15 14:52  19992147  阅读(139)  评论(0编辑  收藏  举报