loj 6391 「THUPC2018」淘米神的树 / Tommy - 多项式

题目传送门

  传送门

  当 $a = b$ 的时候答案是:

$$
n! \prod_i \frac{1}{sz_i}
$$

  现在考虑 $a\neq b$ 的情形,考虑 $a \rightarrow b$ 的链。不妨设这些点依次是 $v_1 = a, v_1, \cdots, v_m = b$。

  注意到硬点 $v_k$ 是链上最后一个点被染红的点非常难算。

  考虑忽略掉其中某条边,然后硬点它右侧的点满足 $v_{i + 1}$ 在 $v_i$ 之前被染红,左侧的点满足 $v_i$ 在 $v_{i + 1} $ 之前被染红。如果认为 $v_0$ 和 $v_{1}$ 有边相连, $v_{m}$ 和 $v_{m + 1}$ 有边相连。这些边每条边被忽略的方案数之和等于所求的方案数的两倍。

  考虑怎么求这些方案数。注意到断掉一条边只会改变这条链上的点的 size。设断掉链上所有边后,前 $i$ 个点的 size 之和为 $s_i$,那么断掉 $v_i$ 和 $v_{i + 1}$ 之间的边,相当于是求 $\prod_{j = 0, i\neq j}^m |s_i - s_j|$。讨论一下拆掉绝对值,然后用洛必达法则多点求值就行了。

  时间复杂度 $O(n\log^2 n)$

Code

/**
 * loj
 * Problem#6391
 * Accepted
 * Time: 27112ms
 * Memory: 80392k
 */
#include <bits/stdc++.h>
using namespace std;
typedef bool boolean;

#define ll long long

template <typename T>
void pfill(T* pst, const T* ped, T val) {
	for ( ; pst != ped; *(pst++) = val);
}

template <typename T>
void pcopy(T* pst, const T* ped, T* pval) {
	for ( ; pst != ped; *(pst++) = *(pval++));
}

const int N = 524288;
const int Mod = 998244353;
const int bzmax = 20;
const int g = 3;

void exgcd(int a, int b, int& x, int& y) {
	if (!b) {
		x = 1, y = 0;
	} else {
		exgcd(b, a % b, y, x);
		y -= (a / b) * x;
	}
}

int inv(int a, int Mod) {
	int x, y;
	exgcd(a, Mod, x, y);
	return (x < 0) ? (x + Mod) : (x);
}

template <const int Mod = :: Mod>
class Z {
	public:
		int v;

		Z() : v(0) {	}
		Z(int x) : v(x){	}
		Z(ll x) : v(x % Mod) {	}

		friend Z operator + (const Z& a, const Z& b) {
			int x;
			return Z(((x = a.v + b.v) >= Mod) ? (x - Mod) : (x));
		}
		friend Z operator - (const Z& a, const Z& b) {
			int x;
			return Z(((x = a.v - b.v) < 0) ? (x + Mod) : (x));
		}
		friend Z operator * (const Z& a, const Z& b) {
			return Z(a.v * 1ll * b.v);
		}
		friend Z operator ~ (const Z& a) {
			return inv(a.v, Mod);
		}
		friend Z operator - (const Z& a) {
			return Z(0) - a;
		}
		Z& operator += (Z b) {
			return *this = *this + b;
		}
		Z& operator -= (Z b) {
			return *this = *this - b;
		}
		Z& operator *= (Z b) {
			return *this = *this * b;
		}
		friend boolean operator == (const Z& a, const Z& b) {
			return a.v == b.v;
		} 
};

typedef Z<> Zi;

Zi qpow(Zi a, int p) {
	if (p < Mod - 1)
		p += Mod - 1;
	Zi rt = 1, pa = a;
	for ( ; p; p >>= 1, pa = pa * pa) {
		if (p & 1) {
			rt = rt * pa;
		}
	}
	return rt;
}

const Zi inv2 ((Mod + 1) >> 1);

class NTT {
	private:
		Zi gn[bzmax + 4], _gn[bzmax + 4];
	public:
		
		NTT() {
			for (int i = 0; i <= bzmax; i++) {
				gn[i] = qpow(Zi(g), (Mod - 1) >> i);
				_gn[i] = qpow(Zi(g), -((Mod - 1) >> i));
			}
		}

