洛谷 P7275 - 计树

最终方案肯定是将 \([1,n]\) 划分为若干个长度 \(\ge 2\) 的区间 \([l_i,r_i]\) 满足 \(l_i\sim r_i\) 按照 \(l_i\to l_i+1\to l_i+2\to\cdots\to r_i\) 的方案连成一条链,剩余点之间连边满足不存在其他边两端点差为 \(1\)

考虑对后面那个限制进行容斥,即将 \([1,n]\) 划分为一些长度 \(\ge 2\) 的连续段并钦定它们必须连成一条链。然后乘以适当的容斥系数并求和。先考虑钦定一些连续段怎么算方案数,这是经典问题。答案是 \(n^{k-2}·\prod(r_i-l_i+1)\),其中 \(k\) 是你划分出的连续段个数。接下来考虑怎么钦定容斥系数 \(A(x)\)。因为对于一组方案,如果其中有一个长度为 \(len\) 的连续段,那么所有对其有贡献的拆分 \(b_i\) 均满足 \(\sum b_i=len\),它们的容斥系数之和为 \(\sum\limits_{\sum b_i=len}\prod[x^{b_i}]A(x)\),我们希望 \(\sum\limits_{\sum b_i=len}\prod[x^{b_i}]A(x)=[len\ge 2]\),而根据生成函数的知识,前者等于 \([x^{len}](A(x)+A^2(x)+A^3(x)+\cdots)=[x^{len}]\dfrac{1}{1-A(x)}-1\),因此 \(\dfrac{1}{1-A(x)}-1=\dfrac{1}{1-x}-1-x\),解得 \(A(x)=\dfrac{x^2}{x^2-x+1}\),手玩可知系数为 \(\{0,0,1,1,0,-1,-1,0,1,1,0,-1,-1\}\)(0-indexed)。

这样,考虑 \(B(x)=\sum\limits_{t}tnA_tx^t\),答案就是 \(n^{-2}[x^n]\sum\limits_{i}B^i(x)=n^{-2}[x^n]\dfrac{1}{1-B(x)}\),多项式求逆搞定。

const int MAXN=1e5;
const int MAXP=1<<18;
const int pr=3;
const int ipr=332748118;
const int MOD=998244353;
int calc(int x){
	x%=6;
	if(x==2||x==3)return 1;
	if(x==5||x==0)return MOD-1;
	return 0;
}
int qpow(int x,int e){int ret=1;for(;e;e>>=1,x=1ll*x*x%MOD)if(e&1)ret=1ll*ret*x%MOD;return ret;}
int n,rev[MAXP+5];
void NTT(vector<int> &a,int len,int type){
	int lg=31-__builtin_clz(len);
	for(int i=0;i<len;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<lg-1);
	for(int i=0;i<len;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
	for(int i=2;i<=len;i<<=1){
		int W=qpow((type<0)?ipr:pr,(MOD-1)/i);
		for(int j=0;j<len;j+=i){
			for(int k=0,w=1;k<(i>>1);k++,w=1ll*w*W%MOD){
				int X=a[j+k],Y=1ll*a[(i>>1)+j+k]*w%MOD;
				a[j+k]=(X+Y)%MOD;a[(i>>1)+j+k]=(X-Y+MOD)%MOD;
			}
		}
	}
	if(!~type){
		int iv=qpow(len,MOD-2);
		for(int i=0;i<len;i++)a[i]=1ll*a[i]*iv%MOD;
	}
}
vector<int>conv(vector<int>a,vector<int>b){
	int LEN=1;while(LEN<a.size()+b.size())LEN<<=1;
	a.resize(LEN,0);b.resize(LEN,0);NTT(a,LEN,1);NTT(b,LEN,1);
	for(int i=0;i<LEN;i++)a[i]=1ll*a[i]*b[i]%MOD;
	NTT(a,LEN,-1);return a;
}
vector<int>getinv(vector<int>a,int len){
	vector<int>b(len,0);b[0]=qpow(a[0],MOD-2);
	for(int i=2;i<=len;i<<=1){
		vector<int>c(b.begin(),b.begin()+(i>>1));
		vector<int>d(a.begin(),a.begin()+i);
		c=conv(c,c);d=conv(c,d);
		for(int j=0;j<i;j++)b[j]=(2*b[j]%MOD-d[j]+MOD)%MOD;
	}return b;
}
int main(){
	scanf("%d",&n);int LEN=1;while(LEN<=n)LEN<<=1;
	vector<int>a(LEN),b;
	for(int i=0;i<LEN;i++){
		if(!i)a[i]=1;
		else a[i]=(-1ll*calc(i)*i%MOD*n%MOD+MOD)%MOD;
	}b=getinv(a,LEN);
	printf("%d\n",1ll*b[n]*qpow(n,MOD-3)%MOD);
	return 0;
}
posted @ 2023-07-21 10:50  tzc_wk  阅读(67)  评论(0)    收藏  举报