[bzoj3697]采药人的路径

首先将0变为-1,平衡的路径即该路径边权和为0,然后统计过重心且满足条件的路径数。

不断将统计当前这颗子树到之前所有子树的路径并合并,统计时,需要处理出两个桶a[i][j]b[i][j],表示之前的子树/当前子树深度为i,到重心的路径中有(j=1)/没有(j=0)深度为0的节点(不能是根或本身)的节点数。

可以发现$ans+=(a[0][0]-1)\cdot b[0][0]+\sum\limits_{x+y>0}a[i][x]\cdot b[-i][y]$,这样统计的复杂度就是树的深度,因为点分治深度都是log,所以直接暴力统计即可(要开long long

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 #define N 100005
 4 struct ji{
 5     int nex,to,len;
 6 }edge[N<<1];
 7 int E,r,n,m,x,y,z,t[N<<1],a[N<<1][2],b[N<<1][2],head[N],vis[N],sz[N];
 8 long long ans;
 9 void add(int x,int y,int z){
10     edge[E].nex=head[x];
11     edge[E].to=y;
12     edge[E].len=z;
13     head[x]=E++;
14 }
15 void tot(int k,int fa,int sh,int sh2){
16     x=max(x,sh2);
17     b[n+sh][(t[n+sh]>0)]++;
18     t[n+sh]++;
19     for(int i=head[k];i!=-1;i=edge[i].nex)
20         if ((!vis[edge[i].to])&&(edge[i].to!=fa))
21             tot(edge[i].to,k,sh+edge[i].len,sh2+1);
22     t[n+sh]--;
23 }
24 void merge(){
25     ans+=1LL*(a[n][0]-1)*b[n][0];
26     for(int i=-x;i<=x;i++)
27         for(int j=0;j<2;j++)
28             for(int k=j^1;k<2;k++)ans+=1LL*a[n-i][j]*b[n+i][k];
29     for(int i=-x;i<=x;i++)
30         for(int j=0;j<2;j++){
31             a[n+i][j]+=b[n+i][j];
32             b[n+i][j]=0;
33         }
34 }
35 void get(int k,int fa){
36     int ma=0;
37     sz[k]=1;
38     for(int i=head[k];i!=-1;i=edge[i].nex)
39         if ((!vis[edge[i].to])&&(edge[i].to!=fa)){
40             get(edge[i].to,k);
41             sz[k]+=sz[edge[i].to];
42             ma=max(ma,sz[edge[i].to]);
43         }
44     ma=max(ma,sz[0]-sz[k]);
45     if (ma<=sz[0]/2)r=k;
46 }
47 void dfs(int k){
48     int ms=0;
49     vis[k]=a[n][0]=1;
50     for(int i=head[k];i!=-1;i=edge[i].nex)
51         if (!vis[edge[i].to]){
52             tot(edge[i].to,0,edge[i].len,x=1);
53             ms=max(ms,x);
54             merge();
55         }
56     for(int i=-ms;i<=ms;i++)a[n+i][0]=a[n+i][1]=0;
57     get(k,0);
58     for(int i=head[k];i!=-1;i=edge[i].nex)
59         if (!vis[edge[i].to]){
60             sz[0]=sz[edge[i].to];
61             get(edge[i].to,0);
62             dfs(r);
63         }
64 }
65 int main(){
66     scanf("%d",&n);
67     memset(head,-1,sizeof(head));
68     for(int i=1;i<n;i++){
69         scanf("%d%d%d",&x,&y,&z);
70         if (!z)z--;
71         add(x,y,z);
72         add(y,x,z);
73     }
74     sz[0]=n;
75     get(1,0);
76     dfs(r);
77     printf("%lld",ans);
78 }
View Code

 

posted @ 2019-07-28 10:36  PYWBKTDA  阅读(118)  评论(0编辑  收藏  举报