编程之美 2013 全国挑战赛 资格赛 题目三 树上的三角形

题目三 树上的三角形

时间限制: 2000ms 内存限制: 256MB

描述

有一棵树,树上有只毛毛虫。它在这棵树上生活了很久,对它的构造了如指掌。所以它在树上从来都是走最短路,不会绕路。它还还特别喜欢三角形,所以当它在树上爬来爬去的时候总会在想,如果把刚才爬过的那几根树枝/树干锯下来,能不能从中选三根出来拼成一个三角形呢?

输入

输入数据的第一行包含一个整数 T,表示数据组数。

接下来有 T 组数据,每组数据中:

第一行包含一个整数 N,表示树上节点的个数(从 1 到 N 标号)。

接下来的 N-1 行包含三个整数 a, b, len,表示有一根长度为 len 的树枝/树干在节点 a 和节点 b 之间。

接下来一行包含一个整数 M,表示询问数。

接下来M行每行两个整数 S, T,表示毛毛虫从 S 爬行到了 T,询问这段路程中的树枝/树干是否能拼成三角形。

输出

对于每组数据,先输出一行"Case #X:",其中X为数据组数编号,从 1 开始。

接下来对于每个询问输出一行,包含"Yes"或“No”,表示是否可以拼成三角形。

数据范围

1 ≤ T ≤ 5

小数据:1 ≤ N ≤ 100, 1 ≤ M ≤ 100, 1 ≤ len ≤ 10000

大数据:1 ≤ N ≤ 100000, 1 ≤ M ≤ 100000, 1 ≤ len ≤ 1000000000

样例输入

2
5
1 2 5
1 3 20
2 4 30
4 5 15
2
3 4
3 5
5
1 4 32
2 3 100
3 5 45
4 5 60
2
1 4
1 3

样例输出

Case #1:
No
Yes
Case #2:
No
Yes

解题思路

这道题如果直接按照题意去写,那么可以利用广度优先搜索得到最短路径(因为这是一颗树,而不是图,所以不必使用最短路算法),然后判断路径上的边是否能组成一个三角形(先对路径排序,然后用两边之和大于第三边进行判断)。不过搜索的时间复杂度是 $O(N)$,判断三角形的时间复杂度为 $O(l \lg l)$(其中 $l$ 是最短路径的长度),小数据没问题,但大数据肯定会挂。

判断三角形是否存在,我并没有更好的办法,那么只能在求最短路径上下手了,以下面的树作为例子(题目没说是几叉树,不过没有关系):

图 1 一颗树的示例

求一棵树上两个节点的最短路径,其实就是求两个节点的最近公共祖先(Least Common Ancestors,LCA)。最近公共祖先指的是在一颗有根树中,找到两个节点 u 和 v 最近的公共祖先。这个概念很容易理解,例如上面节点 5 和 10 的 LCA 就是 1,3 和 11 的 LCA 是 3,7 和 9 的 LCA 是 3。

显然,两个节点与它们的最近公共祖先之间的路径(可以不断向上查找父节点得到)加起来,就是两个节点间的最短路径。上面节点 5 和 10 的最短路径就为 5、2、1、3、8、10;节点 3 和 11 的最短路径就为 3、8、9、11。

求 LCA 有两种算法,一种是离线的 Tarjan 算法,计算出所有 $M$ 个询问所需的时间复杂度是 $O(N + M)$;另一种是基于区间最值查询(Range Minimum/Maximum Query,RMQ)的在线算法,预处理时间是 $O(N \lg N)$,每次询问的时间复杂度为 $O(1)$,总得时间复杂度就是 $O(N \lg N + M)$。两个算法使用那个都可以,不过感觉还是用 Tarjan 更好点,占用内存更少,速度也更快。关于这两个算法的详细解释,可以参见算法之LCA与RMQ问题,这里就不详细说明了。

在线算法的代码

