P10787 [NOI2024] 树的定向

考虑贪心,尝试每一位是否能填 \(0\),瓶颈再判断局面是否有解。

A 性质可以交替定向做,这提示我们如果当前还没有边被定向,那么当所有限制的距离都 \(\ge 2\) 就一定合法。考虑把这个拓展到存在被定向边的情况,可以发现,当我们定向了 \(u\to v\),那么把已经满足的限制删除后,就可以将 \(u\)\(v\) 缩成一个点。于是只要所有未被删除的路径上都有 \(\ge 2\) 条未定向边,当前局面就一定有解。

而如果存在一个限制,上面只有一条边未定向,那么我们就可以把这条边定向并递归。否则考虑编号最小的边,怎么定向都是合法的,填 \(0\) 即可。

于是现在我们要维护的就是给一条边定向,同时找到所有只剩一条未定向边且当前仍然合法的限制。注意到我们并不需要维护限制是否合法,而可以在当未定向边数 \(=1\) 时被取出时判断。那么我们只要支持找出所有未定向边 \(=1\) 的限制和判断其是否还合法,然后找出其中那条未定向的边,直接暴力实现可以做到 \(O(nm)\)

考虑这个怎么优化。一个想法是把限制拆成若干块,只有当一块的未定向边数 \(\le 1\) 时才需要判断这个限制是否合法,这个可以用倍增把限制拆成 \(O(\log)\) 条链。维护一条链是否合法以及找到未定向的边都可以用带权并查集实现做到 \(O(n\log n)\)

于是总复杂度就是 \(O(n\log n)\),可能有点卡常。一个卡常办法是使用线性空间的倍增,那么总共就只有 \(O(n)\) 块。

#include <bits/stdc++.h>
using namespace std;

#define gc getchar_unlocked()
void Read(int &x) {
  x = 0;
  char ch = gc;
  while(!isdigit(ch)) ch = gc;
  while(isdigit(ch)) x = x * 10 + ch - 48, ch = gc;
}

const int kN = 5e5 + 5, kLog = 20;
int cid, n, m;
int fa[kN], dep[kN];
int u[kN], v[kN], eid[kN], dn[kN];
int s[kN], t[kN], lca[kN];
int jp[kN], cnt[kN * 2], tag[kN * 2];
vector<int> g[kN];
vector<int> buc[kN * 2], tra[kN * 2];
int ans[kN];

void DFS(int x, int fa) {
  cnt[x] = 1;
  ::fa[x] = fa;
  dep[x] = dep[fa] + 1;
  int jpf = jp[fa];
  int jpff = jp[jpf];
  tra[x].push_back(x + n);
  if(jpff && (dep[fa] - dep[jpf] == dep[jpf] - dep[jpff])) {
    tra[fa + n].push_back(x + n);
    tra[jpf + n].push_back(x + n);
    jp[x] = jpff;
    cnt[x + n] = cnt[fa + n] + cnt[jpf + n] + 1;
  }else jp[x] = fa, cnt[x + n] = 1;
  for(int to : g[x]) {
    if(to != fa) DFS(to, x);
  }
}
int Push(int s, int t, int id) {
  if(dep[s] < dep[t]) swap(s, t);
  while(dep[s] > dep[t]) {
    if(dep[jp[s]] >= dep[t]) {
      buc[s + n].push_back(id);
      s = jp[s];
    }else {
      buc[s].push_back(id);
      s = fa[s];
    }
  }
  while(s != t) {
    if(jp[s] != jp[t]) {
      buc[s + n].push_back(id);
      buc[t + n].push_back(id);
      s = jp[s], t = jp[t];
    }else {
      buc[s].push_back(id);
      buc[t].push_back(id);
      s = fa[s], t = fa[t];
    }
  }
  return s;
}

struct DSU {
  int fa[kN], sum[kN];
  DSU() {
    iota(fa, fa + kN, 0);
    memset(sum, 0, sizeof(sum));
  }
  int Find(int x) {
    int f = fa[x];
    if(f == x) return x;
    fa[x] = Find(f);
    return sum[x] += sum[f], fa[x];
  }
  void Merge(int x, int y, int v) {
    assert(::fa[x] == y);
    int fx = Find(x), fy = Find(y);
    sum[fx] += sum[y] + v;
    fa[fx] = fy;
  }
  int Query(int x) { return Find(x), sum[x]; }
}dsu;
bool Check(int id) {
  int cnt = 0, sx, sy;
  int x = s[id], y = t[id], p = lca[id];
  int fx = dsu.Find(x);
  int fy = dsu.Find(y);
  int fp = dsu.Find(p);
  if((dsu.Find(fa[fx]) != fp) && (dsu.Find(fa[fy]) != fp)) return 0;
  if(fx == fp) sx = dsu.sum[x] - dsu.sum[p];
  else sx = dsu.sum[x] + dsu.sum[fa[fx]] - dsu.sum[p], cnt++;
  if(sx + 1 < dep[x] - dep[p]) return 0;
  if(fy == fp) sy = dsu.sum[y] - dsu.sum[p];
  else sy = dsu.sum[y] + dsu.sum[fa[fy]] - dsu.sum[p], cnt++;
  if(-sy + 1 < dep[y] - dep[p]) return 0;
  return cnt == 1;
}

int Warn(int id) {
  if(!Check(id)) return 0;
  int ps = dsu.Find(s[id]);
  int pt = dsu.Find(t[id]);
  return (dep[ps] > dep[pt]) ? 2 * ps + 1 : 2 * pt;
}

void Solve(int, bool) ;
void WarnF(int x, int v) {
  tag[x] += v;
  if((cnt[x] -= v) > 1) return ;
  for(int to : tra[x]) WarnF(to, tag[x]);
  tag[x] = 0;
  for(int id : buc[x]) {
    if(int p = Warn(id)) Solve(p / 2, p & 1);
  }
}
void Solve(int p, bool ty) {
  dsu.Merge(p, fa[p], ty ? -1 : 1);
  ans[p] = (ty ^ (v[eid[p]] == p));
  WarnF(p, 1);
}

int main() {
  ios::sync_with_stdio(0), cin.tie(0);
  Read(cid), Read(n), Read(m);
  for(int i = 1; i < n; i++) {
    Read(u[i]), Read(v[i]);
    g[u[i]].push_back(v[i]);
    g[v[i]].push_back(u[i]);
  }
  DFS(1, 0);
  for(int i = 1; i <= m; i++) {
    Read(s[i]), Read(t[i]);
    lca[i] = Push(s[i], t[i], i);
  }
  for(int i = 1; i < n; i++) {
    dn[i] = (dep[u[i]] > dep[v[i]]) ? u[i] : v[i];
    eid[dn[i]] = i;
  }
  memset(ans, -1, sizeof(ans));
  for(int i = 1; i <= m; i++) {
    if(int p = Warn(i)) Solve(p / 2, p & 1);
  }
  for(int i = 1; i < n; i++) {
    if(ans[dn[i]] == -1) {
      Solve(dn[i], v[i] == dn[i]);
    }
  }
  for(int i = 1; i < n; i++) cout << ans[dn[i]];
  cout << "\n";
  return 0;
}
posted @ 2025-07-06 08:08  CJzdc  阅读(48)  评论(0)    收藏  举报