题解:QOJ9840 [CCPCF21D] Tree Partition

题解:QOJ9840 [CCPCF21D] Tree Partition

题目描述

给定⼀棵 n 个节点的树, 树上的点按照 1 到 n 标号。要求删除树上的 k 条边将树划分成 k+1 连通块, 满足每个连通块中的节点标号所构成的集合都为一段连续的整数,求方案数。 \(n\leq 200000, k\leq 400\)

题解

这是由 IOI2025 国家队刘恒熙提出的新做法,涉及的理论可以在他的论文《浅谈连续段及相关问题》找到。这里复述一下做法:

设树的点集为 \(U=\{1,2,\cdots,n\}\)。树的所有连通块是一个 \(U\) 的子集族 \(S\subseteq 2^U\),满足:

如果 \(A, B\in S, A\cap B\neq\varnothing\),则 \(A\cap B\in S\)(II 类封闭)且 \(A\cup B\in S\)(I 类封闭)。

这些性质都是直观的。

由于我们只关心连通块的点集是一个区间的情况,我们不妨将所有不是区间的连通块踢出 \(S\),这样不会影响 I、II 类封闭的性质。

然后论文提出,我们可以对每个 \([i, i+1]\),找到包含它们的极小区间:

\[[L_i, R_i]=\bigcap_{[i, i+1]\subseteq [l, r]\in S}[l, r] \]

然后声称:

\[[l, r]\in S\iff \left(\bigcup_{i=l}^{r-1}[L_i, R_i]\right)\subseteq [l, r] \]

右推左是符合直观的,左推右的证明论文中有(定理 4.2)。到这里,我们就可以快速判断 \([l, r]\) 是否 \(\in S\)

为了解决这道题,我们还需要对固定的右端点 \(r\) 找出所有 \([l, r]\in S\)。这部分的方法在“4.5 求右端点对应的左端点处信息之和”中有介绍。算法流程很符合直觉,假如我们维护了固定 \(r-1\) 时所有 \([l, r-1]\in S\) 的区间,现在转移到 \(r\),原来的区间都会加入一个 \([i, i+1]\) 导致被 \([L_i, R_i]\) 限制,那么我们删掉所有 \(l>L_i\) 的区间就行了。除此之外还有一些新的区间会加入,主要是因为在 \(r-1\) 时这些区间的 \(\max_{i=l}^{r-1}R_i>r-1\) 导致它们被判为非法,现在限制松了,就会成为新的合法区间。那如果这样的话,我们在 \(r-1\) 时就不要删除因为 \(\max_{i=l}^{r-1}R_i>r-1\) 而不合法的区间,而只删除 \(\min_{i=l}^{r-1}L_i<l\) 的区间。提取信息的时候,在区间序列上做前缀和,查询两个前缀和的差就可以了。记得特判单点区间。

具体流程可以看代码。\(c_i\)\([1, i]\) 中最大的 \(k\) 满足 \(R_k>i+1\)(这是给计算 \(i+1\) 的答案用的),使用单调栈;然后计算的时候也是维护单调栈,在栈上做前缀和,维护栈,然后快速找 \(\leq c_i\) 里面最大的在栈中的位置(并查集),做差分,就解决了问题。

最后是和本题相关的很迫真的题解:设 \(dp_{i,j}\) 表示 \([1, i]\) 划分为 \(j\) 段的方案数,枚举 \(j\),用上面的算法优化转移即可。

复杂度 \(O(n(k+\log n))\)\(O(\log n)\) 来自并查集,可以用线性并查集的理论优化掉。实际代码写成 \(O(nk\log n)\) 了。

代码