#include <stdio.h>
#include <cmath>
#include <algorithm>
#include <list>
#include <string.h>
using namespace std;
// 树的节点
struct Node {
	int next, len;
	Node (int n, int l):next(n), len(l) {}
};
int pow2[20];
list<Node> nodes[100010];
bool visit[100010];
int ns[200010];
int nIdx;
int length[100010];
int parent[100010];
int depth[200010];
int first[100010];
int mmin[20][200010];
int edges[100010];
// DFS 对树进行预处理
void dfs(int u, int dep)
{
	ns[++nIdx] = u; depth[nIdx] = dep;
	visit[u] = true;
	if (first[u] == -1) first[u] = nIdx;
	list<Node>::iterator it = nodes[u].begin(), end = nodes[u].end();
	for (;it != end; it++)
	{
		int v = it->next;
		if(!visit[v])
		{
			length[v] = it->len;
			parent[v] = u;
			dfs(v, dep + 1);
			ns[++nIdx] = u;
			depth[nIdx] = dep;
		}
	}
}
// 初始化 RMQ
void init_rmq()
{
	nIdx = 0;
	memset(visit, 0, sizeof(visit));
	memset(first, -1, sizeof(first));
	depth[0] = 0;
	length[1] = parent[1] = 0;
	dfs(1, 1);
	memset(mmin, 0, sizeof(mmin));
	for(int i = 1; i <= nIdx; i++) {
		mmin[0][i] = i;
	}
	int t1 = (int)(log((double)nIdx) / log(2.0));
	for(int i = 1; i <= t1; i++) {
		for(int j = 1; j + pow2[i] - 1 <= nIdx; j++) {
			int a = mmin[i-1][j], b = mmin[i-1][j+pow2[i-1]];
			if(depth[a] <= depth[b]) {
				mmin[i][j] = a;
			} else {
				mmin[i][j] = b;
			}
		}
	}
}
// RMQ 询问
int rmq(int u, int v)
{
	int i = first[u], j = first[v];
	if(i > j) swap(i, j);
	int t1 = (int)(log((double)j - i + 1) / log(2.0));
	int a = mmin[t1][i], b = mmin[t1][j - pow2[t1] + 1];
	if(depth[a] <= depth[b]) {
		return ns[a];
	} else {
		return ns[b];
	}
}

int main() {
	for(int i = 0; i < 20; i++) {
		 pow2[i] = 1 << i;
	}
	int T, n, m, a, b, len;
	scanf("%d ", &T);
	for (int caseIdx = 1;caseIdx <= T;caseIdx++) {
		scanf("%d", &n);
		for (int i = 0;i <= n;i++) {
			nodes[i].clear();
		}
		for (int i = 1;i < n;i++) {
			scanf("%d%d%d", &a, &b, &len);
			nodes[a].push_back(Node(b, len));
			nodes[b].push_back(Node(a, len));
		}
		init_rmq();
		scanf("%d", &m);
		printf("Case #%d:\n", caseIdx);
		for (int i = 0;i < m;i++) {
			scanf("%d%d", &a, &b);
			// 利用 RMQ 得到 LCA
			int root = rmq(a, b);
			bool success = false;
			int l = 0;
			while (a != root) {
				edges[l++] = length[a];
				a = parent[a];
			}
			while (b != root) {
				edges[l++] = length[b];
				b = parent[b];
			}
			if (l >= 3) {
				sort(edges, edges + l);
				for (int j = 2;j < l;j++) {
					if (edges[j - 2] + edges[j - 1] > edges[j]) {
						success = true;
						break;
					}
				}
			}
			if (success) {
				puts("Yes");
			} else {
				puts("No");
			}
		}
	}
	return 0;
}

离线算法的代码

#include <stdio.h>
#include <string.h>
#include <list>
#include <algorithm>
using namespace std;
// 树和查询的节点
struct Node {
	int next, len;
	Node (int n, int l):next(n), len(l) {}
};
list<Node> nodes[100010];
list<Node> querys[100010];
bool visit[100010];
int ancestor[100010];
int parent[100010];
int length[100010];
int edges[100010];
// 查询的结果
bool result[100010];
// 并查集
int uset[100010];
int find(int x) {
	int p = x, t;
	while (uset[p] >= 0) p = uset[p];
	while (x != p) { t = uset[x]; uset[x] = p; x = t; }
	return x;
}
void un_ion(int a, int b) {
	if ((a = find(a)) == (b = find(b))) return;
	if (uset[a] < uset[b]) { uset[a] += uset[b]; uset[b] = a; }
	else { uset[b] += uset[a]; uset[a] = b; }
}
void init_uset() {
	memset(uset, -1, sizeof(uset));
}

void tarjan(int u) {
	visit[u] = true;
	ancestor[find(u)] = u;
	list<Node>::iterator it = nodes[u].begin(), end = nodes[u].end();
	for (;it != end; it++)
	{
		int v = it->next;
		if(!visit[v])
		{
			length[v] = it->len;
			parent[v] = u;
			tarjan(v);
			un_ion(u, v);
			ancestor[find(u)] = u;
		}
	}
	it = querys[u].begin(); end = querys[u].end();
	for (;it != end; it++)
	{
		int v = it->next;
		if(visit[v])
		{
			// 处理从 u 起始的查询
			int root = ancestor[find(v)];
			int l = 0;
			int a = u;
			while (a != root) {
				edges[l++] = length[a];
				a = parent[a];
			}
			while (v != root) {
				edges[l++] = length[v];
				v = parent[v];
			}
			sort(edges, edges + l);
			for (int j = 2;j < l;j++) {
				if (edges[j - 2] + edges[j - 1] > edges[j]) {
					result[it->len] = true;
					break;
				}
			}
		}
	}
}

