UOJ841 龙门探宝

先考虑如果已知了树怎么做。

对于 \(l\)\(r\) 的所有点建出虚树,直径的一个求法是找两次最远点,这代表直径的一个端点 \(p\) 一定最大化了深度 \(dep_p\)

设直径另一个端点为 \(x\),那么距离就是 \(\operatorname{dist}(x,p)=dep_x+dep_p-2dep_{\operatorname{LCA}(x,p)}\),即我们需要最大化的就是 \(dep_x-2dep_{\operatorname{LCA}(x,p)}\)

这个式子和 \(p\) 有关,不好处理。一个想法是把它拆成若干个和 \(p\) 关系不大的式子的 \(\max\)

这时候需要用上区间的性质,但是直接用 \(\operatorname{LCA}(x,x+1)\) 估计是有问题的,因为我们可能会把式子算得更大。

注意到 \(p\) 是区间最深点,因此考虑 \(nxt_x\) 为序列上 \(x\) 的下一个深度不小于 \(dep_x\) 的点。事实上 \(<p\) 部分用 \(\operatorname{LCA}(x,nxt_x)\) 估计就是对的。对于右侧同理,\(lst_x\) 为上一个深度大于 \(dep_x\) 的点即可。证明可以分讨一下,如果 \(\operatorname{LCA}\) 不对那么算出来的一定不优,并且这么算一定能找到一个答案。

\(A_i=dep_i-2dep_{\operatorname{LCA}(i,nxt_i)}\)\(B_i=dep_i-2dep_{\operatorname{LCA}(i,lst_i)}\)。那么区间 \(l\sim r\) 的直径就是 \(dep_p+\max(A_{l\sim (p-1)},B_{(p+1)\sim r})\)

之后就是一个数据结构问题,枚举较短的一边扫描线可以做到 \(O(n\log^2n)\)

#include <bits/stdc++.h>
#include "tree.h"
using namespace std;
using ll = long long;
template <typename T> void Chkmin(T &x, T y) { x = min(x, y); }
template <typename T> void Chkmax(T &x, T y) { x = max(x, y); }

const int inf = 1e9;
const int kN = 1e5 + 5, kS = kN * 4;
int n;
ll ans;
int dep[kN], prv[kN], nxt[kN], A[kN], B[kN];

void GetPrv() {
  stack<int> stk;
  for(int i = n; i >= 1; stk.push(i--)) {
    while(stk.size() && (dep[stk.top()] <= dep[i])) {
      prv[stk.top()] = i;
      stk.pop();
    }
  }
  for(; stk.size(); stk.pop()) prv[stk.top()] = 0;
  for(int i = 1; i <= n; i++) {
    if(prv[i] >= 1) B[i] = query(i, prv[i]) - dep[prv[i]];
    else B[i] = -inf;
  }
}
void GetNxt() {
  stack<int> stk;
  for(int i = 1; i <= n; stk.push(i++)) {
    while(stk.size() && (dep[stk.top()] < dep[i])) {
      nxt[stk.top()] = i;
      stk.pop();
    }
  }
  for(; stk.size(); stk.pop()) nxt[stk.top()] = n + 1;
  for(int i = 1; i <= n; i++) {
    if(nxt[i] <= n) A[i] = query(i, nxt[i]) - dep[nxt[i]];
    else A[i] = -inf;
  }
}

#define ls (o << 1)
#define rs (o << 1 | 1)
struct SGT {
  ll len[kS], sum[kS], tag[kS];
  void Up(int o) { sum[o] = sum[ls] + sum[rs]; }
  void Build(int o, int l, int r) {
    tag[o] = 0;
    len[o] = r - l + 1;
    if(l == r) return void(sum[o] = -inf);
    int mid = (l + r) >> 1;
    Build(ls, l, mid);
    Build(rs, mid + 1, r);
    Up(o);
  }
  void Adt(int o, ll t) { sum[o] += len[o] * t, tag[o] += t; }
  void Dn(int o) { if(ll &t = tag[o]) Adt(ls, t), Adt(rs, t), t = 0; }
  void Update(int o, int l, int r, int x, int y, ll v) {
    if((l > y) || (r < x)) return ;
    if((l >= x) && (r <= y)) return Adt(o, v);
    Dn(o);
    int mid = (l + r) >> 1;
    Update(ls, l, mid, x, y, v);
    Update(rs, mid + 1, r, x, y, v);
    Up(o);
  }
  ll Query(int o, int l, int r, int x, int y) {
    if((l > y) || (r < x)) return 0;
    if((l >= x) && (r <= y)) return sum[o];
    Dn(o);
    int mid = (l + r) >> 1;
    return Query(ls, l, mid, x, y) + Query(rs, mid + 1, r, x, y);
  }
}sgt;

struct Node {
  int l, r, v;
  Node() { }
  Node(int _l, int _r, int _v) {
    l = _l, r = _r, v = _v;
  }
};
Node seg[kN];

void CalcA() {
  sgt.Build(1, 1, n);
  for(int i = n, sc = 0; i >= 1; i--) {
    int pr = i;
    while(sc && (seg[sc].v <= B[i + 1])) {
      sgt.Update(1, 1, n, seg[sc].l, seg[sc].r, B[i + 1] - seg[sc].v);
      pr = seg[sc--].r;
    }
    if(pr > i) seg[++sc] = Node(i + 1, pr, B[i + 1]);
    seg[++sc] = Node(i, i, -inf);
    int l = prv[i] + 1, r = nxt[i] - 1;
    if(i - l > r - i) continue;
    for(int j = i, v = -inf; j >= l; Chkmax(v, A[--j])) {
      int L = 0, R = sc;
      while(L + 1 < R) {
        int mid = (L + R) >> 1;
        (seg[mid].v <= v) ? (R = mid) : (L = mid);
      }
      ans += (ll)v * (min(r, seg[R].r) - i + 1);
      if(seg[R].r < r) ans += sgt.Query(1, 1, n, seg[R].r + 1, r);
    }
  }
}
void CalcB() {
  sgt.Build(1, 1, n);
  for(int i = 1, sc = 0; i <= n; i++) {
    int pl = i;
    while(sc && (seg[sc].v <= A[i - 1])) {
      sgt.Update(1, 1, n, seg[sc].l, seg[sc].r, A[i - 1] - seg[sc].v);
      pl = seg[sc--].l;
    }
    if(pl < i) seg[++sc] = Node(pl, i - 1, A[i - 1]);
    seg[++sc] = Node(i, i, -inf);
    int l = prv[i] + 1, r = nxt[i] - 1;
    if(i - l <= r - i) continue;
    for(int j = i, v = -inf; j <= r; Chkmax(v, B[++j])) {
      int L = 0, R = sc;
      while(L + 1 < R) {
        int mid = (L + R) >> 1;
        (seg[mid].v <= v) ? (R = mid) : (L = mid);
      }
      ans += (ll)v * (i - max(l, seg[R].l) + 1);
      if(seg[R].l > l) ans += sgt.Query(1, 1, n, l, seg[R].l - 1);
    }
  }
}

ll solve(int typ, int N) {
  n = N;
  ans = 0;
  for(int i = 2; i <= n; i++) dep[i] = query(1, i);
  GetPrv();
  GetNxt();
  for(int i = 1; i <= n; i++) {
    int l = prv[i] + 1, r = nxt[i] - 1;
    ans += ((ll)(i - l + 1) * (r - i + 1) - 1) * dep[i];
  }
  CalcA();
  CalcB();
  return ans + (ll)n * inf;
}
posted @ 2025-08-22 14:23  CJzdc  阅读(15)  评论(0)    收藏  举报