P14254 分割(divide)

牛逼题。个人认为如果在 CSP 或 NOIP 一般会放在 T2。

原题链接:P14254 分割(divide)

同时这是好题集第十八篇。

题目描述

你是洛咕咕王国的土地测绘官。洛咕咕王国并购了一块新的领土,这块新的领土正等待被分配。

这块领土可被认为是一棵有 \(n\) 个结点、结点编号为 \(1\)\(n\) 的树,根为编号 \(1\)。为了便于表述,我们把每个结点 \(i\) 在原树中的深度记作 \(d_i\),并规定根的深度为 \(1\)

你的王国有若干位诸侯希望购买土地,因此现在要从这棵树中选出 \(k\)两两不同的结点,并把它们的编号排成一个有序序列 \(b=(b_1,b_2,\dots,b_k)\)。这个序列必须满足两个条件:

第一,每个被选的结点都不是根,并且它们的深度是非降的,也就是对所有 \(1\le i<k\)\(1 < d_{b_i} \le d_{b_{i+1}}\)

第二,按照序列里每一个结点 \(b_i\)\(i=1,2,\cdots,k\)),把它们各自与父亲的连边断开。断开这些 \(k\) 条边后,原树会被分成 \(k+1\) 棵互不相交的连通子树。我们把这 \(k+1\) 棵子树依次编号。其中,第 \(1\) 棵到第 \(k\) 棵对应于根为 \(b_1,\dots,b_k\) 的那 \(k\) 棵子树,而第 \(k+1\) 棵子树则是剩下的、包含原来的树根 \(1\) 的那一棵(它的根仍记为 \(1\))。对于第 \(i\) 棵子树,把该子树中所有结点在原树中的深度值去重后组成一个集合,记为 \(S_i\)。要求这次分割满足等式:

\[S_1 = \bigcap_{i=2}^{k+1} S_i \]

换言之,第 \(1\) 棵子树中出现的所有深度恰好是“出现在所有其他子树中的深度”的交集。

我们把任意两个序列 \(b\) 视为不同的方案当且仅当它们作为序列不同(即结点相同但顺序不同视为不同方案)。你的任务是计算满足上述条件的序列 \(b\) 的个数,对 \(998244353\) 取模后输出结果。

对于 \(100\%\) 的数据,保证 \(2\le k< n\le10^6\)\(1 \leq p_i \le i\)。树保证连通。

Solution

反正这个一大段题面我累计读了半个小时吧,我的理解能力是时候该练练了。

首先非常容易发现,你分割出来的第 \(i\) 棵子树的深度一定是某个区间 \([l_i,r_i]\)。显然每棵子树的根节点的深度 \(d_u=l_i\)。那么题目所求的 \(S_1 = \bigcap\limits_{i=2}^{k+1} S_i\) 就是说 \([l_1,r_1]=[\max\limits_{i=2}^{k+1}l_i,\min\limits_{i=2}^{k+1}r_i]\)

然后又因为 \(l_1=d_{b_1}\) 是除以 \(1\) 为根的子树外最小的,结果你要求 \(l_1=\max\limits_{i=2}^{k+1}l_i\),那么是不是每棵子树的根节点深度必须一样啊。

那么现在就好办了,我们只需要关心 \(r_1=\min\limits_{i=2}^{k+1}r_i\) 的部分了。

容易想到枚举 \(r_1\) 的值来去掉 \(\max\)。那么此时有两种点(以下 \(1\le i\le k+1\)):

  • \(r_i>r_1\)。这些点我们设有 \(a\) 个。
  • \(r_i=r_1\)。这些点我们设有 \(b\) 个。只有当 \(b\ge 2\) 的时候才有方案,因为 \(r_1\) 自己要占一个,剩下的点至少要占一个。