int main() {
	int T, n, m, a, b, len;
	scanf("%d ", &T);
	for (int caseIdx = 1;caseIdx <= T;caseIdx++) {
		scanf("%d", &n);
		for (int i = 0;i <= n;i++) {
			nodes[i].clear();
			querys[i].clear();
		}
		for (int i = 1;i < n;i++) {
			scanf("%d%d%d", &a, &b, &len);
			nodes[a].push_back(Node(b, len));
			nodes[b].push_back(Node(a, len));
		}
		scanf("%d", &m);
		for (int i = 0;i < m;i++) {
			scanf("%d%d", &a, &b);
			// 查询要添加两遍,以防止出现遗漏
			querys[a].push_back(Node(b, i));
			querys[b].push_back(Node(a, i));
		}
		printf("Case #%d:\n", caseIdx);
		init_uset();
		memset(visit, 0, sizeof(visit));
		memset(result, 0, sizeof(result));
		length[1] = parent[1] = 0;
		tarjan(1);
		for (int i = 0;i < m;i++) {
			if (result[i]) {
				puts("Yes");
			} else {
				puts("No");
			}
		}
	}
	return 0;
}

这两个算法应该是没问题的,但大数据的时候都 TLE 了,看来 list 真不能随便用,动态开辟内存还是太慢了。离线算法的内存使用大概只有在线算法的 70%。

后来我翻代码的时候(所有人的代码都可以看到,这点挺给力),看到有人没用上面的 LCA 算法,而是在用 DFS 建好树后,使要判断的两个节点 $u$ 和 $v$ 分别沿着父节点链向上遍历,同时保持 $u$ 和 $v$ 的深度是相同的,这样同样能得到最短路径和 LCA,只不过时间复杂度要高一些。但在这道题中也没有关系,因为在找三角形时还是需要把路径遍历一编才可以,LCA 的计算反而会带来额外的复杂性,看来的确是自己想复杂了。

这段遍历算法大概类似于下面这样:

while (deep(u) > deep(v)){
	// 记录路径 u 到 parent(u) 的路径
	u = parent(u);
}
while (deep(v) > deep(u)){
	// 记录路径 v 到 parent(v) 的路径
	v = parent(v);
}
while (u != v){
	// 记录路径 u 到 parent(u) 的路径
	u = parent(u);
	// 记录路径 v 到 parent(v) 的路径
	v = parent(v);
}

完整的代码见这里,ID 是 mochavic(排名第一),果然是大神。

还是把 mochavic 的代码也贴这里吧:

 

#include <stdio.h>
#include <vector>
#include <algorithm>
using namespace std;
int deep[100010], f[100010][2];
vector<int> e[100010];
int c[60], cn;
void dfs(int fa, int x, int d){
	int i, y;
	deep[x] = d;
	for (i = 0; i < (int)e[x].size(); i += 2){
		y = e[x][i];
		if (y == fa) continue;
		f[y][0] = x;
		f[y][1] = e[x][i + 1];
		dfs(x, y, d + 1);
	}
}
void pd(int x, int y){
	while (deep[x] > deep[y]){
		c[cn++] = f[x][1];
		x = f[x][0];
		if (cn == 50) return;
	}
	while (deep[y] > deep[x]){
		c[cn++] = f[y][1];
		y = f[y][0];
		if (cn == 50) return;
	}
	while (x != y){
		c[cn++] = f[x][1];
		x = f[x][0];
		c[cn++] = f[y][1];
		y = f[y][0];
		if (cn >= 50) return;
	}
}
int main(){
	int T, ri = 1, n, m, x, y, z, i;
	scanf("%d", &T);
	while (T--){
		scanf("%d", &n);
		for (i = 1; i <= n; i++) e[i].clear();
		for (i = 1; i < n; i++){
			scanf("%d%d%d", &x, &y, &z);
			e[x].push_back(y);
			e[x].push_back(z);
			e[y].push_back(x);
			e[y].push_back(z);
		}
		dfs(0, 1, 0);
		printf("Case #%d:\n", ri++);
		scanf("%d", &m);
		while (m--){
			scanf("%d%d", &x, &y);
			cn = 0;
			pd(x, y);
			sort(c, c + cn);
			for (i = 0; i + 2 < cn; i++){
				if (c[i] + c[i + 1] > c[i + 2]) break;
			}
			if (i + 2 < cn) printf("Yes\n");
			else printf("No\n");
		}
	}
	return 0;
}
posted @ 2013-04-09 00:27  CYJB  阅读(2784)  评论(0编辑  收藏  举报
Fork me on GitHub