Fork me on GitHub
编程之美 2013 全国挑战赛 资格赛

编程之美 2013 全国挑战赛 资格赛 题目三 树上的三角形

题目三 树上的三角形

时间限制: 2000ms 内存限制: 256MB

描述

有一棵树,树上有只毛毛虫。它在这棵树上生活了很久,对它的构造了如指掌。所以它在树上从来都是走最短路,不会绕路。它还还特别喜欢三角形,所以当它在树上爬来爬去的时候总会在想,如果把刚才爬过的那几根树枝/树干锯下来,能不能从中选三根出来拼成一个三角形呢?

输入

输入数据的第一行包含一个整数 T,表示数据组数。

接下来有 T 组数据,每组数据中:

第一行包含一个整数 N,表示树上节点的个数(从 1 到 N 标号)。

接下来的 N-1 行包含三个整数 a, b, len,表示有一根长度为 len 的树枝/树干在节点 a 和节点 b 之间。

接下来一行包含一个整数 M,表示询问数。

接下来M行每行两个整数 S, T,表示毛毛虫从 S 爬行到了 T,询问这段路程中的树枝/树干是否能拼成三角形。

输出

对于每组数据,先输出一行"Case #X:",其中X为数据组数编号,从 1 开始。

接下来对于每个询问输出一行,包含"Yes"或“No”,表示是否可以拼成三角形。

数据范围

1 ≤ T ≤ 5

小数据:1 ≤ N ≤ 100, 1 ≤ M ≤ 100, 1 ≤ len ≤ 10000

大数据:1 ≤ N ≤ 100000, 1 ≤ M ≤ 100000, 1 ≤ len ≤ 1000000000

样例输入

2
5
1 2 5
1 3 20
2 4 30
4 5 15
2
3 4
3 5
5
1 4 32
2 3 100
3 5 45
4 5 60
2
1 4
1 3

样例输出

Case #1:
No
Yes
Case #2:
No
Yes

解题思路

这道题如果直接按照题意去写,那么可以利用广度优先搜索得到最短路径(因为这是一颗树,而不是图,所以不必使用最短路算法),然后判断路径上的边是否能组成一个三角形(先对路径排序,然后用两边之和大于第三边进行判断)。不过搜索的时间复杂度是 O(N),判断三角形的时间复杂度为 O(llgl)(其中 l 是最短路径的长度),小数据没问题,但大数据肯定会挂。

判断三角形是否存在,我并没有更好的办法,那么只能在求最短路径上下手了,以下面的树作为例子(题目没说是几叉树,不过没有关系):

图 1 一颗树的示例

求一棵树上两个节点的最短路径,其实就是求两个节点的最近公共祖先(Least Common Ancestors,LCA)。最近公共祖先指的是在一颗有根树中,找到两个节点 u 和 v 最近的公共祖先。这个概念很容易理解,例如上面节点 5 和 10 的 LCA 就是 1,3 和 11 的 LCA 是 3,7 和 9 的 LCA 是 3。

显然,两个节点与它们的最近公共祖先之间的路径(可以不断向上查找父节点得到)加起来,就是两个节点间的最短路径。上面节点 5 和 10 的最短路径就为 5、2、1、3、8、10;节点 3 和 11 的最短路径就为 3、8、9、11。

求 LCA 有两种算法,一种是离线的 Tarjan 算法,计算出所有 M 个询问所需的时间复杂度是 O(N+M);另一种是基于区间最值查询(Range Minimum/Maximum Query,RMQ)的在线算法,预处理时间是 O(NlgN),每次询问的时间复杂度为 O(1),总得时间复杂度就是 O(NlgN+M)。两个算法使用那个都可以,不过感觉还是用 Tarjan 更好点,占用内存更少,速度也更快。关于这两个算法的详细解释,可以参见算法之LCA与RMQ问题,这里就不详细说明了。

在线算法的代码