#include <bits/stdc++.h>
using namespace std;
#ifdef LOCAL
#define debug(...) fprintf(stderr, ##__VA_ARGS__)
#else
#define endl "\n"
#define debug(...) void(0)
#endif
using LL = long long;
template <unsigned umod>
struct modint {/*{{{*/
  static constexpr int mod = umod;
  unsigned v;
  modint() = default;
  template <class T, enable_if_t<is_integral<T>::value, int> = 0>
    modint(const T& y) : v((unsigned)(y % mod + (is_signed<T>() && y < 0 ? mod : 0))) {}
  modint& operator+=(const modint& rhs) { v += rhs.v; if (v >= umod) v -= umod; return *this; }
  modint& operator-=(const modint& rhs) { v -= rhs.v; if (v >= umod) v += umod; return *this; }
  modint& operator*=(const modint& rhs) { v = (unsigned)(1ull * v * rhs.v % umod); return *this; }
  modint& operator/=(const modint& rhs) { assert(rhs.v); return *this *= qpow(rhs, mod - 2); }
  friend modint operator+(modint lhs, const modint& rhs) { return lhs += rhs; }
  friend modint operator-(modint lhs, const modint& rhs) { return lhs -= rhs; }
  friend modint operator*(modint lhs, const modint& rhs) { return lhs *= rhs; }
  friend modint operator/(modint lhs, const modint& rhs) { return lhs /= rhs; }
  template <class T> friend modint qpow(modint a, T b) {
    modint r = 1;
    for (assert(b >= 0); b; b >>= 1, a *= a) if (b & 1) r *= a;
    return r;
  }
  friend int raw(const modint& self) { return self.v; }
  friend ostream& operator<<(ostream& os, const modint& self) { return os << raw(self); }
  explicit operator bool() const { return v != 0; }
  modint operator-() const { return modint(0) - *this; }
  bool operator==(const modint& rhs) const { return v == rhs.v; }
  bool operator!=(const modint& rhs) const { return v != rhs.v; }
};/*}}}*/
using mint = modint<998244353>;
constexpr int N = 2e5 + 10;
struct range {
  int l, r;
};
int n, mn[18][N], mx[18][N], anc[18][N], dep[N], m;
basic_string<int> tr[N];
void dfs(int u, int fa) {
  anc[0][u] = fa, dep[u] = dep[fa] + 1;
  for (int v : tr[u]) if (v != fa) dfs(v, u);
}
range query(int u, int v) {
  range res{min(u, v), max(u, v)};
  auto jmp = [&](int &x, int j) {
    res.l = min(res.l, mn[j][x]);
    res.r = max(res.r, mx[j][x]);
    x = anc[j][x];
  };
  if (dep[u] < dep[v]) swap(u, v);
  int d = dep[u] - dep[v];
  for (int j = 17; j >= 0; j--) if (d >> j & 1) jmp(u, j);
  if (u == v) return res;
  for (int j = 17; j >= 0; j--) if (anc[j][u] != anc[j][v]) jmp(u, j), jmp(v, j);
  jmp(u, 1), jmp(v, 0);
  return res;
}
range b[N];
int c[N];
mint f[N], g[N];
int main() {
#ifndef LOCAL
  cin.tie(nullptr)->sync_with_stdio(false);
#endif
  cin >> n >> m;
  for (int i = 1, u, v; i < n; i++) cin >> u >> v, tr[u] += v, tr[v] += u;
  for (int i = 1; i <= n; i++) mx[0][i] = mn[0][i] = i;
  dfs(1, 0);
  for (int j = 1; j < 18; j++) {
    for (int i = 1; i <= n; i++) {
      anc[j][i] = anc[j - 1][anc[j - 1][i]];
      mx[j][i] = max(mx[j - 1][i], mx[j - 1][anc[j - 1][i]]);
      mn[j][i] = min(mn[j - 1][i], mn[j - 1][anc[j - 1][i]]);
    }
  }
  for (int i = 1; i < n; i++) b[i] = query(i, i + 1), debug("b[%d] = [%d, %d]\n", i, b[i].l, b[i].r);
  static int stk[N];
  int top = 0;
  for (int i = 1; i < n; i++) {
    while (top && b[stk[top]].r < b[i].r) --top;
    stk[++top] = i;
    while (top && b[stk[top]].r <= i + 1) --top;
    c[i] = stk[top];
    debug("c[%d] = %d\n", i, c[i]);
  }
  static mint p[N];
  static int fa[N];
  f[0] = 1;
  while (m--) {
    for (int i = 0; i <= n; i++) fa[i] = i;
    p[0] = 0;
    g[0] = 0;
    g[1] = f[0];
    top = 0;
    auto find = [&](auto self, int x) -> int { return fa[x] == x ? x : fa[x] = self(self, fa[x]); };
    for (int i = 1; i < n; i++) {
      p[i] = p[stk[top]] + f[i - 1];
      stk[++top] = i;
      while (stk[top] > b[i].l) fa[stk[top--]] -= 1;
//    int v = find(find, c[i]);
//    debug("[v = %d] g[%d]: ", v, i + 1);
//    for (int j = 1; j <= top; j++) debug("%d ", stk[j]);
//    debug("\n");
      g[i + 1] = p[stk[top]] - p[find(find, c[i])] + f[i];
    }
    cout << g[n] << endl;
    memcpy(f, g, sizeof f);
  }
  return 0;
}
posted @ 2025-06-02 18:22  caijianhong  阅读(40)  评论(0)    收藏  举报