P5434 有标号荒漠计数

P5434 有标号荒漠计数

数学、组合计数、生成函数


其实森林也是沙漠

挺难的一道题,并非独立切。

\(f_n\)\(n\) 个点仙人掌数量,\(F(x)\)\(\{f_n\}\) 的 EGF,答案即为 \([\frac{x^n}{n!}]\exp(F(x))\)

图论计数,一般考虑设函数方程或者将一个组合对象拆分为多个小组合对象,这里考虑设函数方程。

发现 \(F(x)\) 其实不能转移,考虑设

\[G(x)=\sum\limits_{i=0}^{+\infty}\frac{f_i}{(i-1)!} \]

其组合意义为有根仙人掌的生成函数。

对一个仙人掌,考虑任取一个点 \(u\) 作为根。与 \(u\) 相连的桥,其贡献为 \(G(x)\);与 \(u\) 相连的环,其贡献为 \(\frac{1}{2}\sum\limits_{i=2}^{+\infty}G^i(x)=\frac{G^2(x)}{2(1-G(x))}\) 这里除以 \(2\)、不除阶乘是因为与 \(u\) 相连的环是可翻转的排列。
将桥和环任意组合(无序无数量限制,即 \(\exp\))可以得到关于 \(G(x)\) 的函数方程。

\[G(x)=x\exp\left(G(x)+\frac{G^2(x)}{2\left(1-G(x)\right)}\right) \]

牛顿迭代即可,时间复杂度 \(\Theta(n\log n)\)