View Code
 1 #include <stdio.h>
  2 #include <cmath>
  3 #include <algorithm>
  4 #include <list>
  5 #include <string.h>
  6 using namespace std;
  7 // 树的节点
  8 struct Node {
  9     int next, len;
 10     Node (int n, int l):next(n), len(l) {}
 11 };
 12 int pow2[20];
 13 list<Node> nodes[100010];
 14 bool visit[100010];
 15 int ns[200010];
 16 int nIdx;
 17 int length[100010];
 18 int parent[100010];
 19 int depth[200010];
 20 int first[100010];
 21 int mmin[20][200010];
 22 int edges[100010];
 23 // DFS 对树进行预处理
 24 void dfs(int u, int dep)
 25 {
 26     ns[++nIdx] = u; depth[nIdx] = dep;
 27     visit[u] = true;
 28     if (first[u] == -1) first[u] = nIdx;
 29     list<Node>::iterator it = nodes[u].begin(), end = nodes[u].end();
 30     for (;it != end; it++)
 31     {
 32         int v = it->next;
 33         if(!visit[v])
 34         {
 35             length[v] = it->len;
 36             parent[v] = u;
 37             dfs(v, dep + 1);
 38             ns[++nIdx] = u;
 39             depth[nIdx] = dep;
 40         }
 41     }
 42 }
 43 // 初始化 RMQ
 44 void init_rmq()
 45 {
 46     nIdx = 0;
 47     memset(visit, 0, sizeof(visit));
 48     memset(first, -1, sizeof(first));
 49     depth[0] = 0;
 50     length[1] = parent[1] = 0;
 51     dfs(1, 1);
 52     memset(mmin, 0, sizeof(mmin));
 53     for(int i = 1; i <= nIdx; i++) {
 54         mmin[0][i] = i;
 55     }
 56     int t1 = (int)(log((double)nIdx) / log(2.0));
 57     for(int i = 1; i <= t1; i++) {
 58         for(int j = 1; j + pow2[i] - 1 <= nIdx; j++) {
 59             int a = mmin[i-1][j], b = mmin[i-1][j+pow2[i-1]];
 60             if(depth[a] <= depth[b]) {
 61                 mmin[i][j] = a;
 62             } else {
 63                 mmin[i][j] = b;
 64             }
 65         }
 66     }
 67 }
 68 // RMQ 询问
 69 int rmq(int u, int v)
 70 {
 71     int i = first[u], j = first[v];
 72     if(i > j) swap(i, j);
 73     int t1 = (int)(log((double)j - i + 1) / log(2.0));
 74     int a = mmin[t1][i], b = mmin[t1][j - pow2[t1] + 1];
 75     if(depth[a] <= depth[b]) {
 76         return ns[a];
 77     } else {
 78         return ns[b];
 79     }
 80 }
 81 
 82 int main() {
 83     for(int i = 0; i < 20; i++) {
 84          pow2[i] = 1 << i;
 85     }
 86     int T, n, m, a, b, len;
 87     scanf("%d ", &T);
 88     for (int caseIdx = 1;caseIdx <= T;caseIdx++) {
 89         scanf("%d", &n);
 90         for (int i = 0;i <= n;i++) {
 91             nodes[i].clear();
 92         }
 93         for (int i = 1;i < n;i++) {
 94             scanf("%d%d%d", &a, &b, &len);
 95             nodes[a].push_back(Node(b, len));
 96             nodes[b].push_back(Node(a, len));
 97         }
 98         init_rmq();
 99         scanf("%d", &m);
100         printf("Case #%d:\n", caseIdx);
101         for (int i = 0;i < m;i++) {
102             scanf("%d%d", &a, &b);
103             // 利用 RMQ 得到 LCA
104             int root = rmq(a, b);
105             bool success = false;
106             int l = 0;
107             while (a != root) {
108                 edges[l++] = length[a];
109                 a = parent[a];
110             }
111             while (b != root) {
112                 edges[l++] = length[b];
113                 b = parent[b];
114             }
115             if (l >= 3) {
116                 sort(edges, edges + l);
117                 for (int j = 2;j < l;j++) {
118                     if (edges[j - 2] + edges[j - 1] > edges[j]) {
119                         success = true;
120                         break;
121                     }
122                 }
123             }
124             if (success) {
125                 puts("Yes");
126             } else {
127                 puts("No");
128             }
129         }
130     }
131     return 0;
132 }

