点分治学习记录

学习了一下点分治

POJ 1741(由于poj不支持c++11和bits,改一下即可ac)

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define ll long long
  4 struct edge
  5 {
  6     int to,val;
  7     edge(int a,int b)
  8     {to=a,val=b;}
  9 };
 10 const int maxn=10005;
 11 vector<edge>G[maxn];
 12 int sz[maxn],f[maxn],used[maxn],rt;
 13 int n,size,k;
 14 ll ans;
 15 void getrt(int u,int fa)
 16 {
 17     f[u]=0;
 18     sz[u]=1;
 19     for(auto i:G[u])
 20     {
 21         int v=i.to;
 22         if(v==fa || used[v]) continue;
 23         getrt(v,u);
 24         sz[u]+=sz[v];
 25         f[u]=max(f[u],sz[v]);
 26     }
 27     f[u]=max(f[u],size-sz[u]);
 28     if(f[u]<f[rt]) rt=u;
 29 }
 30 int d[maxn];
 31 vector<int>tmp;
 32 void dfs(int u,int fa)
 33 {
 34     for(auto i:G[u])
 35     {
 36         int v=i.to,val=i.val;
 37         if(v==fa || used[v]) continue;
 38         d[v]=d[u]+val;
 39         tmp.push_back(d[v]);
 40         dfs(v,u);
 41     }
 42 }
 43 int calc(int u,int dep)
 44 {
 45     tmp.clear();
 46     d[u]=dep;
 47     tmp.push_back(d[u]);
 48     dfs(u,0);
 49     sort(tmp.begin(),tmp.end());
 50     int ret=0;
 51     for(int l=0,r=tmp.size()-1;l<tmp.size() && l<r;l++)
 52     {
 53         if(2*tmp[l]>k) break;
 54         while(tmp[l]+tmp[r]>k) r--;
 55         if(l>=r) break;
 56         ret+=r-l;
 57     }
 58     return ret;
 59 }
 60 void divide(int now)
 61 {
 62     ans+=calc(now,0);
 63     used[now]=1;
 64     for(auto i:G[now])
 65     {
 66         int v=i.to,val=i.val;
 67         if(used[v]) continue;
 68         ans-=calc(v,val);
 69         rt=0;
 70         size=sz[v];
 71         getrt(v,0);
 72         divide(rt);
 73     }
 74 }
 75 void init()
 76 {
 77     f[0]=1e9+7;
 78     for(int i=1;i<=n;i++) used[i]=0,G[i].clear();
 79     rt=ans=0;
 80     size=n;
 81 }
 82 int main()
 83 {
 84     #ifdef amori
 85     freopen("in.txt","r",stdin);
 86     #endif // amori
 87     while(~scanf("%d%d",&n,&k))
 88     {
 89         if(!n && !k) break;
 90         init();
 91         for(int i=1;i<n;i++)
 92         {
 93             int u,v,w;
 94             scanf("%d%d%d",&u,&v,&w);
 95             G[u].push_back(edge(v,w));
 96             G[v].push_back(edge(u,w));
 97         }
 98         getrt(1,0);
 99         divide(rt);
100         printf("%d\n",ans);
101     }
102 }
View Code

注意事项:初始化时设f[0]为inf,size=n,每次重新获取重心(getrt)的时候,将rt初始化为0,size初始化为子树的大小(sz[v])。

洛谷P3806

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define ll long long
  4 struct edge
  5 {
  6     int to,val;
  7     edge(int a,int b)
  8     {to=a,val=b;}
  9 };
 10 const int maxn=10005;
 11 vector<edge>G[maxn];
 12 int sz[maxn],f[maxn],used[maxn],rt;
 13 int n,size,k;
 14 int ans;
 15 void getrt(int u,int fa)
 16 {
 17     f[u]=0;
 18     sz[u]=1;
 19     for(auto i:G[u])
 20     {
 21         int v=i.to;
 22         if(v==fa || used[v]) continue;
 23         getrt(v,u);
 24         sz[u]+=sz[v];
 25         f[u]=max(f[u],sz[v]);
 26     }
 27     f[u]=max(f[u],size-sz[u]);
 28     if(f[u]<f[rt]) rt=u;
 29 }
 30 int d[maxn];
 31 vector<int>tmp;
 32 int cnt[10000005];
 33 void dfs(int u,int fa)
 34 {
 35     for(auto i:G[u])
 36     {
 37         int v=i.to,val=i.val;
 38         if(v==fa || used[v]) continue;
 39         d[v]=d[u]+val;
 40         if(!cnt[d[v]])
 41             tmp.push_back(d[v]);
 42         cnt[d[v]]++;
 43         dfs(v,u);
 44     }
 45 }
 46 int calc(int u,int dep)
 47 {
 48     d[u]=dep;
 49     tmp.push_back(d[u]);
 50     cnt[d[u]]++;
 51     dfs(u,0);
 52     int ret=0;
 53     for(auto i:tmp)
 54     {
 55         if(i<=k) ret+=cnt[k-i]*cnt[i];
 56         cnt[i]=0;
 57     }
 58     tmp.clear();
 59     return ret;
 60 }
 61 void divide(int now)
 62 {
 63     ans+=calc(now,0);
 64     used[now]=1;
 65     for(auto i:G[now])
 66     {
 67         int v=i.to,val=i.val;
 68         if(used[v]) continue;
 69         ans-=calc(v,val);
 70         rt=0;
 71         size=sz[v];
 72         getrt(v,0);
 73         divide(rt);
 74     }
 75 }
 76 void init()
 77 {
 78     f[0]=1e9+7;
 79     for(int i=1;i<=n;i++) used[i]=0;
 80     rt=ans=0;
 81     size=n;
 82 }
 83 int main()
 84 {
 85     #ifdef amori
 86     freopen("in.txt","r",stdin);
 87     #endif // amori
 88     int q;
 89     scanf("%d%d",&n,&q);
 90     for(int i=1;i<n;i++)
 91     {
 92         int u,v,w;
 93         scanf("%d%d%d",&u,&v,&w);
 94         G[u].push_back(edge(v,w));
 95         G[v].push_back(edge(u,w));
 96     }
 97     while(q--)
 98     {
 99         scanf("%d",&k);
100         init();
101         getrt(1,0);
102         divide(rt);
103         if(ans) puts("AYE");
104         else puts("NAY");
105     }
106 }
View Code

 

posted @ 2019-05-13 18:42  Amori  阅读(191)  评论(0编辑  收藏  举报