树上邻域理论(树上圆理论) 小记
- 邻域:记 \(f(u, r)\) 表示距离 \(u\) 不超过 \(r\) 的点组成的邻域。
令 \(x, y\) 为点集 \(S\) 中两个距离最远的点,设 \(u\) 为 \(x, y\) 中点(可能是一条边的中心),设 \(d\) 为 \(x, y\) 的距离,那么覆盖 \(S\) 的最小邻域为 \(f(u, \frac d2)\)。
- 邻域 \(f(u_1, r_1)\) 包含邻域 \(f(u_2, r_2)\),当且仅当 \(r_1 \ge r_2 + \text{dist} (u_1, u_2)\)。
事实上我们可以把树上邻域视作平面上的圆,令 \(d = \text{dist}(u_1, u_2)\),那么有

显然有 \(r_1 \ge r_2 + \text{dist}(u_1, u_2)\)。
- 设 \(c(S) = f(u, r)\) 为包含 \(S\) 的最小邻域,求 \(c(S_1 \cup S_2)\) 的合并操作(其中 \(S_1 \cap S_2 = \emptyset\)):
若 \(S_1\) 包含 \(S_2\) 则 \(c(S_1 \cup S_2) = c(S_1)\),\(S_2\) 包含 \(S_1\) 同理。
否则令 \(d = \text{dist} (u, v)\),则 \(c(S_1 \cup S_2) = f(\text{mov} (u_1, u_2, \frac {d - r_1 + r_2} 2), \frac {d + r_1 + r_2} 2)\)。其中 \(\text{mov}(u, v, k)\) 表示 \(u\) 向 \(v\) 方向移动 \(k\) 步到达的点。

