bzoj 3625: [Codeforces Round #250]小朋友和二叉树【NTT+多项式开根求逆】

参考:https://www.cnblogs.com/2016gdgzoi509/p/8999460.html
列出生成函数方程,g(x)是价值x的个数

\[f(x)=g(x)*f^2(x)+1 \]

+1是f[0]=1
根据公式解出

\[f(x)=\frac{1+(-)\sqrt{1-4*g(x)}}{2*g(x)} \]

舍去+的答案,分式上下同乘\( 1-\sqrt{1-4*g(x)} \)

\[f(x)=\frac{2}{1+\sqrt{1-4*g(x)}} \]

然后套多项式开跟和求逆的板子即可

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
const int N=500005,mod=998244353,inv2=499122177;
int n,m,bt,lm,re[N],a[N],b[N],c[N],t[N];//,aa[N],bb[N],cc[N];
int read()
{
	int r=0,f=1;
	char p=getchar();
	while(p>'9'||p<'0')
	{
		if(p=='-')
			f=-1;
		p=getchar();
	}
	while(p>='0'&&p<='9')
	{
		r=r*10+p-48;
		p=getchar();
	}
	return r*f;
}
int ksm(int a,int b)
{
	int r=1;
	while(b)
	{
		if(b&1)
			r=1ll*r*a%mod;
		a=1ll*a*a%mod;
		b>>=1;
	}
	return r;
}
void dft(int a[],int f,int lm)
{
	// for(int i=0;i<lm;i++)
		// cerr<<a[i]<<" ";cerr<<endl;
	// bt=log2(lm);//cerr<<" "<<lm<<" "<<bt<<endl;
	// for(int i=0;i<lm;i++)
		// re[i]=(re[i>>1]>>1)|((i&1)<<(bt-1));
	for(int i=0;i<lm;i++)
		if(i<re[i])
			swap(a[i],a[re[i]]);
	for(int i=1;i<lm;i<<=1)
	{
		int wi=ksm(3,(mod-1)/(i*2));
		if(f==-1)
			wi=ksm(wi,mod-2);
		for(int k=0;k<lm;k+=(i<<1))
		{
			int w=1,x,y;
			for(int j=0;j<i;j++)
			{
				x=a[j+k],y=1ll*w*a[i+j+k]%mod;
				a[j+k]=(x+y)%mod,a[i+j+k]=(x-y+mod)%mod;
				w=1ll*w*wi%mod;
			}
		}
	}
	if(f==-1)
	{
		int ni=ksm(lm,mod-2);
		for(int i=0;i<lm;i++)
			a[i]=1ll*a[i]*ni%mod;
	}
	// for(int i=0;i<lm;i++)
		// cerr<<a[i]<<" ";cerr<<endl<<endl;;
}
void qiuni(int len)
{
	if(len==1)
	{
		c[0]=ksm(b[0],mod-2);
		return;
	}
	qiuni(len>>1);
	memcpy(t,b,sizeof(int)*len);
	memset(t+len,0,sizeof(int)*len);
	int bt=-1,lm=1;
	while(lm<len<<1)
		lm<<=1,bt++;
	for(int i=0;i<lm;i++) 
		re[i]=(re[i>>1]>>1)|((i&1)<<bt);
	dft(t,1,lm);
	dft(c,1,lm);
	for(int i=0;i<lm;i++) 
		c[i]=1ll*c[i]*(2-1ll*t[i]*c[i]%mod+mod)%mod;
	dft(c,-1,lm);
	memset(c+len,0,sizeof(int)*len);
}
void kaigen(int len)
{
	if(len==1)
	{
		b[0]=1;
		return;
	}
	kaigen(len>>1);
	memset(c,0,sizeof(int)*len);
	qiuni(len);
	memcpy(t,a,sizeof(int)*len);
	memset(t+len,0,sizeof(int)*len);
	int bt=-1,lm=1;
	while(lm<len<<1)
		lm<<=1,bt++;
	for(int i=0;i<lm;i++) 
		re[i]=(re[i>>1]>>1)|((i&1)<<bt);
	dft(t,1,lm);
	dft(b,1,lm);
	dft(c,1,lm);
	for(int i=0;i<len*2;i++)
		b[i]=1ll*(1ll*b[i]*b[i]+t[i])%mod*c[i]%mod*inv2%mod;
	dft(b,-1,lm);
	memset(b+len,0,sizeof(int)*len);
}
int main()
{
	n=read(),m=read();
	for(int i=1;i<=n;i++)
	{
		int x=read();
		a[x]++;
	}
	for(int i=1;i<=m;i++)
		a[i]=(-a[i]*4+mod)%mod;
	for(bt=1;(1<<bt)<=m;bt++);
	lm=(1<<bt);
	for(int i=0;i<lm;i++)
		if(a[i])
			a[i]=mod-4;
	a[0]++;
	kaigen(lm);
	// for(int i=0;i<n;i++)
		// cerr<<a[i]<<" "<<b[i]<<" "<<c[i]<<endl;
	b[0]=(b[0]+1)%mod;
	memset(c,0,sizeof(c));
	qiuni(lm);
	for(int i=1;i<=m;i++)
		printf("%d\n",c[i]*2%mod);
	return 0;
}
posted @ 2019-01-19 10:34  lokiii  阅读(...)  评论(...编辑  收藏