BZOJ 3625: [Codeforces Round #250]小朋友和二叉树

3625: [Codeforces Round #250]小朋友和二叉树

Time Limit: 40 Sec  Memory Limit: 256 MB
Submit: 343  Solved: 143
[Submit][Status][Discuss]

Description

我们的小朋友很喜欢计算机科学,而且尤其喜欢二叉树。
考虑一个含有n个互异正整数的序列c[1],c[2],...,c[n]。如果一棵带点权的有根二叉树满足其所有顶点的权值都在集合{c[1],c[2],...,c[n]}中,我们的小朋友就会将其称作神犇的。并且他认为,一棵带点权的树的权值,是其所有顶点权值的总和。
给出一个整数m,你能对于任意的s(1<=s<=m)计算出权值为s的神犇二叉树的个数吗?请参照样例以更好的理解什么样的两棵二叉树会被视为不同的。
我们只需要知道答案关于998244353(7*17*2^23+1,一个质数)取模后的值。

Input

第一行有2个整数 n,m(1<=n<=10^5; 1<=m<=10^5)。
第二行有n个用空格隔开的互异的整数 c[1],c[2],...,c[n](1<=c[i]<=10^5)。

Output

输出m行,每行有一个整数。第i行应当含有权值恰为i的神犇二叉树的总数。请输出答案关于998244353(=7*17*2^23+1,一个质数)取模后的结果。

Sample Input

样例一:
2 3
1 2
样例二:
3 10
9 4 3
样例三:
5 10
13 10 6 4 15

Sample Output

样例一:
1
3
9
样例二:
0
0
1
1
0
2
4
2
6
15
样例三:
0
0
0
1
0
1
0
2
0
5

HINT

 

对于第一个样例,有9个权值恰好为3的神犇二叉树:

 

Source

分析:

定义$f(x)$为权值为$x$的二叉树个数,$g(x)$为$c$集合中是否存在一个权值为$x$的元素,也就是一个$01$序列...

然后考虑$f$和$g$的关系,$f=f^2g+1$,其中$g$是枚举根节点的权值,$f^2$分别枚举左右子树的权值,$+1$是加上一个空树的情况...

这样我们解出来$f=\frac {1± \sqrt {1-4g}}{2g}$,然后把$g$的生成函数带进去求解即可...

这样就需要多项式求逆和多项式开方...

多项式求逆:

我们考虑倍增的思想,我们现在已经知道了$A(x)B(x)=1 (mod  x^m)$,求$C(x)$满足$A(x)C(x)=1 (mod  x^{2m})$...

可以得出$C(x)=B(x)(2-A(x)B(x))$...

多项式开方:

依旧是倍增的思想,我们现在已经知道了$B(x)B(x)=A(x) (mod  x^m)$,求$C(x)$满足$C(x)C(x)=A(x) (mod  x^{2m})$...

可以得出$C(x)=\frac {B^2(x)+A(x)}{2B(x)}$...

代码:

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
//by NeighThorn
using namespace std;

const int maxn=500000+5,mod=998244353,M=499122177,G=5;

int n,L,num,R[maxn],a[maxn],b[maxn],c[maxn],d[maxn];

inline int power(int x,int y){
    int res=1;
    while(y){
	    if(y&1)
	        res=1LL*res*x%mod;
	    x=1LL*x*x%mod,y>>=1;
    }
    return res;
}

inline void NTT(int *a,int f,int n,int L){
//	for(int i=0;i<n;i++) cout<<a[i]<<" ";puts("");
    for(int i=0;i<n;i++)
        R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
    for(int i=0;i<n;i++) 
        if(i<R[i]) swap(a[i],a[R[i]]);
    for(int i=1;i<n;i<<=1){
        int wn=power(G,(mod-1)/(i<<1));
        if(f==-1) wn=power(wn,mod-2);
        for(int j=0;j<n;j+=(i<<1)){
            int w=1;
            for(int k=0;k<i;k++,w=1LL*w*wn%mod){
                int x=a[j+k],y=1LL*w*a[j+k+i]%mod;
                a[j+k]=((x+y)%mod+mod)%mod;
                a[j+k+i]=((x-y)%mod+mod)%mod;
            }
        }
    }
    if(f==-1){
        int tmp=power(n,mod-2);
        for(int i=0;i<n;i++)
            a[i]=1LL*a[i]*tmp%mod;
    }
//    for(int i=0;i<n;i++) cout<<a[i]<<" ";puts("");puts("");
}

//b(2-a*b)

inline void inverse(int *a,int *b,int n,int L){
    if(n==1){
        b[0]=power(a[0],mod-2);return;
    }
    inverse(a,b,n>>1,L-1);
    memcpy(c,a,n*sizeof(int));
    memset(c+n,0,n*sizeof(int));
    /*cout<<"inverse: "<<endl;*/NTT(c,1,n<<1,L+1);NTT(b,1,n<<1,L+1);
    for(int i=0;i<n<<1;i++) b[i]=1LL*b[i]*((2-1LL*c[i]*b[i]%mod+mod)%mod)%mod;
    /*cout<<"inverse: "<<endl;*/NTT(b,-1,n<<1,L+1);
    memset(b+n,0,n*sizeof(int));
}

//(b^2+a)/(2*b)

inline void sqrt(int *a,int *b,int n,int L){
	if(n==1){
		b[0]=1;return;
	}
	sqrt(a,b,n>>1,L-1);
	memset(d,0,n*2*sizeof(int));
	inverse(b,d,n,L);
//	cout<<"d: "<<endl;for(int i=0;i<n;i++) cout<<d[i]<<" ";cout<<endl;
	memcpy(c,a,n*sizeof(int));
	memset(c+n,0,n*sizeof(int));
	/*cout<<"sqrt: "<<endl;*/NTT(c,1,n<<1,L+1),NTT(b,1,n<<1,L+1);NTT(d,1,n<<1,L+1);
	for(int i=0;i<n<<1;i++) b[i]=(1LL*c[i]*d[i]%mod+b[i])%mod*M%mod;
//	cout<<"b: "<<endl;for(int i=0;i<n<<1;i++) cout<<b[i]<<" ";cout<<endl;
	/*cout<<"sqrt: "<<endl;*/NTT(b,-1,n<<1,L+1);
	memset(b+n,0,n*sizeof(int));
}

signed main(void){
//	freopen("out.txt","w",stdout);
    scanf("%d%d",&n,&num);a[0]=1;
    for(int i=1,x;i<=n;i++){
        scanf("%d",&x);
        if(x<=num)
            a[x]=mod-4;
    }
    for(n=1;n<=num;n<<=1) L++;
    sqrt(a,b,n,L);
    memcpy(a,b,n*sizeof(int));a[0]++;
    memset(b,0,n*sizeof(int));inverse(a,b,n,L);
    for(int i=1;i<=num;i++)
        printf("%d\n",(b[i]<<1)%mod);
    return 0;
}

  


By NeighThorn

posted @ 2017-03-03 16:35  NeighThorn  阅读(1162)  评论(0编辑  收藏  举报