Loading

树上邻域理论(树上圆理论) 小记

  • 邻域:记 \(f(u, r)\) 表示距离 \(u\) 不超过 \(r\) 的点组成的邻域。

\(x, y\) 为点集 \(S\) 中两个距离最远的点,设 \(u\)\(x, y\) 中点(可能是一条边的中心),设 \(d\)\(x, y\) 的距离,那么覆盖 \(S\) 的最小邻域为 \(f(u, \frac d2)\)

  • 邻域 \(f(u_1, r_1)\) 包含邻域 \(f(u_2, r_2)\),当且仅当 \(r_1 \ge r_2 + \text{dist} (u_1, u_2)\)

事实上我们可以把树上邻域视作平面上的圆,令 \(d = \text{dist}(u_1, u_2)\),那么有

显然有 \(r_1 \ge r_2 + \text{dist}(u_1, u_2)\)

  • \(c(S) = f(u, r)\) 为包含 \(S\) 的最小邻域,求 \(c(S_1 \cup S_2)\) 的合并操作(其中 \(S_1 \cap S_2 = \emptyset\)):

\(S_1\) 包含 \(S_2\)\(c(S_1 \cup S_2) = c(S_1)\)\(S_2\) 包含 \(S_1\) 同理。
否则令 \(d = \text{dist} (u, v)\),则 \(c(S_1 \cup S_2) = f(\text{mov} (u_1, u_2, \frac {d - r_1 + r_2} 2), \frac {d + r_1 + r_2} 2)\)。其中 \(\text{mov}(u, v, k)\) 表示 \(u\)\(v\) 方向移动 \(k\) 步到达的点。

  • 距离查询:对于点集 \(S\),包含其的最小邻域为 \(c(S) = f(u, r)\),则任意一个点 \(v\) 到达 \(S\) 中的最远点距离为 \(\text{dist(u, v)} + r\)

所有直径中点为 \(u\),且点 \(v\) 到直径其中一端距离最大,必然经过直径中心。


先分治,对于 \(i \in [l, mid]\),设 \(h_i\) 为点集 \([i, mid]\) 的最小邻域,\(i \in [mid + 1, r]\) 同理。

现在需要统计所有 \(i \in [l, mid], j \in [mid + 1, r]\)\(h_i\)\(h_j\) 合并后的邻域半径大小之和,邻域合并操作需要分三种情况讨论:

  • \(h_i\) 包含 \(h_j\)

  • \(h_i\)\(h_j\) 不存在包含关系。

  • \(h_i\) 被包含于 \(h_j\)

一个性质:\(\forall i \in [l, mid - 1]\)\(h_i\) 包含 \(h_{i + 1}\)\(i \in [mid + 2, r]\) 同理。

所以 \(h_{mid + 1 \sim r}\) 中,存在两个分界线 \(p, q\) 满足 \(h_i\) 包含 \(h_{mid + 1\sim p}\)\(h_i\)\(h_{p + 1\sim q}\) 不存在包含关系,\(h_i\) 被包含于 \(h_{q + 1\sim r}\)

并且,随着 \(i\) 的减小,\(h_i\) 越来越大,\(p, q\) 应是单调不增的,所以可以直接维护 \(p, q\)

考虑计算答案。

  • 对于 \(h_{mid + 1\sim p}\),合并后邻域仍为 \(h_i\),贡献为 \(h_i\) 的直径。

  • 对于 \(h_{p + 1\sim q}\),合并后为 \(h_i\) 的半径,加上 \(h_j\) 的半径,加上两个中心点之间的距离,贡献乘上 \(\frac 12\)。前两者是容易的,第三者需要使用全局平衡二叉树 / 点分树。

  • 对于 \(h_{q + 1\sim r}\),合并邻域为 \(h_{q + 1\sim r}\),贡献为对应直径。

时间复杂度 \(\mathcal O(n\log^2n)\),注意事先需要给每条边中心额外加一个虚点。

点击查看代码
#include <bits/stdc++.h>
namespace Initial {
	#define ll int
	#define ull unsigned long long
	#define fi first
	#define se second
	#define mkp make_pair
	#define pir pair <ll, ll>
	#define pb push_back
	#define i128 __int128
	using namespace std;
	const ll maxn = 2e5 + 10, inf = 1e9, mod = 998244353, iv = mod - mod / 2;
	ll power(ll a, ll b = mod - 2, ll p = mod) {
		ll s = 1;
		while(b) {
			if(b & 1) s = 1ll * s * a %p;
			a = 1ll * a * a %p, b >>= 1;
		} return s;
	}
	template <class T>
	const inline ll pls(const T x, const T y) { return x + y >= mod? x + y - mod : x + y; }
	template <class T>
	const inline void add(T &x, const T y) { x = x + y >= mod? x + y - mod : x + y; }
	template <class T>
	const inline void chkmax(T &x, const T y) { x = x < y? y : x; }
	template <class T>
	const inline void chkmin(T &x, const T y) { x = x < y? x : y; }
} using namespace Initial;

namespace Read {
	char buf[1 << 22], *p1, *p2;
	// #define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, (1 << 22) - 10, stdin), p1 == p2)? EOF : *p1++)
	template <class T>
	const inline void rd(T &x) {
		char ch; bool neg = 0;
		while(!isdigit(ch = getchar()))
			if(ch == '-') neg = 1;
		x = ch - '0';
		while(isdigit(ch = getchar()))
			x = (x << 1) + (x << 3) + ch - '0';
		if(neg) x = -x;
	}
} using Read::rd;

ll n; long long s[maxn]; vector <ll> to[maxn];