		void operator () (Zi* f, int len, int sgn) {
			for (int i = 1, j = len >> 1, k; i < len - 1; i++, j += k) {
				if (i < j)
					swap(f[i], f[j]);
				for (k = len >> 1; k <= j; j -= k, k >>= 1);
			}
			
			Zi *wn = (sgn > 0) ? (gn + 1) : (_gn + 1), w, a, b;
			for (int l = 2, hl; l <= len; l <<= 1, wn++) {
				hl = l >> 1, w = 1;
				for (int i = 0; i < len; i += l, w = 1) {
					for (int j = 0; j < hl; j++, w *= *wn) {
						a = f[i + j], b = f[i + j + hl] * w;
						f[i + j] = a + b;
						f[i + j + hl] = a - b;
					}
				}
			}

			if (sgn < 0) {
				Zi invlen = ~Zi(len);
				for (int i = 0; i < len; i++) {
					f[i] *= invlen;
				}
			}
		}

		int correct_len(int len) {
			int m = 1;
			for ( ; m <= len; m <<= 1);
			return m;
		}
} NTT;

void pol_inverse(Zi* f, Zi* g, int n) {
	static Zi A[N];
	if (n == 1) {
		g[0] = ~f[0];
	} else {
		int hn = (n + 1) >> 1, t = NTT.correct_len(n << 1 | 1);
		pol_inverse(f, g, hn);
		
		pcopy(A, A + n, f);
		pfill(A + n, A + t, Zi(0));
		pfill(g + hn, g + t, Zi(0));
		NTT(A, t, 1);
		NTT(g, t, 1);
		for (int i = 0; i < t; i++) {
			g[i] = g[i] * (Zi(2) - g[i] * A[i]);
		}
		NTT(g, t, -1);
		pfill(g + n, g + t, Zi(0));
	}
}

void pol_sqrt(Zi* f, Zi* g, int n) {
	static Zi A[N], B[N];
	if (n == 1) {
		g[0] = f[0];
	} else {
		int hn = (n + 1) >> 1, t = NTT.correct_len(n + n);
		
		pol_sqrt(f, g, hn);

		pfill(g + hn, g + n, Zi(0));
		for (int i = 0; i < hn; i++)
			A[i] = g[i] + g[i];
		pfill(A + hn, A + t, Zi(0));
		pol_inverse(A, B, n);
		pcopy(A, A + n, f);
		pfill(A + n, A + t, Zi(0));
		NTT(A, t, 1);
		NTT(B, t, 1);
		for (int i = 0; i < t; i++)
			A[i] *= B[i];
		NTT(A, t, -1);
		for (int i = 0; i < n; i++)
			g[i] = g[i] * inv2 + A[i];
	}
}

typedef class Poly : public vector<Zi> {
	public:
		using vector<Zi>::vector;

		Poly& fix(int sz) {
			resize(sz);
			return *this;
		}
} Poly;

Poly operator + (Poly A, Poly B) {
	int n = A.size(), m = B.size();
	int t = max(n, m);
	A.resize(t), B.resize(t);
	for (int i = 0; i < t; i++) {
		A[i] += B[i];
	}
	return A;
}

Poly operator - (Poly A, Poly B) {
	int n = A.size(), m = B.size();
	int t = max(n, m);
	A.resize(t), B.resize(t);
	for (int i = 0; i < t; i++) {
		A[i] -= B[i];
	}
	return A;
}

Poly sqrt(Poly a) {
	Poly rt (a.size());
	pol_sqrt(a.data(), rt.data(), a.size());
	return rt;
}

Poly operator * (Poly A, Poly B) {
	int n = A.size(), m = B.size();
	int k = NTT.correct_len(n + m - 1);
	if (n < 20 || m < 20) {
		Poly rt (n + m - 1);
		for (int i = 0; i < n; i++) {
			for (int j = 0; j < m; j++) {
				rt[i + j] += A[i] * B[j];
			}
		}
		return rt;
	}
	A.resize(k), B.resize(k);
	NTT(A.data(), k, 1);
	NTT(B.data(), k, 1);
	for (int i = 0; i < k; i++) {
		A[i] *= B[i];
	}
	NTT(A.data(), k, -1);
	A.resize(n + m - 1);
	return A;
}

Poly operator ~ (Poly f) {
	int n = f.size(), t = NTT.correct_len((n << 1) | 1);
	Poly rt (t);
	f.resize(t);
	pol_inverse(f.data(), rt.data(), n);
	rt.resize(n);
	return rt;
}

Poly operator / (Poly A, Poly B) {
	int n = A.size(), m = B.size();
	if (n < m) {
		return Poly {0};
	}
	int r = n - m + 1;
	reverse(A.begin(), A.end());
	reverse(B.begin(), B.end());
	A.resize(r), B.resize(r);
	A = A * ~B;
	A.resize(r);
	reverse(A.begin(), A.end());
	return A;
}

Poly operator % (Poly A, Poly B) {
	int n = A.size(), m = B.size();
	if (n < m) {
		return A;
	}
	if (m == 1) {
		return Poly {0};
	}
	A = A - A / B * B;
	A.resize(m - 1);
	return A;
}

