点分治

引入

\(\;\)
假如现在我们得到了一棵\(n\)个节点的树,每条边都有长度。
现在我们要求这棵树中两个点之间距离小于\(k\)的点对个数。
\(n\leq 4×10^4\)

朴素做法

\(\;\)
先预处理好距离,再\(O(n^2)\)枚举点对。

重心

\(\;\)
我们找到这棵树的重心G,把这棵树分为若干个子树,那么发现满足条件的点对只有3种情况:
1.点对在某个子树中(直接递归求解)
2.两个点所构成的路径经过了重心G,但你会发现这两个点一定不能在同一个子树中。
所以我们处理出当前这棵树中每个点的d值,\(d_i\)表示点\(i\)到重心G的距离。
那么只需要用\(d_i+d_j\leq k\)这样\((i,j)\)的数量减去\(d_i+d_j\leq k\)且满足\(i,j\)在同一个子树中的数量
而你会发现,后者可以在递归子树中处理。
3.这条路经的一个端点是G,那么实质上和2.是一种情况,再加入一个\(d_G=0\)即可
\(\;\)

时间复杂度

\(\;\)
选重心来分割整棵树的目的:
你会发现,这若干棵子树中不会有子树的大小超过原树的一半(否则就与重心的定义不符)
所以最多只会递归\(log(n)\)层,每一层也是\(n\)个点。但在递归中还要将处理好的d排序。
总复杂度\(O(n log^2 n)\)

Code

\(\;\)
一定要注意:如果在函数里单独开变量而是开全局变量,一定要注意随时清空,防止上一层的答案对下面有影响。

#include <bits/stdc++.h>

const int N = 40010;
int n, k, head[N], tot, f[N], mn, W, vis[N], d[N], ans, q[N], cnt, sz[N];
struct node {
	int to, nxt, val;
}E[N << 1];
void add(int u, int v, int w) {
	E[++tot].to = v; E[tot].nxt = head[u]; E[tot].val = w; head[u] = tot;
}
void dfs(int total, int u, int fa) {
	f[u] = 0; // 一定注意初始化 
	for(int i=head[u];i;i=E[i].nxt) {
		int v = E[i].to;
		if(v == fa || vis[v]) continue;
		dfs(total, v, u);
		f[u] = std::max(f[u], sz[v]);
	}
	f[u] = std::max(f[u], total - sz[u]);
	if(f[u] < mn) {
		mn = f[u]; W = u;
	}
}
void dfs0(int u, int fa) {
	sz[u] = 1; q[++cnt] = d[u]; // 这是在减去2那一部分的时候的d值 
	for(int i=head[u];i;i=E[i].nxt) {
		int v = E[i].to;
		if(v == fa || vis[v]) continue;
		dfs0(v, u); 
		sz[u] += sz[v];
	}
} 
void getG(int rt) {  	
	cnt = 0; // 随时清空 
	dfs0(rt, 0); // 预处理好每个点的子树大小(因为随着划分重心,树的形态会变化) 
	mn = 1e9; // 注意初始化 
	dfs(sz[rt], rt, 0); // DP计算重心 
	vis[W] = 1; // 这个点作为重心,将其打上标记(相当于一个边界条件) 
}
void getd(int u, int fa) {
	q[++cnt] = d[u]; // 在求d值的过程中将其存入q数组中,这里是以这棵树为重心是的d值,与上面的d不一样 
	for(int i=head[u];i;i=E[i].nxt) {
		int v = E[i].to;
		if(vis[v] || v == fa) continue;
		d[v] = d[u] + E[i].val; 
		getd(v, u);
	}
}
void solve(int rt) {
	getG(rt);  // 得到重心 
	if(sz[rt] != n) { 
	// 对于2.情况,要减去在相同子树内(i,j)<=k的个数。这个过程是在递归到这个子树中时进行的 
	//但对于一开始整棵树的情况就没必要减了 
		std::sort(q + 1, q + cnt + 1); 
		// q里存储的是这棵子树内的d值
		// 因为树内点的编号不一定是连续的,所以需要开q这个数组存它 
		int e1 = 1, e2 = cnt;
		for(int e1=1;e1<=cnt;e1++) {
			while(e2 > e1 && q[e1] + q[e2] > k) e2 --; // 双指针枚举,复杂度是线性的 
			if(e2 <= e1) break;
			ans -= (e2 - e1);
		}
	}
	d[W] = 0; // 重心的d当然是0 
	cnt = 0; // 一定注意要随时清空 
	getd(W, 0);
	std::sort(q + 1, q + cnt + 1);
	int e1 = 1, e2 = cnt;
	for(int e1=1;e1<=cnt;e1++) {
		while(e2 > e1 && q[e1] + q[e2] > k) e2 --;
		if(e2 <= e1) break;
		ans += (e2 - e1);
	} 
	for(int i=head[W];i;i=E[i].nxt) {
		int v = E[i].to;
		if(!vis[v]) solve(v); // 如果这个点没被打上标记,一定要向下递归 
	}
}
int main() {
	scanf("%d", &n);
	for(int i=1;i<n;i++) {
		int u, v, w;
		scanf("%d%d%d", &u, &v, &w);
		add(u, v, w); add(v, u, w);
	}
	scanf("%d", &k);
	solve(1);
	printf("%d", ans);
	return 0;
}
posted @ 2021-02-25 22:41  czytysnow  阅读(36)  评论(0编辑  收藏