然后现在考虑 \(r_1\) 会在哪个地方取到最小值。

  • \([2,k]\) 取到最小
    这个也就是说在分割出去的子树里面取到。我们会在 \(b\) 里取一个出来当作 \(r_1\),然后在剩下所有大于等于 \(r_1\)\(a+b-1\) 个里面选择 \(k-1\) 个出来任意排列。但是如果你全部选的是大于 \(r_1\) 的,这是不合法的,需要减掉。那么这一部分的方案就是 \(b\times(A_{a+b-1}^{k-1}-A_{a}^{k-1})\)

  • \(k+1\) 取到最小
    这就是说在剩下的那棵子树取到。这种情况就相当于我们钦定在 \(k+1\) 的地方取到最小,需要满足 \(\min\limits_{i=2}^{k}r_i>r_1\)\(r_{k+1}=r_1\)。此时 \(a\) 只有等于 \(k-1\) 才合法,因为如果 \(a\) 小了就不够填 \(r_2\sim r_k\) 了,大了就说明我们把一部分大于 \(r_1\) 的填到了 \(r_{k+1}\) 中使 \(r_{k+1}>r_1\)。所以这一部分的方案就是 \(b\times a!\)

这个预处理和计算过程可以在 \(O(n\log n)\) 的时间完成,具体可以看代码。

Code
#include <bits/stdc++.h>
#define Mem(a,b) memset((a),(b),sizeof((a)))
#define eb emplace_back
#define pb push_back
using namespace std;
using i64 = long long;
using uint = unsigned int;
using ui64 = unsigned long long;
using i128 = __int128;
constexpr int N = 2e6 + 15, mod = 998244353;
constexpr i64 inf = 1e18;
constexpr double eps = 1e-6, PI = acos (-1);

namespace FAST_IO {
#define IOSIZE 300000
    char ibuf[IOSIZE], obuf[IOSIZE], *XXH1 = ibuf, *XXH2 = ibuf, *XXH3 = obuf;
#define getchar() ((XXH1==XXH2)and(XXH2=(XXH1=ibuf)+fread(ibuf,1,IOSIZE,stdin),XXH1==XXH2)?(EOF):(*XXH1++))
#define putchar(x) ((XXH3==obuf+IOSIZE)&&(fwrite(obuf,XXH3-obuf,1,stdout),XXH3=obuf),*XXH3++=x)
#define isdigit(ch) (ch>47&&ch<58)
#define isspace(ch) (ch<33)
    template <typename T> inline void read (T &x) {	x = 0; T f = 1;char ch = getchar ();while (!isdigit (ch)) {if (ch == '-') f = -1; ch = getchar ();}while (isdigit (ch)) {x = x * 10 + (ch ^ '0'); ch = getchar ();} x *= f;}
	template <> inline void read (double &x) { x = 0; int f = 1;char ch = getchar ();while (!isdigit (ch)) { if (ch == '-') f = -1; ch = getchar ();} while (isdigit (ch)) x = x * 10 + (ch - '0'), ch = getchar ();if (ch == '.') {ch = getchar (); for (double t = 0.1; isdigit (ch); t *= 0.1) x += t * (ch - '0'), ch = getchar ();}x *= f;}
	inline void read(char &x) { char ch;while ((ch = getchar()) != EOF && isspace(ch));if (ch != EOF) x = ch;}
	inline bool read(char *s) { char ch; while (ch = getchar(), isspace(ch)); if (ch == EOF) return false; while (!isspace(ch)) *s++ = ch, ch = getchar(); *s = '\000'; return true; }
	inline bool read(string& s) { s = ""; char ch; while (ch = getchar(), isspace(ch)); if (ch == EOF) return false; while (!isspace(ch)) s += ch, ch = getchar(); return true; }
	template <typename T, typename ...Args> inline void read (T &x, Args &...args) {read(x); read(args...);}
	template <typename T> inline void write (T x) {	if (x < 0) putchar ('-'), x = -x;static short stk[50], top(0);do stk[++ top] = x % 10, x /= 10;while (x);while (top)putchar (stk[top --] | 48);}
	inline void write(string x) { for (int i = 0, n = x.size(); i < n; i++) putchar(x[i]); }
	inline void write(char *x) { while (*x) putchar(*x++); }
	inline void write(const char *x) { while (*x) putchar(*x++); }
	inline void write(char x) { putchar(x); }
	struct fio { ~fio () {if (XXH3 != obuf) { fwrite(obuf, 1, XXH3 - obuf, stdout);XXH3 = obuf;}}} io;
	template <typename T> fio& operator >> (fio &io, T &x) {return read (x), io;}
	template <typename T> fio& operator << (fio &io, const T &x) {return write (x), io;}
#define cin io
#define cout io
#define endl '\n'
}
using namespace FAST_IO;

