题解:P7275 计树

题意:给出一个数 \(n\),问对于所有节点 \(x\) 都满足存在 \(|x-y|=1\) 使得 \((x,y)\) 有边的树有多少个,\(n\le 10^5\)

做法:

我们肯定是考虑直接钦定连续段然后去计算,但是因为有可能在相邻段间连出边就爆炸了,考虑容斥,那么现在有两个问题。

  • 如何对于一种特定的连续段划分计算答案。

  • 如何分配容斥系数。

首先对于第一个问题比较简单,这是经典结论。现在有 \(n\) 个点,假设被分成 \(A_1,A_2,\cdots A_k\) 这些连通块,那么使他们联通的方案数为 \(n^{k-2}\prod\limits_{i=1}^k A_i\)。这个比较经典就不再证明。

然后在这个题中,假设我们划分出来 \(A_1,A_2,\cdots A_n\) 这些连续段,那么我们不妨把一个段的权值设为 \(A_i\times n\),最后除以 \(n^2\) 就是答案。

然后对于第二个我们考虑应该怎么做,我们考虑一个连续段的权值应该是 \([len\ge 2]\),我们枚举内部划分的情况,希望要得到一个柿子:

\[[len\ge 2] = \sum_{A_1+A_2+\cdots+A_k=len}\prod_{i=1}^k f(A_i) \]

这里 \(f(x)\)\(x\) 的容斥系数。

我们用 \(F\) 这个多项式来改写一下,\(F\)\(k\) 次项系数即是 \(f(k)\)

我们枚举 \(k\),那么其实我们可以把后面这个东西写成 \([x^{len}]F^k(x)\)。整合一下右侧,其实右侧就等于 \(\frac{1}{1-F(x)}-1\)。左侧其实就是 \((1+x+x^2+\cdots)-(1+x)=\frac{1}{1-x}-(1+x)\)

把这个方程解出来可以得到 \(F(x)=\frac{x^2}{x^2-x+1}\)

那么剩下就很好做了,我们直接记 \(G(x) = \sum i\times n\times f(i)x^i\),答案就是 \(\frac1{n^2}\times \frac{1}{1-G(x)}[x^n]\)

代码:

#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxn = 5e5 + 5, mod = 998244353, gb = 3, gi = (mod + 1) / gb;
int qpow(int x, int k, int p) {
	int res = 1;
	while(k) {
		if(k & 1)
			res = res * x % p;
		x = x * x % p, k >>= 1;
	}
	return res;
}
int rev[maxn];
void init(int len) {
	for (int i = 0; i < len; i++) {
		rev[i] = rev[i >> 1] >> 1;
		if(i & 1)
			rev[i] |= (len >> 1);
	}
}
struct Poly {
	vector<int> a;
	void resize(int N) {
		a.resize(N);
	}
	int size() {
		return a.size();
	}
	int& operator[](int x) {
		return a[x];
	}
	void NTT(int f) {
		for (int i = 0; i < size(); i++)
			if(i < rev[i])
				swap(a[i], a[rev[i]]);
		for (int h = 2; h <= size(); h <<= 1) {
			int d = qpow((f == 1 ? gb : gi), (mod - 1) / h, mod);
			for (int i = 0; i < size(); i += h) {
				int nw = 1;
				for (int j = i; j < i + h / 2; j++) {
					int a0 = a[j], a1 = a[j + h / 2] * nw % mod;
					a[j] = (a0 + a1) % mod, a[j + h / 2] = (a0 - a1 + mod) % mod;
					nw = nw * d % mod;
				}
			}
		}
		if(f == -1) {
			int inv = qpow(size(), mod - 2, mod);
			for (int i = 0; i < size(); i++)	
				a[i] = a[i] * inv % mod;
		}
	}
	friend Poly operator*(Poly f, Poly g) {
		int len = 1, t = f.size() + g.size() - 1;
		while(len < t)
			len <<= 1;
		init(len), f.resize(len), g.resize(len);
		f.NTT(1), g.NTT(1);
		for (int i = 0; i < len; i++)
			f[i] = f[i] * g[i] % mod;
		f.NTT(-1);
		f.resize(t);
		return f; 
	}
	void get_neg() {
		for (int i = 0; i < size(); i++)
			a[i] = (mod - a[i]) % mod;
	}
	friend Poly operator+(Poly f, int v) {
		f[0] = (f[0] + v) % mod;
		return f;
	}
	void print() {
		for (int i = 0; i < size(); i++)
			cout << a[i] << " ";
		cout << endl;
	}
} f;
Poly get_inv(Poly f, int lim) {
	if(lim == 1) {
		f.resize(1);
		f[0] = qpow(f[0], mod - 2, mod);
		return f;
	}
	Poly g = get_inv(f, lim + 1 >> 1);
	int len = 1;
	while(len < lim * 2)
		len <<= 1;
	init(len);
	f.resize(lim), f.resize(len), g.resize(len);
	f.NTT(1), g.NTT(1);
	for (int i = 0; i < len; i++)
		f[i] = (2 * g[i] - f[i] * g[i] % mod * g[i] % mod + mod) % mod;
	f.NTT(-1);
	f.resize(lim);
	return f;
}
int n;
signed main() {
	cin >> n;
	f.resize(n + 1);
	f[0] = 1, f[1] = mod - 1, f[2] = 1;
	f = get_inv(f, n + 1);
	for (int i = n; i >= 2; i--)
		f[i] = f[i - 2];
	f[0] = f[1] = 0;
	for (int i = 1; i <= n; i++)
		f[i] = i * n % mod * f[i] % mod;
	f.get_neg();
	//(f + 1).print();
	f = get_inv(f + 1, n + 1);
	cout << f[n] * qpow(n, mod - 3, mod) % mod << endl;
	return 0; 
}
posted @ 2025-10-17 14:49  LUlululu1616  阅读(12)  评论(0)    收藏  举报