离线算法的代码

View Code
 1 #include <stdio.h>
  2 #include <string.h>
  3 #include <list>
  4 #include <algorithm>
  5 using namespace std;
  6 // 树和查询的节点
  7 struct Node {
  8     int next, len;
  9     Node (int n, int l):next(n), len(l) {}
 10 };
 11 list<Node> nodes[100010];
 12 list<Node> querys[100010];
 13 bool visit[100010];
 14 int ancestor[100010];
 15 int parent[100010];
 16 int length[100010];
 17 int edges[100010];
 18 // 查询的结果
 19 bool result[100010];
 20 // 并查集
 21 int uset[100010];
 22 int find(int x) {
 23     int p = x, t;
 24     while (uset[p] >= 0) p = uset[p];
 25     while (x != p) { t = uset[x]; uset[x] = p; x = t; }
 26     return x;
 27 }
 28 void un_ion(int a, int b) {
 29     if ((a = find(a)) == (b = find(b))) return;
 30     if (uset[a] < uset[b]) { uset[a] += uset[b]; uset[b] = a; }
 31     else { uset[b] += uset[a]; uset[a] = b; }
 32 }
 33 void init_uset() {
 34     memset(uset, -1, sizeof(uset));
 35 }
 36 
 37 void tarjan(int u) {
 38     visit[u] = true;
 39     ancestor[find(u)] = u;
 40     list<Node>::iterator it = nodes[u].begin(), end = nodes[u].end();
 41     for (;it != end; it++)
 42     {
 43         int v = it->next;
 44         if(!visit[v])
 45         {
 46             length[v] = it->len;
 47             parent[v] = u;
 48             tarjan(v);
 49             un_ion(u, v);
 50             ancestor[find(u)] = u;
 51         }
 52     }
 53     it = querys[u].begin(); end = querys[u].end();
 54     for (;it != end; it++)
 55     {
 56         int v = it->next;
 57         if(visit[v])
 58         {
 59             // 处理从 u 起始的查询
 60             int root = ancestor[find(v)];
 61             int l = 0;
 62             int a = u;
 63             while (a != root) {
 64                 edges[l++] = length[a];
 65                 a = parent[a];
 66             }
 67             while (v != root) {
 68                 edges[l++] = length[v];
 69                 v = parent[v];
 70             }
 71             sort(edges, edges + l);
 72             for (int j = 2;j < l;j++) {
 73                 if (edges[j - 2] + edges[j - 1] > edges[j]) {
 74                     result[it->len] = true;
 75                     break;
 76                 }
 77             }
 78         }
 79     }
 80 }
 81 
 82 int main() {
 83     int T, n, m, a, b, len;
 84     scanf("%d ", &T);
 85     for (int caseIdx = 1;caseIdx <= T;caseIdx++) {
 86         scanf("%d", &n);
 87         for (int i = 0;i <= n;i++) {
 88             nodes[i].clear();
 89             querys[i].clear();
 90         }
 91         for (int i = 1;i < n;i++) {
 92             scanf("%d%d%d", &a, &b, &len);
 93             nodes[a].push_back(Node(b, len));
 94             nodes[b].push_back(Node(a, len));
 95         }
 96         scanf("%d", &m);
 97         for (int i = 0;i < m;i++) {
 98             scanf("%d%d", &a, &b);
 99             // 查询要添加两遍,以防止出现遗漏
100             querys[a].push_back(Node(b, i));
101             querys[b].push_back(Node(a, i));
102         }
103         printf("Case #%d:\n", caseIdx);
104         init_uset();
105         memset(visit, 0, sizeof(visit));
106         memset(result, 0, sizeof(result));
107         length[1] = parent[1] = 0;
108         tarjan(1);
109         for (int i = 0;i < m;i++) {
110             if (result[i]) {
111                 puts("Yes");
112             } else {
113                 puts("No");
114             }
115         }
116     }
117     return 0;
118 }