#define cdt std::conditional_t

template <int M>
class modint {
public:
	using T = typename cdt <M <= INT_MAX, int, i64>;
	T x;
	modint (int X = 0) {while (X >= M) X -= M;while (X < 0) X += M;x = X;}
	modint (i64 X) {x = X % M;if (x < 0) x += M;}
	modint (i128 X) {x = X % M;if (x < 0) x += M;}
	modint (ui64 X) {x = X % M;}
	modint (const modint &w) : x(w.x) {}
	modint (modint &&w) : x(w.x) {}
	inline modint& operator = (int &&v) {return *this = modint(v);}
	inline modint& operator = (const int &w) {return *this = modint(w);}
	inline modint& operator = (modint &&w) {x = w.x; return *this;}
	inline modint& operator = (const modint &w) {x = w.x; return *this;}
	inline modint& operator ++ () {x = (x + 1 == M ? 0 : x + 1); return *this;}
	inline modint& operator -- () {x = (x == 0 ? M - 1 : x - 1); return *this;}
	inline modint operator + (int &&w) const {return x + w;}
	inline modint operator + (const int &w) const {return x + w;}
	inline modint& operator += (int &&w) {x = (x + w >= M ? x + w - M : x + w);return *this;}
	inline modint& operator += (const int &w) {x = (x + w >= M ? x + w - M : x + w);return *this;}
	inline modint operator + (modint &&w) const {return x + w.x;}
	inline modint operator + (const modint &w) const {return x + w.x;}
	inline modint& operator += (modint &&w) {x = (x + w.x >= M ? x + w.x - M : x + w.x);return *this;}
	inline modint& operator += (const modint &w) {x = (x + w.x >= M ? x + w.x - M : x + w.x);return *this;}
	friend inline modint operator + (const int &v, modint &&w) {return v + w.x;}
	friend inline modint operator + (const int &v, const modint &w) {return v + w.x;}
	inline modint operator - (int &&w) const {return x - w;}
	inline modint operator - (const int &w) const {return x - w;}
	inline modint& operator -= (int &&w) {x = (x - w < 0 ? x - w + M : x - w);return *this;}
	inline modint& operator -= (const int &w) {x = (x - w < 0 ? x - w + M : x - w);return *this;}
	inline modint operator - (modint &&w) const {return x - w.x;}
	inline modint operator - (const modint &w) const {return x - w.x;}
	inline modint& operator -= (modint &&w) {x = (x - w.x < 0 ? x - w.x + M : x - w.x);return *this;}
	inline modint& operator -= (const modint &w) {x = (x - w.x < 0 ? x - w.x + M : x - w.x);return *this;}
	friend inline modint operator - (const int &v, modint &&w) {return v - w.x;}
	friend inline modint operator - (const int &v, const modint &w) {return v - w.x;}
	inline modint operator * (int &&w) const {return (i64)x * w;}
	inline modint operator * (const int &w) const {return (i64)x * w;}
	inline modint& operator *= (int &&w) {x = (i64)x * w % M;return *this;}
	inline modint& operator *= (const int &w) {x = (i64)x * w % M;return *this;}
	inline modint operator * (modint &&w) const {return (i64)x * w.x;}
	inline modint operator * (const modint &w) const {return (i64)x * w.x;}
	inline modint& operator *= (modint &&w) {x = (i64)x * w.x % M;return *this;}
	inline modint& operator *= (const modint &w) {x = (i64)x * w.x % M;return *this;}
	friend inline modint operator * (const int &v, modint &&w) {return (i64)v * w.x;}
	friend inline modint operator * (const int &v, const modint &w) {return (i64)v * w.x;}
	inline modint operator ^ (int b) const {modint a = x, ans = 1;for (; b; b >>= 1, a *= a) if (b & 1) ans *= a;return ans;}
	inline modint& operator ^= (int b) {modint a = x;x = 1;for (; b; b >>= 1, a *= a) if (b & 1) *this *= a;return *this;}
	inline modint operator ~ () const {return *this ^ (M - 2);}
	inline modint operator / (int &&b) const {return *this * ~modint(b);}
	inline modint operator / (const int &b) const {return *this * ~modint(b);}
	inline modint operator / (modint &&b) const {return *this * ~b;}
	inline modint operator / (const modint &b) const {return *this * ~b;}
	inline modint& operator /= (modint &&b) {return *this *= ~b;}
	inline modint& operator /= (const modint &b) {return *this *= ~b;}
	inline modint operator - () const {return -x;}
	friend bool operator == (modint a, int &&w) {return a.x == w;}
	friend bool operator == (modint a, const int &w) {return a.x == w;}
	friend bool operator == (modint a, modint &&w) {return a.x == w.x;}
	friend bool operator == (modint a, const modint &w) {return a.x == w.x;}
	friend bool operator != (modint a, int &&w) {return a.x != w;}
	friend bool operator != (modint a, const int &w) {return a.x != w;}
	friend bool operator != (modint a, modint &&w) {return a.x != w.x;}
	friend bool operator != (modint a, const modint &w) {return a.x != w.x;}
	inline bool operator ! () {return !x;}
	friend fio& operator >> (fio& io, modint& a) {return read (a.x), io;}
	friend fio& operator << (fio& io, const modint& a) {return write (a.x), io;}
};
typedef modint<mod> mo;

