P14254 分割(divide)
牛逼题。个人认为如果在 CSP 或 NOIP 一般会放在 T2。
原题链接:P14254 分割(divide)。
同时这是好题集第十八篇。
题目描述
你是洛咕咕王国的土地测绘官。洛咕咕王国并购了一块新的领土,这块新的领土正等待被分配。
这块领土可被认为是一棵有 \(n\) 个结点、结点编号为 \(1\) 到 \(n\) 的树,根为编号 \(1\)。为了便于表述,我们把每个结点 \(i\) 在原树中的深度记作 \(d_i\),并规定根的深度为 \(1\)。
你的王国有若干位诸侯希望购买土地,因此现在要从这棵树中选出 \(k\) 个两两不同的结点,并把它们的编号排成一个有序序列 \(b=(b_1,b_2,\dots,b_k)\)。这个序列必须满足两个条件:
第一,每个被选的结点都不是根,并且它们的深度是非降的,也就是对所有 \(1\le i<k\) 有 \(1 < d_{b_i} \le d_{b_{i+1}}\)。
第二,按照序列里每一个结点 \(b_i\)(\(i=1,2,\cdots,k\)),把它们各自与父亲的连边断开。断开这些 \(k\) 条边后,原树会被分成 \(k+1\) 棵互不相交的连通子树。我们把这 \(k+1\) 棵子树依次编号。其中,第 \(1\) 棵到第 \(k\) 棵对应于根为 \(b_1,\dots,b_k\) 的那 \(k\) 棵子树,而第 \(k+1\) 棵子树则是剩下的、包含原来的树根 \(1\) 的那一棵(它的根仍记为 \(1\))。对于第 \(i\) 棵子树,把该子树中所有结点在原树中的深度值去重后组成一个集合,记为 \(S_i\)。要求这次分割满足等式:
换言之,第 \(1\) 棵子树中出现的所有深度恰好是“出现在所有其他子树中的深度”的交集。
我们把任意两个序列 \(b\) 视为不同的方案当且仅当它们作为序列不同(即结点相同但顺序不同视为不同方案)。你的任务是计算满足上述条件的序列 \(b\) 的个数,对 \(998244353\) 取模后输出结果。
对于 \(100\%\) 的数据,保证 \(2\le k< n\le10^6\),\(1 \leq p_i \le i\)。树保证连通。
Solution
反正这个一大段题面我累计读了半个小时吧,我的理解能力是时候该练练了。
首先非常容易发现,你分割出来的第 \(i\) 棵子树的深度一定是某个区间 \([l_i,r_i]\)。显然每棵子树的根节点的深度 \(d_u=l_i\)。那么题目所求的 \(S_1 = \bigcap\limits_{i=2}^{k+1} S_i\) 就是说 \([l_1,r_1]=[\max\limits_{i=2}^{k+1}l_i,\min\limits_{i=2}^{k+1}r_i]\)。
然后又因为 \(l_1=d_{b_1}\) 是除以 \(1\) 为根的子树外最小的,结果你要求 \(l_1=\max\limits_{i=2}^{k+1}l_i\),那么是不是每棵子树的根节点深度必须一样啊。
那么现在就好办了,我们只需要关心 \(r_1=\min\limits_{i=2}^{k+1}r_i\) 的部分了。
容易想到枚举 \(r_1\) 的值来去掉 \(\max\)。那么此时有两种点(以下 \(1\le i\le k+1\)):
- \(r_i>r_1\)。这些点我们设有 \(a\) 个。
- \(r_i=r_1\)。这些点我们设有 \(b\) 个。只有当 \(b\ge 2\) 的时候才有方案,因为 \(r_1\) 自己要占一个,剩下的点至少要占一个。
然后现在考虑 \(r_1\) 会在哪个地方取到最小值。
-
在 \([2,k]\) 取到最小
这个也就是说在分割出去的子树里面取到。我们会在 \(b\) 里取一个出来当作 \(r_1\),然后在剩下所有大于等于 \(r_1\) 的 \(a+b-1\) 个里面选择 \(k-1\) 个出来任意排列。但是如果你全部选的是大于 \(r_1\) 的,这是不合法的,需要减掉。那么这一部分的方案就是 \(b\times(A_{a+b-1}^{k-1}-A_{a}^{k-1})\)。 -
在 \(k+1\) 取到最小
这就是说在剩下的那棵子树取到。这种情况就相当于我们钦定在 \(k+1\) 的地方取到最小,需要满足 \(\min\limits_{i=2}^{k}r_i>r_1\),\(r_{k+1}=r_1\)。此时 \(a\) 只有等于 \(k-1\) 才合法,因为如果 \(a\) 小了就不够填 \(r_2\sim r_k\) 了,大了就说明我们把一部分大于 \(r_1\) 的填到了 \(r_{k+1}\) 中使 \(r_{k+1}>r_1\)。所以这一部分的方案就是 \(b\times a!\)。
这个预处理和计算过程可以在 \(O(n\log n)\) 的时间完成,具体可以看代码。
Code
#include <bits/stdc++.h>
#define Mem(a,b) memset((a),(b),sizeof((a)))
#define eb emplace_back
#define pb push_back
using namespace std;
using i64 = long long;
using uint = unsigned int;
using ui64 = unsigned long long;
using i128 = __int128;
constexpr int N = 2e6 + 15, mod = 998244353;
constexpr i64 inf = 1e18;
constexpr double eps = 1e-6, PI = acos (-1);
namespace FAST_IO {
#define IOSIZE 300000
char ibuf[IOSIZE], obuf[IOSIZE], *XXH1 = ibuf, *XXH2 = ibuf, *XXH3 = obuf;
#define getchar() ((XXH1==XXH2)and(XXH2=(XXH1=ibuf)+fread(ibuf,1,IOSIZE,stdin),XXH1==XXH2)?(EOF):(*XXH1++))
#define putchar(x) ((XXH3==obuf+IOSIZE)&&(fwrite(obuf,XXH3-obuf,1,stdout),XXH3=obuf),*XXH3++=x)
#define isdigit(ch) (ch>47&&ch<58)
#define isspace(ch) (ch<33)
template <typename T> inline void read (T &x) { x = 0; T f = 1;char ch = getchar ();while (!isdigit (ch)) {if (ch == '-') f = -1; ch = getchar ();}while (isdigit (ch)) {x = x * 10 + (ch ^ '0'); ch = getchar ();} x *= f;}
template <> inline void read (double &x) { x = 0; int f = 1;char ch = getchar ();while (!isdigit (ch)) { if (ch == '-') f = -1; ch = getchar ();} while (isdigit (ch)) x = x * 10 + (ch - '0'), ch = getchar ();if (ch == '.') {ch = getchar (); for (double t = 0.1; isdigit (ch); t *= 0.1) x += t * (ch - '0'), ch = getchar ();}x *= f;}
inline void read(char &x) { char ch;while ((ch = getchar()) != EOF && isspace(ch));if (ch != EOF) x = ch;}
inline bool read(char *s) { char ch; while (ch = getchar(), isspace(ch)); if (ch == EOF) return false; while (!isspace(ch)) *s++ = ch, ch = getchar(); *s = '\000'; return true; }
inline bool read(string& s) { s = ""; char ch; while (ch = getchar(), isspace(ch)); if (ch == EOF) return false; while (!isspace(ch)) s += ch, ch = getchar(); return true; }
template <typename T, typename ...Args> inline void read (T &x, Args &...args) {read(x); read(args...);}
template <typename T> inline void write (T x) { if (x < 0) putchar ('-'), x = -x;static short stk[50], top(0);do stk[++ top] = x % 10, x /= 10;while (x);while (top)putchar (stk[top --] | 48);}
inline void write(string x) { for (int i = 0, n = x.size(); i < n; i++) putchar(x[i]); }
inline void write(char *x) { while (*x) putchar(*x++); }
inline void write(const char *x) { while (*x) putchar(*x++); }
inline void write(char x) { putchar(x); }
struct fio { ~fio () {if (XXH3 != obuf) { fwrite(obuf, 1, XXH3 - obuf, stdout);XXH3 = obuf;}}} io;
template <typename T> fio& operator >> (fio &io, T &x) {return read (x), io;}
template <typename T> fio& operator << (fio &io, const T &x) {return write (x), io;}
#define cin io
#define cout io
#define endl '\n'
}
using namespace FAST_IO;
#define cdt std::conditional_t
template <int M>
class modint {
public:
using T = typename cdt <M <= INT_MAX, int, i64>;
T x;
modint (int X = 0) {while (X >= M) X -= M;while (X < 0) X += M;x = X;}
modint (i64 X) {x = X % M;if (x < 0) x += M;}
modint (i128 X) {x = X % M;if (x < 0) x += M;}
modint (ui64 X) {x = X % M;}
modint (const modint &w) : x(w.x) {}
modint (modint &&w) : x(w.x) {}
inline modint& operator = (int &&v) {return *this = modint(v);}
inline modint& operator = (const int &w) {return *this = modint(w);}
inline modint& operator = (modint &&w) {x = w.x; return *this;}
inline modint& operator = (const modint &w) {x = w.x; return *this;}
inline modint& operator ++ () {x = (x + 1 == M ? 0 : x + 1); return *this;}
inline modint& operator -- () {x = (x == 0 ? M - 1 : x - 1); return *this;}
inline modint operator + (int &&w) const {return x + w;}
inline modint operator + (const int &w) const {return x + w;}
inline modint& operator += (int &&w) {x = (x + w >= M ? x + w - M : x + w);return *this;}
inline modint& operator += (const int &w) {x = (x + w >= M ? x + w - M : x + w);return *this;}
inline modint operator + (modint &&w) const {return x + w.x;}
inline modint operator + (const modint &w) const {return x + w.x;}
inline modint& operator += (modint &&w) {x = (x + w.x >= M ? x + w.x - M : x + w.x);return *this;}
inline modint& operator += (const modint &w) {x = (x + w.x >= M ? x + w.x - M : x + w.x);return *this;}
friend inline modint operator + (const int &v, modint &&w) {return v + w.x;}
friend inline modint operator + (const int &v, const modint &w) {return v + w.x;}
inline modint operator - (int &&w) const {return x - w;}
inline modint operator - (const int &w) const {return x - w;}
inline modint& operator -= (int &&w) {x = (x - w < 0 ? x - w + M : x - w);return *this;}
inline modint& operator -= (const int &w) {x = (x - w < 0 ? x - w + M : x - w);return *this;}
inline modint operator - (modint &&w) const {return x - w.x;}
inline modint operator - (const modint &w) const {return x - w.x;}
inline modint& operator -= (modint &&w) {x = (x - w.x < 0 ? x - w.x + M : x - w.x);return *this;}
inline modint& operator -= (const modint &w) {x = (x - w.x < 0 ? x - w.x + M : x - w.x);return *this;}
friend inline modint operator - (const int &v, modint &&w) {return v - w.x;}
friend inline modint operator - (const int &v, const modint &w) {return v - w.x;}
inline modint operator * (int &&w) const {return (i64)x * w;}
inline modint operator * (const int &w) const {return (i64)x * w;}
inline modint& operator *= (int &&w) {x = (i64)x * w % M;return *this;}
inline modint& operator *= (const int &w) {x = (i64)x * w % M;return *this;}
inline modint operator * (modint &&w) const {return (i64)x * w.x;}
inline modint operator * (const modint &w) const {return (i64)x * w.x;}
inline modint& operator *= (modint &&w) {x = (i64)x * w.x % M;return *this;}
inline modint& operator *= (const modint &w) {x = (i64)x * w.x % M;return *this;}
friend inline modint operator * (const int &v, modint &&w) {return (i64)v * w.x;}
friend inline modint operator * (const int &v, const modint &w) {return (i64)v * w.x;}
inline modint operator ^ (int b) const {modint a = x, ans = 1;for (; b; b >>= 1, a *= a) if (b & 1) ans *= a;return ans;}
inline modint& operator ^= (int b) {modint a = x;x = 1;for (; b; b >>= 1, a *= a) if (b & 1) *this *= a;return *this;}
inline modint operator ~ () const {return *this ^ (M - 2);}
inline modint operator / (int &&b) const {return *this * ~modint(b);}
inline modint operator / (const int &b) const {return *this * ~modint(b);}
inline modint operator / (modint &&b) const {return *this * ~b;}
inline modint operator / (const modint &b) const {return *this * ~b;}
inline modint& operator /= (modint &&b) {return *this *= ~b;}
inline modint& operator /= (const modint &b) {return *this *= ~b;}
inline modint operator - () const {return -x;}
friend bool operator == (modint a, int &&w) {return a.x == w;}
friend bool operator == (modint a, const int &w) {return a.x == w;}
friend bool operator == (modint a, modint &&w) {return a.x == w.x;}
friend bool operator == (modint a, const modint &w) {return a.x == w.x;}
friend bool operator != (modint a, int &&w) {return a.x != w;}
friend bool operator != (modint a, const int &w) {return a.x != w;}
friend bool operator != (modint a, modint &&w) {return a.x != w.x;}
friend bool operator != (modint a, const modint &w) {return a.x != w.x;}
inline bool operator ! () {return !x;}
friend fio& operator >> (fio& io, modint& a) {return read (a.x), io;}
friend fio& operator << (fio& io, const modint& a) {return write (a.x), io;}
};
typedef modint<mod> mo;
bool MS;
int n, k;
int p[N];
vector <int> g[N];
int dep[N], mxdep[N];//每个点的深度、以 i 为根的子树内所有点的最大深度(注意到这就是 r[i])
mo fac[N], inv[N];
vector <int> xxh[N];
vector <pair <int, int>> XXh[N];
inline void init () {
fac[0] = 1;
for (int i = 1; i <= N - 15; ++ i) fac[i] = fac[i - 1] * i;
inv[N - 15] = ~fac[N - 15];
for (int i = N - 15; i >= 1; -- i) inv[i - 1] = inv[i] * i;
}
inline mo C (int n, int m) {
if (m < 0 || m > n) return 0;
return fac[n] * inv[m] * inv[n - m];
}
inline mo A (int n, int m) {
if (m < 0 || m > n) return 0;
return fac[n] * inv[n - m];
}
void dfs (int u) {
mxdep[u] = dep[u] = dep[p[u]] + 1;
for (int v : g[u]) {
dfs (v);
mxdep[u] = max (mxdep[u], mxdep[v]);
}
}
bool MT;
int main () {
// freopen ("divide.in", "r", stdin);
// freopen ("divide.out", "w", stdout);
init ();
cin >> n >> k;
for (int i = 2; i <= n; ++ i) {
cin >> p[i];
g[p[i]].eb (i);
}
dfs (1);//显然需要预处理深度
for (int i = 1; i <= n; ++ i) xxh[dep[i]].eb (mxdep[i]);//分层
for (int i = 1; i <= n; ++ i) sort (xxh[i].begin (), xxh[i].end ());
for (int i = 1; i <= n; ++ i) {
if (xxh[i].size ()) {
XXh[i].eb (xxh[i][0], 1);
for (int j = 1; j < xxh[i].size(); ++ j) {
if (xxh[i][j] == xxh[i][j - 1]) XXh[i][XXh[i].size() - 1].second ++;
else XXh[i].eb (xxh[i][j], 1);
}
}
}
mo ans = 0;
for (int i = 2; i <= n; ++ i) {
if (XXh[i].size ()) {
reverse (XXh[i].begin (), XXh[i].end ());//你肯定要从深度最大的开始向上合并啊
int a = 0;
for (auto &p : XXh[i]) {
int b = p.second;
if (b >= 2 && a + b >= k + 1) ans += b * (A(a + b - 1, k - 1) - A(a, k - 1));
if (b >= 2 && a == k - 1) ans += b * fac[a];
a += b;
}
}
}
cout << ans << endl;
cerr << "Memory:" << (&MS - &MT) / 1048576.0 << "MB Time:" << clock() << "ms" << endl;
return 0;
}//8.88KB,哈哈

浙公网安备 33010602011771号