这两个算法应该是没问题的,但大数据的时候都 TLE 了,看来 list 真不能随便用,动态开辟内存还是太慢了。离线算法的内存使用大概只有在线算法的 70%。

后来我翻代码的时候(所有人的代码都可以看到,这点挺给力),看到有人没用上面的 LCA 算法,而是在用 DFS 建好树后,使要判断的两个节点 u 和 v 分别沿着父节点链向上遍历,同时保持 u 和 v 的深度是相同的,这样同样能得到最短路径和 LCA,只不过时间复杂度要高一些。但在这道题中也没有关系,因为在找三角形时还是需要把路径遍历一编才可以,LCA 的计算反而会带来额外的复杂性,看来的确是自己想复杂了。

这段遍历算法大概类似于下面这样:

复制代码
 1 while (deep(u) > deep(v)){
 2     //

算法:Eratosthenes 筛选求质数

说明:
   除了自身之外,无法被其它整数整除的数称之为质数,要求质数很简单,但如何快速的求出质数则一直是程式设计人员与数学家努力的课题,在这边介绍一个着名的Eratosthenes求质数方法。
 
解法:
 首先知道这个问题可以使用回圈来求解,将一个指定的数除以所有小于它的数,若可以整除就不是质数,然而如何减少回圈的检查次数?如何求出小于N的所有质数?
 
 首先假设要检查的数是N好了,则事实上只要检查至N的开根号就可以了,道理很简单,假设A*B = N,如果A大于N的开根号,则事实上在小于A之前的检查就可以先检查到B这个数可以整除N。不过在程式中使用开根号会精确度的问题,所以可以使用i*i <= N进行检查,且执行更快。
 
再来假设有一个筛子存放1~N,例如:
 
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 ........ N
 
先将2的倍数筛去: 2 3 5 7 9 11 13 15 17 19 21 ........ N
 
再将3的倍数筛去: 2 3 5 7 11 13 17 19 ........ N
 
再来将5的倍数筛去,再来将7的质数筛去,再来将11的倍数筛去........,如此进行到最后留下的数就都是质数,这就是Eratosthenes筛选方法(Eratosthenes Sieve Method)。
 
*/
复制代码
public class Eratosthenes {

    /**
     * @param args
     */
    public static void main(String[] args) {
        int N = 100;
        int i = 0, j = 0 , count = 0;
        int prime[] = new int[N + 1];

        //初始化数据
        for (i = 2; i <= N; i++) {
            prime[i] = 1;
        }
        //循环1(N 开方 次)
        for (i = 2; i * i <= N; i++) {
            if (prime[i] == 0) {
                count++;
                continue;
            }
            //循环2(N/i 次)  筛选被i整除的数 
            for (j = i * i; j <= N; j = j + i) {
                prime[j] = 0;
                count++;
            }
        }

        System.out.println("Times of calculation : " + count);
        j=0;
        for (i = 2; i <= N; i++) {
            if (prime[i] == 1) {
                System.out.print("\t");
                System.out.print(i);
                j++;
                if(j % 10 == 0){
                    System.out.println();
                }
            }

        }

    }

}
复制代码

循环次数 O(N):

N 进入循环的次数 循环次数/N
100 109 1.09
1000 1430 1.43
10000 17055 1.70
100000 193328 1.93
1000000 2122879 2.12
10000000 22852766 2.28

 

 

 

 

 

质数可以去http://www.rapidtables.com/math/algebra/Ln.htm进行校验。

 
 
 
posted on 2013-04-09 10:58  HackerVirus  阅读(174)  评论(0编辑  收藏  举报