- 距离查询:对于点集 \(S\),包含其的最小邻域为 \(c(S) = f(u, r)\),则任意一个点 \(v\) 到达 \(S\) 中的最远点距离为 \(\text{dist(u, v)} + r\)。
所有直径中点为 \(u\),且点 \(v\) 到直径其中一端距离最大,必然经过直径中心。
先分治,对于 \(i \in [l, mid]\),设 \(h_i\) 为点集 \([i, mid]\) 的最小邻域,\(i \in [mid + 1, r]\) 同理。
现在需要统计所有 \(i \in [l, mid], j \in [mid + 1, r]\),\(h_i\) 与 \(h_j\) 合并后的邻域半径大小之和,邻域合并操作需要分三种情况讨论:
-
\(h_i\) 包含 \(h_j\)。
-
\(h_i\) 和 \(h_j\) 不存在包含关系。
-
\(h_i\) 被包含于 \(h_j\)。
一个性质:\(\forall i \in [l, mid - 1]\),\(h_i\) 包含 \(h_{i + 1}\)。\(i \in [mid + 2, r]\) 同理。
所以 \(h_{mid + 1 \sim r}\) 中,存在两个分界线 \(p, q\) 满足 \(h_i\) 包含 \(h_{mid + 1\sim p}\),\(h_i\) 与 \(h_{p + 1\sim q}\) 不存在包含关系,\(h_i\) 被包含于 \(h_{q + 1\sim r}\)。
并且,随着 \(i\) 的减小,\(h_i\) 越来越大,\(p, q\) 应是单调不增的,所以可以直接维护 \(p, q\)。
考虑计算答案。
-
对于 \(h_{mid + 1\sim p}\),合并后邻域仍为 \(h_i\),贡献为 \(h_i\) 的直径。
-
对于 \(h_{p + 1\sim q}\),合并后为 \(h_i\) 的半径,加上 \(h_j\) 的半径,加上两个中心点之间的距离,贡献乘上 \(\frac 12\)。前两者是容易的,第三者需要使用全局平衡二叉树 / 点分树。
-
对于 \(h_{q + 1\sim r}\),合并邻域为 \(h_{q + 1\sim r}\),贡献为对应直径。
时间复杂度 \(\mathcal O(n\log^2n)\),注意事先需要给每条边中心额外加一个虚点。
点击查看代码
#include <bits/stdc++.h>
namespace Initial {
#define ll int
#define ull unsigned long long
#define fi first
#define se second
#define mkp make_pair
#define pir pair <ll, ll>
#define pb push_back
#define i128 __int128
using namespace std;
const ll maxn = 2e5 + 10, inf = 1e9, mod = 998244353, iv = mod - mod / 2;
ll power(ll a, ll b = mod - 2, ll p = mod) {
ll s = 1;
while(b) {
if(b & 1) s = 1ll * s * a %p;
a = 1ll * a * a %p, b >>= 1;
} return s;
}
template <class T>
const inline ll pls(const T x, const T y) { return x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void add(T &x, const T y) { x = x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void chkmax(T &x, const T y) { x = x < y? y : x; }
template <class T>
const inline void chkmin(T &x, const T y) { x = x < y? x : y; }
} using namespace Initial;
namespace Read {
char buf[1 << 22], *p1, *p2;
// #define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, (1 << 22) - 10, stdin), p1 == p2)? EOF : *p1++)
template <class T>
const inline void rd(T &x) {
char ch; bool neg = 0;
while(!isdigit(ch = getchar()))
if(ch == '-') neg = 1;
x = ch - '0';
while(isdigit(ch = getchar()))
x = (x << 1) + (x << 3) + ch - '0';
if(neg) x = -x;
}
} using Read::rd;
ll n; long long s[maxn]; vector <ll> to[maxn];
namespace LCA{
ll d[maxn][20], dep[maxn], st[20][maxn], Log[maxn], ti, dfn[maxn];
void dfs(ll u, ll fa = 0) {
d[u][0] = fa, dep[u] = dep[fa] + 1;
st[0][++ti] = fa, dfn[u] = ti;
for(ll i = 1; i < 20; i++) d[u][i] = d[d[u][i - 1]][i - 1];
for(ll v: to[u])
if(v ^ fa) dfs(v, u);
}
ll Min(ll u, ll v) {return dep[u] < dep[v]? u : v;}
void Init() {
for(ll i = 2; i <= ti; i++) Log[i] = Log[i >> 1] + 1;
for(ll i = 1; (1 << i) <= ti; i++)
for(ll j = 1; j + (1 << i) - 1 <= ti; j++)
st[i][j] = Min(st[i - 1][j], st[i - 1][j + (1 << i - 1)]);
}
ll lca(ll u, ll v) {
if(u == v) return u;
ll l = min(dfn[u], dfn[v]) + 1, r = max(dfn[u], dfn[v]);
ll k = Log[r - l + 1];
return Min(st[k][l], st[k][r - (1 << k) + 1]);
}
ll jump(ll u, ll k) {
for(ll i = 0; i < 20; i++)
if(k & (1 << i)) u = d[u][i];
return u;
}
ll mov(ll u, ll v, ll k) {
ll c = lca(u, v);
if(k <= dep[u] - dep[c]) return jump(u, k);
return jump(v, dep[u] + dep[v] - 2 * dep[c] - k);
}
ll dist(ll u, ll v) {return dep[u] + dep[v] - 2 * dep[lca(u, v)];}
} using namespace LCA;
namespace Centroid_Divide {
ll rt, bs, siz[maxn], par[maxn]; bool vis[maxn];
void findrt(ll u, ll N, ll fa = 0) {
siz[u] = 1; ll mx = 0;
for(ll v: to[u])
if(v != fa && !vis[v]) {
findrt(v, N, u), siz[u] += siz[v];
chkmax(mx, siz[v]);
} chkmax(mx, N - siz[u]);
if(mx < bs) bs = mx, rt = u;
}
void getsiz(ll u, ll fa = 0) {
siz[u] = 1;
for(ll v: to[u])
if(v != fa && !vis[v])
getsiz(v, u), siz[u] += siz[v];
}
ll build(ll u, ll N) {
bs = inf, findrt(u, N);
vis[u = rt] = true, getsiz(u);
for(ll v: to[u])
if(!vis[v]) par[build(v, siz[v])] = u;
return u;
}
ll cnt[maxn]; long long sum[maxn], _sum[maxn];
long long qry(ll u) {
long long ret = 0;
for(ll x = u, y = 0; x; y = x, x = par[x]) {
ll d = dist(u, x);
ret += 1ll * (cnt[x] - cnt[y]) * d + sum[x] - _sum[y];
} return ret;
}
void add(ll u, ll w) {
for(ll x = u, y = 0; x; y = x, x = par[x]) {
ll d = dist(u, x);
cnt[x] += w, sum[x] += w * d, _sum[y] += w * d;
}
}
} using namespace Centroid_Divide;
struct Circle {ll u, r;} h[maxn]; long long ans;
bool contain(const Circle A, const Circle B) {
ll d = dist(A.u, B.u);
return A.r >= B.r + d;
}
Circle operator + (const Circle A, const Circle B) {
ll d = dist(A.u, B.u);
if(contain(A, B)) return A;
if(contain(B, A)) return B;
return (Circle) {mov(A.u, B.u, (d + B.r - A.r) >> 1), (A.r + B.r + d) >> 1};
}
void solve(ll l, ll r) {
if(l == r) return; ll mid = l + r >> 1;
solve(l, mid), solve(mid + 1, r); s[l - 1] = 0;
h[mid] = (Circle) {mid, 0}, h[mid + 1] = (Circle) {mid + 1, 0};
for(ll i = mid - 1; i >= l; i--) h[i] = h[i + 1] + (Circle) {i, 0};
for(ll i = mid + 2; i <= r; i++) h[i] = h[i - 1] + (Circle) {i, 0};
for(ll i = l; i <= r; i++) s[i] = s[i - 1] + h[i].r;
for(ll i = mid, j = mid, k = mid; i >= l; i--) {
while(k < r && !contain(h[k + 1], h[i])) add(h[++k].u, 1);
while(j < k && contain(h[i], h[j + 1])) add(h[++j].u, -1);
ans += 1ll * h[i].r * (j - mid);
ans += s[r] - s[k];
ans += (s[k] - s[j] + 1ll * h[i].r * (k - j) + qry(h[i].u)) >> 1;
if(i == l)
while(j < k) add(h[++j].u, -1);
}
}
int main() {
rd(n);
for(ll i = 1; i < n; i++) {
ll u, v; rd(u), rd(v);
to[n + i].pb(u), to[u].pb(n + i);
to[n + i].pb(v), to[v].pb(n + i);
} dfs(1), build(1, 2 * n - 1);
Init(), solve(1, n);
printf("%lld\n", ans);
return 0;
}

浙公网安备 33010602011771号