Zi Inv[N];
void init_inv(int n) {
	Inv[0] = 0, Inv[1] = 1;
	for (int i = 2; i <= n; i++) {
		Inv[i] = Inv[Mod % i] * Zi((Mod - (Mod / i)));
	}
}

void diff(Poly& f) {
	if (f.size() == 1) {
		f[0] = 0;
		return;
	}
	for (int i = 1; i < (signed) f.size(); i++) {
		f[i - 1] = f[i] * Zi(i);
	}
	f.resize(f.size() - 1);
}
void integ(Poly& f) {
	f.resize(f.size() + 1);
	for (int i = (signed) f.size() - 1; i; i--) {
		f[i] = f[i - 1] * Inv[i];
	}
	f[0] = 0;
}

Poly ln(Poly f) {
	int n = f.size();
	Poly h = f;
	diff(h);
	f = h * ~f;
	f.resize(n - 1);
	integ(f);
	return f;
}

void pol_exp(Poly& f, Poly& g, int n) {
	Poly h;
	if (n == 1) {
		g.resize(1);
		g[0] = 1;
	} else {
		int hn = (n + 1) >> 1;
		pol_exp(f, g, hn);
		
		h.resize(n), g.resize(n);
		pcopy(h.data(), h.data() + n, f.data());

		g = g * (Poly{1} - ln(g) + h);
		g.resize(n);
	}
}

Poly exp(Poly f) {
	int n = f.size();
	Poly rt;
	pol_exp(f, rt, n);
	return rt;
}

class PolyBuilder {
	protected:
		int num;
		Poly P[N << 1];
		
		void _init(int *x, int l, int r) {
			if (l == r) {
				P[num++] = Poly{-Zi(x[l]), Zi(1)};
				return;
			}
			int mid = (l + r) >> 1;
			int curid = num++;
			_init(x, l, mid);
			int rid = num;
			_init(x, mid + 1, r);
			P[curid] = P[curid + 1] * P[rid];
		}

		void _evalute(Poly f, Zi* y, int l, int r) {
			f = f % P[num++];
			if (l == r) {
				y[l] = f[0];
				return;
			}
			int mid = (l + r) >> 1;
			_evalute(f, y, l, mid);
			_evalute(f, y, mid + 1, r);
		}
	public:
		Poly evalute(Poly f, int* x, int n) {
			Poly rt(n);
			num = 0;
			_init(x, 0, n - 1);
			num = 0;
			_evalute(f, rt.data(), 0, n - 1);
			return rt;
		}
} PolyBuilder;

Poly dividing(int *a, int l, int r) {
	if (l == r)
		return Poly {-Zi(a[l]), 1};
	int mid = (l + r) >> 1;
	return dividing(a, l, mid) * dividing(a, mid + 1, r);
}

int n, a, b;
boolean on[N];
int fa[N], sz[N];
vector<int> G[N];

void dfs(int p, int fa) {
	sz[p] = 1;
	::fa[p] = fa;
	for (auto e : G[p]) {
		if (e ^ fa) {
			dfs(e, p);
			sz[p] += sz[e];
		}
	}
}

int main() {
	scanf("%d%d%d", &n, &a, &b);
	for (int i = 1, u, v; i < n; i++) {
		scanf("%d%d", &u, &v);
		G[u].push_back(v);
		G[v].push_back(u);
	}
	dfs(a, 0);
	init_inv(n);
	if (a == b) {
		Zi ans = 1;
		for (int i = 1; i <= n; i++)
			ans = ans * i;
		for (int i = 1; i <= n; i++)
			ans *= Inv[sz[i]];
		printf("%d\n", ans.v);
		return 0;
	}
	vector<int> c {0};
	for (int p = b; p; p = fa[p]) {
		on[p] = true;
		c.push_back(sz[p]);
	}
	Poly F = dividing(c.data(), 0, (signed) c.size() - 1);
	diff(F);
	Poly y = PolyBuilder.evalute(F, c.data(), c.size());
	Zi sgn = (F.size() & 1) ? (Zi(Mod - 1)) : Zi(1), sum = 0;
	for (auto t : y) {
		sum = sum + ~t * (sgn = -sgn);
	}
	for (int i = 1; i <= n; i++)
		sum = sum * i;
	for (int i = 1; i <= n; i++)
		sum = sum * ((!on[i]) ? (Inv[sz[i]]) : Zi(1));
	sum = sum * ((Mod + 1) >> 1);
	printf("%d\n", sum.v);
	return 0;
} 
posted @ 2020-02-22 18:15  阿波罗2003  阅读(291)  评论(0编辑  收藏  举报