bool MS;

int n, k;
int p[N];
vector <int> g[N];
int dep[N], mxdep[N];//每个点的深度、以 i 为根的子树内所有点的最大深度(注意到这就是 r[i])
mo fac[N], inv[N];
vector <int> xxh[N];
vector <pair <int, int>> XXh[N];

inline void init () {
	fac[0] = 1;
	for (int i = 1; i <= N - 15; ++ i) fac[i] = fac[i - 1] * i;
	inv[N - 15] = ~fac[N - 15];
	for (int i = N - 15; i >= 1; -- i) inv[i - 1] = inv[i] * i;
}

inline mo C (int n, int m) {
	if (m < 0 || m > n) return 0;
	return fac[n] * inv[m] * inv[n - m];
}

inline mo A (int n, int m) {
	if (m < 0 || m > n) return 0;
	return fac[n] * inv[n - m];
}

void dfs (int u) {
	mxdep[u] = dep[u] = dep[p[u]] + 1;
	for (int v : g[u]) {
		dfs (v);
		mxdep[u] = max (mxdep[u], mxdep[v]);
	}
}

bool MT;

int main () {
//	freopen ("divide.in", "r", stdin);
//	freopen ("divide.out", "w", stdout);
	init ();
	cin >> n >> k;
	for (int i = 2; i <= n; ++ i) {
		cin >> p[i];
		g[p[i]].eb (i);
	}
	dfs (1);//显然需要预处理深度
	for (int i = 1; i <= n; ++ i) xxh[dep[i]].eb (mxdep[i]);//分层
	for (int i = 1; i <= n; ++ i) sort (xxh[i].begin (), xxh[i].end ());
	for (int i = 1; i <= n; ++ i) {
		if (xxh[i].size ()) {
			XXh[i].eb (xxh[i][0], 1);
			for (int j = 1; j < xxh[i].size(); ++ j) {
				if (xxh[i][j] == xxh[i][j - 1]) XXh[i][XXh[i].size() - 1].second ++;
				else XXh[i].eb (xxh[i][j], 1);
			}
		}
	}
	mo ans = 0;
	for (int i = 2; i <= n; ++ i) {
		if (XXh[i].size ()) {
			reverse (XXh[i].begin (), XXh[i].end ());//你肯定要从深度最大的开始向上合并啊
			int a = 0;
			for (auto &p : XXh[i]) {
				int b = p.second;
				if (b >= 2 && a + b >= k + 1) ans += b * (A(a + b - 1, k - 1) - A(a, k - 1));
				if (b >= 2 && a == k - 1) ans += b * fac[a];
				a += b;
			}
		}
	}
	cout << ans << endl;
	cerr << "Memory:" << (&MS - &MT) / 1048576.0 << "MB Time:" << clock() << "ms" << endl;
	return 0;
}//8.88KB,哈哈
posted @ 2026-04-10 17:36  XXh_Laoxu  阅读(14)  评论(0)    收藏  举报

转载请注明出处!


#页面摧毁游戏#
使用【上下左右】控制飞行器的运动
使用【空格】发射导弹
点击开始摧毁