#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int,int> pii;
typedef pair<int,ll> pil;
typedef pair<ll,int> pli;
typedef pair<ll,ll> pll;
const int MOD=998244353,G=3,INV2=(MOD+1)>>1;
void add(int &x,int y){
	x+=y;
	if(x>=MOD) x-=MOD;
}
int qpow(int a,int b){
	int mul=1;
	while(b){
		if(b&1) mul=(ll)mul*a%MOD;
		a=(ll)a*a%MOD;
		b>>=1;
	}
	return mul;
}
const int N=800005;
int fact[N],invfact[N],inv[N],stp[N],invstp[N];
void init(){
	fact[0]=1;
	for(int i=1;i<N;i++) fact[i]=(ll)i*fact[i-1]%MOD;
	invfact[N-1]=qpow(fact[N-1],MOD-2);
	for(int i=N-1;i>0;i--) invfact[i-1]=(ll)i*invfact[i]%MOD;
	for(int i=1;i<N;i++) inv[i]=(ll)invfact[i]*fact[i-1]%MOD;
	for(int w=1;w<N;w<<=1) stp[w]=qpow(G,(MOD-1)/w),invstp[w]=qpow(stp[w],MOD-2);
}
namespace Poly{
	void ntt(vector<int> &f,bool flag=false){
		int n=f.size();
		vector<int> rev(n);
		for(int i=1;i<n;i++){
			rev[i]=rev[i>>1]>>1;
			if(i&1) rev[i]|=(n>>1);
		}
		for(int i=1;i<n;i++) if(rev[i]<i) swap(f[rev[i]],f[i]);
		for(int w=1;w<n;w<<=1){
			int step=stp[w<<1];
			if(flag) step=invstp[w<<1];
			for(int i=0;i<n;i+=(w<<1)){
				int cur=1;
				for(int j=i;j<i+w;j++,cur=(ll)cur*step%MOD){
					int a=f[j],b=(ll)cur*f[j+w]%MOD;
					f[j]=(a+b)%MOD,f[j+w]=(a-b+MOD)%MOD;
				}
			}
		}
		if(flag){
			int inv=::inv[n];
			for(int i=0;i<n;i++) f[i]=(ll)inv*f[i]%MOD;
		}
	}
	vector<int> inv(vector<int> f,int n=-1){
		if(n==1) return {qpow(f[0],MOD-2)};
		if(n==-1) n=f.size();
		else f.resize(n);
		vector<int> h=inv(f,n>>1);
		f.resize(n<<1),h.resize(n<<1);
		ntt(f),ntt(h);
		for(int i=0;i<(n<<1);i++) f[i]=(2*h[i]%MOD-(ll)h[i]*h[i]%MOD*f[i]%MOD+MOD)%MOD;
		ntt(f,1);
		f.resize(n);
		return f;
	}
	vector<int> ln(vector<int> f){
		int n=f.size();
		vector<int> h=inv(f);
		for(int i=0;i+1<n;i++) f[i]=(ll)(i+1)*f[i+1]%MOD;
		f[n-1]=0;
		f.resize(n<<1),h.resize(n<<1);
		ntt(f),ntt(h);
		for(int i=0;i<(n<<1);i++) f[i]=(ll)f[i]*h[i]%MOD;
		ntt(f,1);
		for(int i=n-1;i>0;i--) f[i]=(ll)::inv[i]*f[i-1]%MOD;
		f[0]=0;
		f.resize(n);
		return f;
	}
	vector<int> exp(vector<int> f,int n=-1){
		if(n==1) return {1};
		if(n==-1) n=f.size();
		else f.resize(n);
		vector<int> h=exp(f,n>>1);
		h.resize(n);
		vector<int> g=ln(h);
		f.resize(n<<1),g.resize(n<<1),h.resize(n<<1);
		ntt(f),ntt(g),ntt(h);
		for(int i=0;i<(n<<1);i++) f[i]=(ll)h[i]*(1+MOD-g[i]+f[i])%MOD;
		ntt(f,1);
		f.resize(n);
		return f;
	}
	vector<int> mul(vector<int> f,vector<int> g,int n){
		f.resize(n<<1),g.resize(n<<1);
		ntt(f),ntt(g);
		for(int i=0;i<(n<<1);i++) f[i]=(ll)f[i]*g[i]%MOD;
		ntt(f,1);
		f.resize(n);
		return f;
	}
	vector<int> calc(int n){
		if(n==1) return {0};
		vector<int> f=calc(n>>1);
		f.resize(n);
		vector<int> g=f;
		for(int i=0;i<n;i++) g[i]=(ll)(MOD-2)*g[i]%MOD;
		add(g[0],2);
		g=mul(mul(f,f,n),inv(g),n);
		for(int i=0;i<n;i++) add(g[i],f[i]);
		vector<int> A=exp(g),tmp=f;
		for(int i=n-1;i>0;i--) A[i]=A[i-1];
		A[0]=0;
		A.resize(n<<2),f.resize(n<<2);
		ntt(A),ntt(f);
		vector<int> B(n<<2),C(n<<2);
		for(int i=0;i<(n<<2);i++){
			int tmp=(ll)(1+MOD-f[i])*(1+MOD-f[i])%MOD;
			B[i]=(ll)(A[i]-f[i]+MOD)*tmp%MOD;
			C[i]=((ll)A[i]*(tmp+(ll)f[i]*(1+MOD-f[i])%MOD+(ll)f[i]*f[i]%MOD*INV2%MOD)%MOD-tmp+MOD)%MOD;
		}
		ntt(B,1),ntt(C,1);
		B.resize(n),C.resize(n);
		B=mul(B,inv(C),n);
		B.resize(n);
		for(int i=0;i<n;i++) add(tmp[i],MOD-B[i]);
		return tmp;
	}
}
int n;
int main(){
	init();
	scanf("%d",&n);
	int len;
	for(len=1;len<=n;len<<=1);
	vector<int> f=Poly::calc(len);
	for(int i=1;i<len;i++) f[i]=(ll)inv[i]*f[i]%MOD;
	f=Poly::exp(f);
	printf("%lld\n",(ll)fact[n]*f[n]%MOD);
	return 0;
}
posted @ 2025-10-30 08:12  SmpaelFx  阅读(5)  评论(0)    收藏  举报