洛谷 P5162 WD与积木【多项式求逆】

设f[i]为i个积木能堆出来的种类,g[i]为i个积木能堆出来的种类和

\[f[n]=\sum_{i=1}^{n}C_{n}^{i}g[n-i] \]

\[g[n]=\sum_{i=1}^{n}C_{n}^{i}f[n-i]+g[n] \]

理解就是选出包含最后一个的块,然后剩下的按照之前的拼
化简,设s为\( \frac{1}{n!} \),G为\( \frac{g[n]}{n!} \),F为\( \frac{fn]}{n!} \),把组合数拆开,变成卷积形式,然后化简就变成

\[F=\frac{1}{1-S} \]

\[G=F*(F-1) \]

用多项式求逆即可

#include<iostream>
#include<cstdio>
using namespace std;
const int N=1000005,mod=998244353;
int T,n=1e5+5,s[N],f[N],g[N],a[N],b[N],c[N],t[N],fac[N],inv[N],re[N],lm,bt,ans[N];
int read()
{
	int r=0,f=1;
	char p=getchar();
	while(p>'9'||p<'0')
	{
		if(p=='-')
			f=-1;
		p=getchar();
	}
	while(p>='0'&&p<='9')
	{
		r=r*10+p-48;
		p=getchar();
	}
	return r*f;
}
int ksm(int a,int b)
{
	int r=1;
	while(b)
	{
		if(b&1)
			r=1ll*r*a%mod;
		a=1ll*a*a%mod;
		b>>=1;
	}
	return r;
}
void dft(int a[],int f)
{
	for(int i=0;i<lm;i++)
		if(i<re[i])
			swap(a[i],a[re[i]]);
	for(int i=1;i<lm;i<<=1)
	{
		int wi=ksm(3,(mod-1)/(i*2));
		if(f==-1)
			wi=ksm(wi,mod-2);
		for(int k=0;k<lm;k+=(i<<1))
		{
			int w=1,x,y;
			for(int j=0;j<i;j++)
			{
				x=a[j+k];
				y=1ll*a[i+j+k]*w%mod;
				a[j+k]=(x+y)%mod;
				a[i+j+k]=(x-y+mod)%mod;
				w=1ll*w*wi%mod;
			}
		}
	}
	if(f==-1)
	{
		int ni=ksm(lm,mod-2);
		for(int i=0;i<lm;i++)
			a[i]=1ll*a[i]*ni%mod;
	}
}
void clc(int len)
{//cerr<<len<<endl;
	if(len==0)
	{
		c[0]=ksm(s[0],mod-2);
		return;
	}
	clc(len>>1);
	for(bt=1;(1<<bt)<=len;bt++);
	lm=(1<<bt);
	for(int i=0;i<lm;i++)
		re[i]=(re[i>>1]>>1)|((i&1)<<(bt-1));
	for(int i=0;i<len;i++)
		t[i]=s[i];
	dft(t,1);
	dft(c,1);
	for(int i=0;i<lm;i++)
		c[i]=1ll*c[i]*(mod+2-1ll*c[i]*t[i]%mod)%mod;
	dft(c,-1);
	for(int i=len;i<=2*lm+1;i++)
		c[i]=0;
	for(int i=0;i<=2*lm+1;i++)
		t[i]=0;
}
int main()
{
	n=1e5+5;
	fac[0]=inv[0]=1;
	for(int i=1;i<=n;i++)
		fac[i]=1ll*fac[i-1]*i%mod;
	inv[n]=ksm(fac[n],mod-2);
	for(int i=n-1;i>=1;i--)
		inv[i]=1ll*inv[i+1]*(i+1)%mod;
	for(int i=1;i<=n;i++)
		s[i]=mod-inv[i];
	s[0]++;
	for(bt=1;(1<<bt)<=2*n;bt++);
	lm=(1<<bt);
	clc(lm);
	for(int i=1;i<=n;i++)
		a[i]=b[i]=f[i]=c[i];
	f[0]=a[0]=1,b[0]=0;
	// for(int i=0;i<=10;i++)
		// cerr<<f[i]<<" ";cerr<<endl;
	for(bt=1;(1<<bt)<=2*n;bt++);
	lm=(1<<bt);
	for(int i=0;i<lm;i++)
		re[i]=(re[i>>1]>>1)|((i&1)<<(bt-1));
	dft(a,1);
	dft(b,1);
	for(int i=0;i<lm;i++)
		g[i]=1ll*a[i]*b[i]%mod;
	dft(g,-1);
	for(int i=1;i<=n;i++)
		ans[i]=1ll*g[i]*ksm(f[i],mod-2)%mod;//,printf("%d\n",ans[i]);
	T=read();
	while(T--)
		printf("%d\n",ans[read()]);
	return 0;
}
posted @ 2019-06-17 09:33  lokiii  阅读(165)  评论(0编辑  收藏  举报