namespace LCA{
	ll d[maxn][20], dep[maxn], st[20][maxn], Log[maxn], ti, dfn[maxn];
	void dfs(ll u, ll fa = 0) {
		d[u][0] = fa, dep[u] = dep[fa] + 1;
		st[0][++ti] = fa, dfn[u] = ti;
		for(ll i = 1; i < 20; i++) d[u][i] = d[d[u][i - 1]][i - 1];
		for(ll v: to[u])
			if(v ^ fa) dfs(v, u);
	}
	ll Min(ll u, ll v) {return dep[u] < dep[v]? u : v;}
	void Init() {
		for(ll i = 2; i <= ti; i++) Log[i] = Log[i >> 1] + 1;
		for(ll i = 1; (1 << i) <= ti; i++)
			for(ll j = 1; j + (1 << i) - 1 <= ti; j++)
				st[i][j] = Min(st[i - 1][j], st[i - 1][j + (1 << i - 1)]);
	}
	ll lca(ll u, ll v) {
		if(u == v) return u;
		ll l = min(dfn[u], dfn[v]) + 1, r = max(dfn[u], dfn[v]);
		ll k = Log[r - l + 1];
		return Min(st[k][l], st[k][r - (1 << k) + 1]);
	}
	ll jump(ll u, ll k) {
		for(ll i = 0; i < 20; i++)
			if(k & (1 << i)) u = d[u][i];
		return u;
	}
	ll mov(ll u, ll v, ll k) {
		ll c = lca(u, v);
		if(k <= dep[u] - dep[c]) return jump(u, k);
		return jump(v, dep[u] + dep[v] - 2 * dep[c] - k);
	}
	ll dist(ll u, ll v) {return dep[u] + dep[v] - 2 * dep[lca(u, v)];}
} using namespace LCA;

namespace Centroid_Divide {
	ll rt, bs, siz[maxn], par[maxn]; bool vis[maxn];
	void findrt(ll u, ll N, ll fa = 0) {
		siz[u] = 1; ll mx = 0;
		for(ll v: to[u])
			if(v != fa && !vis[v]) {
				findrt(v, N, u), siz[u] += siz[v];
				chkmax(mx, siz[v]);
			} chkmax(mx, N - siz[u]);
		if(mx < bs) bs = mx, rt = u;
	}
	void getsiz(ll u, ll fa = 0) {
		siz[u] = 1;
		for(ll v: to[u])
			if(v != fa && !vis[v])
				getsiz(v, u), siz[u] += siz[v]; 
	}
	ll build(ll u, ll N) {
		bs = inf, findrt(u, N);
		vis[u = rt] = true, getsiz(u);
		for(ll v: to[u])
			if(!vis[v]) par[build(v, siz[v])] = u;
		return u;
	}
	ll cnt[maxn]; long long sum[maxn], _sum[maxn];
	long long qry(ll u) {
		long long ret = 0;
		for(ll x = u, y = 0; x; y = x, x = par[x]) {
			ll d = dist(u, x);
			ret += 1ll * (cnt[x] - cnt[y]) * d + sum[x] - _sum[y];
		} return ret;
	}
	void add(ll u, ll w) {
		for(ll x = u, y = 0; x; y = x, x = par[x]) {
			ll d = dist(u, x);
			cnt[x] += w, sum[x] += w * d, _sum[y] += w * d;
		}
	}
} using namespace Centroid_Divide;

struct Circle {ll u, r;} h[maxn]; long long ans;
bool contain(const Circle A, const Circle B) {
	ll d = dist(A.u, B.u);
	return A.r >= B.r + d;
}
Circle operator + (const Circle A, const Circle B) {
	ll d = dist(A.u, B.u);
	if(contain(A, B)) return A;
	if(contain(B, A)) return B;
	return (Circle) {mov(A.u, B.u, (d + B.r - A.r) >> 1), (A.r + B.r + d) >> 1};
}

void solve(ll l, ll r) {
	if(l == r) return; ll mid = l + r >> 1;
	solve(l, mid), solve(mid + 1, r); s[l - 1] = 0;
	h[mid] = (Circle) {mid, 0}, h[mid + 1] = (Circle) {mid + 1, 0};
	for(ll i = mid - 1; i >= l; i--) h[i] = h[i + 1] + (Circle) {i, 0};
	for(ll i = mid + 2; i <= r; i++) h[i] = h[i - 1] + (Circle) {i, 0};
	for(ll i = l; i <= r; i++) s[i] = s[i - 1] + h[i].r;
	for(ll i = mid, j = mid, k = mid; i >= l; i--) {
		while(k < r && !contain(h[k + 1], h[i])) add(h[++k].u, 1);
		while(j < k && contain(h[i], h[j + 1])) add(h[++j].u, -1);
		ans += 1ll * h[i].r * (j - mid);
		ans += s[r] - s[k];
		ans += (s[k] - s[j] + 1ll * h[i].r * (k - j) + qry(h[i].u)) >> 1;
		if(i == l)
			while(j < k) add(h[++j].u, -1);
	}
}

int main() {
	rd(n);
	for(ll i = 1; i < n; i++) {
		ll u, v; rd(u), rd(v);
		to[n + i].pb(u), to[u].pb(n + i);
		to[n + i].pb(v), to[v].pb(n + i);
	} dfs(1), build(1, 2 * n - 1);
	Init(), solve(1, n);
	printf("%lld\n", ans);
	return 0;
}
posted @ 2025-02-07 19:01  Sktn0089  阅读(356)  评论(0)    收藏  举报