【BZOJ1468】Tree [点分治]

Tree

Time Limit: 10 Sec  Memory Limit: 64 MB
[Submit][Status][Discuss]

Description

  给你一棵TREE,以及这棵树上边的距离,问有多少对点它们两者间的距离小于等于K。

Input

  第一行一个n,接下来n-1行边描述管道,按照题目中写的输入,接下来是一个k。

Output

  仅包括一个整数,表示有多少对点之间的距离小于等于k。

Sample Input

  7
  1 6 13
  6 3 9
  3 5 7
  4 1 3
  2 4 20
  4 7 2
  10

Sample Output

  5

HINT

  n<=40000

Solution

  树上路径统计问题,运用点分
  每次处理与重心相关的路径,发现如果直接处理两点之间比较困难,我们想到了将所有点加入一个数组,用指针判断加起来<=K的个数,这样的话不一定全是简单路径,但是我们只要减去每个子树中这样操作的条数就一定只剩下简单路径了。
  点分大概的步骤:
  1.找出重心;
  2.计算经过该重心的路径相关需要求的;
  3.去掉重心对于每棵子树继续做以上过程。

Code

  1 #include<iostream>  
  2 #include<algorithm>  
  3 #include<cstdio>  
  4 #include<cstring>  
  5 #include<cstdlib>  
  6 #include<cmath>  
  7 using namespace std;  
  8       
  9 const int ONE=80001;
 10   
 11 int n,K;
 12 int x,y,z;
 13 int next[ONE],first[ONE],go[ONE],w[ONE],tot;
 14 int center_vis[ONE];
 15 int Max,dist[ONE],num;
 16 int d[ONE];
 17 int center,Ans;
 18   
 19 int get()
 20 { 
 21         int res,Q=1;    char c;
 22         while( (c=getchar())<48 || c>57)
 23         if(c=='-')Q=-1;
 24         if(Q) res=c-48; 
 25         while((c=getchar())>=48 && c<=57) 
 26         res=res*10+c-48; 
 27         return res*Q; 
 28 }
 29   
 30 int Add(int u,int v,int z)
 31 {
 32         next[++tot]=first[u];   first[u]=tot;   go[tot]=v;  w[tot]=z;
 33         next[++tot]=first[v];   first[v]=tot;   go[tot]=u;  w[tot]=z;
 34 }
 35   
 36 namespace PointF
 37 {
 38         struct power
 39         {
 40             int maxx,size;
 41         }S[ONE];
 42           
 43         void Getsize(int u,int father)
 44         {
 45                 S[u].size=1;
 46                 S[u].maxx=0;
 47                 for(int e=first[u];e;e=next[e])
 48                 {
 49                     int v=go[e];
 50                     if(v==father || center_vis[v]) continue;
 51                     Getsize(v,u);
 52                     S[u].size+=S[v].size;
 53                     S[u].maxx=max(S[u].maxx,S[v].size);
 54                 }
 55         }
 56           
 57         void Getcenter(int u,int father,int total)
 58         {
 59                 S[u].maxx=max(S[u].maxx,total-S[u].size);
 60                 if(Max>S[u].maxx)
 61                 {
 62                     Max=S[u].maxx;
 63                     center=u;
 64                 }
 65                   
 66                 for(int e=first[u];e;e=next[e])
 67                 {
 68                     int v=go[e];
 69                     if(v==father || center_vis[v]) continue;
 70                     Getcenter(v,u,total);
 71                 }
 72         }
 73           
 74         void Getdist(int u,int father,int value)
 75         {
 76                 dist[++num]=value;
 77                 for(int e=first[u];e;e=next[e])
 78                 {
 79                     int v=go[e];
 80                     if(v==father || center_vis[v]) continue;
 81                     Getdist(v,u,value+w[e]);
 82                 }
 83         }
 84           
 85         int Calc(int u,int value)
 86         {
 87             int res=0;
 88             num=0;
 89               
 90             Getdist(u,0,value);
 91             sort(dist+1,dist+num+1);
 92               
 93             int l=1,r=num;
 94               
 95             while(l<r)
 96             {
 97                 while(dist[l]+dist[r]>K && l<r) r--;
 98                 res+=r-l;
 99                 l++;
100             }
101             return res;
102         }
103           
104         void Dfs(int u)
105         {
106                 Max=n;  
107                 Getsize(u,0);
108                   
109                 center=u;
110                 Getcenter(u,0,S[u].size);
111                 center_vis[center]=true;
112                  
113                 Ans+=Calc(center,0);
114                 for(int e=first[center];e;e=next[e])
115                 {
116                     int v=go[e];
117                     if(center_vis[v]) continue;
118                     Ans-=Calc(v,w[e]);
119                     Dfs(v);
120                 }
121         }
122 }
123   
124   
125   
126 int main()
127 {
128         n=get();
129         for(int i=1;i<n;i++)
130         {
131             x=get();    y=get();    z=get();
132             Add(x,y,z);
133         }
134           
135         K=get();
136         PointF::Dfs(1);
137           
138         printf("%d",Ans);
139 }
View Code

 

posted @ 2017-02-25 20:42  BearChild  阅读(251)  评论(0编辑  收藏  举报