点分治

\(\text{luogu-3806}\)

给定一棵有 \(n\) 个点的树,询问树上距离为 \(k\) 的点对是否存在。

\(1 \leq n\leq 10^4\)\(1 \leq m\leq 100\)\(1 \leq k \leq 10^7\)\(1 \leq u, v \leq n\)\(1 \leq w \leq 10^4\)


点分治模板题。

下面部分题解来自于 题解:P3806 【模板】点分治 1 - 洛谷专栏

OI wiki:点分治适合处理大规模的树上路径信息问题。这道题目就是一个树上路径信息问题。

点分治的基本思路是:定一个根,把路径分为三类。

  • 以根节点为其中一个端点的;
  • 只是经过根节点的;
  • 不经过根节点的。

首先弄一棵树出来,定根为 \(1\)。(这图真大)

容易看出第一类和第二类都与根节点相关,于是我们先处理它们。处理方式因题而异,下面会解释,但是通常我们会把第二类路径拆成两个第一类。

假设我们已经处理完了这些路径,现在还剩第三类路径。这时,根节点已经没有存在的意义了。于是我们把它删了。图变成了这样:

这时我们发现图变成了三颗不连通的树。然后,我们就可以贯彻分治思想,递归地处理这三棵树。于是,所有的第三类路径就都被转化成了第一类或者第二类路径。

这就是点分治的思想。

以下是这道题目的题解。样例太水,给一组样例:

6 5
1 2 2
1 3 4
2 5 3
2 4 7
1 6 8
5
6
9
10
14
AYE
AYE
AYE
AYE
NAY

图例:(就是刚刚的图)

本题多测,然而给每一个询问跑一遍点分治太麻烦而且可能会被这题的神奇时限卡掉,于是离线掉询问一起处理。

于是我们给路径分类,探讨第一类和第二类路径的处理方法。

  • 对于一个第一类路径(一个端点是根的路径):
    我们直接对于每一个询问判断是否符合条件。在实现中为了减少码量,我们通常把它与一个根到根的长度为 \(0\) 的路径拼成一个第二类路径处理。
  • 对于一个第二类路径(一个经过根的路径):
    我们把它分成两个第一类路径。然后 DFS 出第一类路径,建立一个 bool 数组(或者 bitset,但是笔者不会用)存储一个距离对应的路径是否存在。然后,遍历询问,对于路径判断不在该子树内的路径中有没有长度符合(即为询问要求的长度减去当前路径的长度)的路径。若有,记录答案。

由于此后第一类路径被转化为了第二类路径,下面统称“路径”。

举个例子:

在处理以 \(1\) 为根,点集为 \(\{1, 2, 3, 4, 5, 6\}\) 的子树时,我们有路径 \(1 \to 5\) 满足询问 1 的需求,有路径 \(3 \to 2\) 满足询问 2 的需求,有路径 \(3 \to 5\) 满足询问 3 的需求,有路径 \(6 \to 2\) 满足询问 4 的需求。它们都经过了当前根 \(1\)

对于这个过程,实现中直接对每个子树跑一遍 DFS 就行了。注意这是一个递归过程,不要跑到已经处理过的点就行了。这里切记不能偷懒对整体 DFS,否则你的算法会认为存在 \(4 \to 2 \to 1 \to 2 \to 5\) 的路径使得长度为 \(14\),满足询问 4。然而,这并不是一个路径。

然后,给当前根节点打一个标记表示已经处理,然后递归处理子树就行了。每次处理的复杂度是 \(\mathcal{O}(n)\),期望树高有 \(\mathcal{O}(\log n)\),期望复杂度 \(\mathcal{O}(n \log n)\)

但是这就完了吗?出题人给你一个链就能把你的算法卡成 \(\mathcal{O}(n^2)\)。那怎么办?

树上有个东西叫重心,其定义为:如果在树中选择某个节点并删除,这棵树将分为若干棵子树,统计子树节点数并记录最大值。取遍树上所有节点,使此最大值取到最小的节点被称为整个树的重心。

显然就有性质:以树的重心为根时,所有子树的大小都不超过整棵树大小的一半。

于是我们每次定根的时候用 \(\mathcal{O}(n)\) 的时间求出重心,以重心为根就可以了。易证树高不大于 \(\log n\)。于是复杂度就稳在了 \(\mathcal{O}(n \log n)\)

#include<iostream>
#include<cstdio>
#include<vector>
using namespace std;
#define MAXN 10005
#define MAXM 10000005
#define ll long long 
#define pii pair<ll, ll>
#define fi first
#define se second

long long read() {
	long long x = 0, f = 1;
	char c = getchar();
	while(c > 57 || c < 48) { if(c == 45) f = -1; c = getchar(); }
	while(c >= 48 && c <= 57) { x = (x << 1) + (x << 3) + (c - 48); c = getchar(); }
	return x * f;
}

ll n, m, k, q[MAXN], f[MAXN], si, sz[MAXN], rt, cnt, d[MAXN], dis[MAXN], qs[MAXN];
bool vis[MAXN], ans[MAXN], p[MAXM];
vector<pii > v[MAXN];

void getrt(ll x, ll fa) {
	sz[x] = 1, f[x] = 0;
	for(auto it : v[x]) if(it.fi != fa && !vis[it.fi]) {
		ll y = it.fi; getrt(y, x), sz[x] += sz[y];
		f[x] = max(f[x], sz[y]);
	}
	f[x] = max(f[x], si - sz[x]);
	if(f[x] < f[rt]) rt = x;
	return;
}

void dfs(ll x, ll fa) {
	dis[++ cnt] = d[x];
	for(auto it : v[x]) if(it.fi != fa && !vis[it.fi]) 
		d[it.fi] = d[x] + it.se, dfs(it.fi, x);
	return;
}

void getans(ll x) {
	p[0] = 1; ll px = 0;
	for(auto it : v[x]) {
		ll y = it.fi, w = it.se;
		if(vis[y]) continue;
		cnt = 0, d[y] = w, dfs(y, x);
		for(int i = 1; i <= cnt; i ++) for(int j = 1; j <= m; j ++)
			if(q[j] >= dis[i]) ans[j] |= p[q[j] - dis[i]];
		for(int i = 1; i <= cnt; i ++) if(dis[i] <= 1e7)
			p[dis[i]] = 1, qs[++ px] = dis[i];
	}
	for(int i = 1; i <= px; i ++) p[qs[i]] = 0;
	return;
}

void solve(ll x) {
	vis[x] = 1, getans(x);
	for(auto it : v[x]) if(!vis[it.fi])
		si = sz[it.fi], f[rt = 0] = 0x3f3f3f3f,
		getrt(it.fi, 0), solve(rt);
	return;
}

int main() {
	n = read(), m = read();
	for(int i = 1; i < n; i ++) {
		ll x = read(), y = read(), w = read();
		v[x].push_back({y, w}), v[y].push_back({x, w});
	}
	for(int i = 1; i <= m; i ++) q[i] = read();
	f[0] = 0x3f3f3f3f, si = n;
	getrt(1, 0), solve(rt);
	for(int i = 1; i <= m; i ++) 
		if(ans[i]) cout << "AYE\n";
		else cout << "NAY\n";
	return 0;
}
posted @ 2026-02-08 16:19  So_noSlack  阅读(7)  评论(0)    收藏  举报