[CF438E] The Child and Binary Tree

link

$solution:$

考虑朴素 $dp$,设 $f_i$ 为权值为 $i$ 的二叉树个数,$C_i$ 表示 $i$ 出现的次数。

则 $$f_i=\sum_{j=0}^i C_j\sum_{k=0}^{i-j} f_k\times f_{i-j-k}$$ 

时间复杂度 $O(n^3)$ 。

而简单思考发现 $j+k+(i-j-k)=i$ ,其实整个 $dp$ 过程是三个式子的卷积。

考虑将$f,c$ 写成生成函数。

则 $$F=1+C\times F^2$$ $+1$ 是因为 $f_0=1$ 。

则 $F=\dfrac{2}{1+\sqrt{1-4F}}$ 。

直接多项式操作即可,时间复杂度 $O(n\log n)$ 。

#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#define int long long
#define mod 998244353
using namespace std;
inline int read(){
    int f=1,ans=0;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){ans=ans*10+c-'0';c=getchar();}
    return f*ans;
}
const int MAXN=2000001;
const int Inv2=499122177;
namespace Poly{
    int flip[MAXN];
    inline int ksm(int a,int b){
        int ans=1;
        while(b){
            if(b&1) ans*=a,ans%=mod;
            a*=a,a%=mod;
            b>>=1;
        }return ans;
    }
    inline int Get(int tmp){
        int lim=1;while(lim<=tmp) lim<<=1;
        for(int i=0;i<lim;i++) flip[i]=((flip[i>>1]>>1)|(i&1?lim>>1:0));
        return lim;
    }
    inline void NTT(int *f,int lim,int opt){
        for(int i=0;i<lim;i++) if(i<flip[i]) swap(f[i],f[flip[i]]);
        for(register int p=2;p<=lim;p<<=1){
            int len=p>>1,buf=ksm(3,(mod-1)/p);
            if(opt==-1) buf=ksm(buf,mod-2);
            for(register int be=0;be<lim;be+=p){
                int tmp=1;
                for(register int l=be;l<be+len;l++){
                    int t=f[l+len]*tmp;t%=mod;
                    f[l+len]=(f[l]-t+mod)%mod,f[l]=(f[l]+t)%mod;
                    tmp*=buf,tmp%=mod;
                }
            }
        }if(opt==-1){
            int Inv=ksm(lim,mod-2);
            for(int i=0;i<lim;i++) f[i]*=Inv,f[i]%=mod;
        }return;
    }
    int f[MAXN],g[MAXN];
    inline void Getinv(int lim,int *a,int *b){
        if(lim==1){memset(b,0,sizeof(b));b[0]=ksm(a[0],mod-2);return;}
        Getinv(lim>>1,a,b);
        int len=Get(lim);
        for(register int i=0;i<len;i++) f[i]=a[i],g[i]=b[i];for(int i=lim;i<=len;i++) f[i]=g[i]=0;
        NTT(f,len,1),NTT(g,len,1);
        for(register int i=0;i<len;i++) g[i]=(g[i]*(((2-f[i]*g[i])%mod)+mod))%mod;
        NTT(g,len,-1);
        for(register int i=0;i<lim;i++) b[i]=g[i];
        return;
    }
    int F[MAXN],G[MAXN],R[MAXN];
    inline void Getsqrt(int lim,int *a,int *b){
        if(lim==1){b[0]=1;return;}
        Getsqrt(lim>>1,a,b);
        int len=Get(lim);
        for(register int i=0;i<lim;i++) F[i]=a[i],F[i]%=mod;for(int i=lim;i<len;i++) F[i]=0;
        Getinv(len,b,R);
        len=Get(lim);for(register int i=lim;i<len;i++) R[i]=0;
        NTT(F,len,1),NTT(R,len,1);
        for(int i=0;i<len;i++) F[i]*=R[i],F[i]%=mod,R[i]=0;
        NTT(F,len,-1);
        for(int i=0;i<lim;i++) b[i]=((F[i]+b[i])*Inv2)%mod;
        return;
    }
}
int n,m,N,M,f[MAXN],g[MAXN],num[MAXN];
signed main(){
    n=read(),m=read();
    for(register int i=1;i<=n;i++) f[read()]++;
    int len=Poly::Get(m);
    for(register int i=1;i<len;i++) f[i]=(((-4*f[i])%mod)+mod)%mod; f[0]++;
    Poly::Getsqrt(len,f,g);
    g[0]+=1;
    for(register int i=0;i<len;i++) f[i]=g[i];memset(g,0,sizeof(g));
    Poly::Getinv(len,f,g);
    for(register int i=0;i<len;i++) g[i]<<=1,g[i]%=mod;
    for(register int i=1;i<=m;i++) printf("%lld\n",g[i]);
    return 0;
}
View Code

 

posted @ 2019-08-10 19:19  siruiyang_sry  阅读(203)  评论(1编辑  收藏  举报