2019南昌邀请赛网络预选赛 J.Distance on the tree(树链剖分)

传送门

 

题意:

  给出一棵树,每条边都有权值;

  给出 m 次询问,每次询问有三个参数 u,v,w ,求节点 u 与节点 v 之间权值 ≤ w 的路径个数;

题解:

  昨天再打比赛的时候,中途,凯少和我说,这道题,一眼看去,就是树链剖分,然鹅,太久没写树链剖分的我一时也木有思路;

  今天上午把树链剖分温习了一遍,做了个模板题;

  下午再想了一下这道题,思路喷涌而出............

  首先,介绍一下相关变量:

1 int fa[maxn];//fa[u]:u的父节点
2 int son[maxn];//son[u]:u的重儿子
3 int dep[maxn];//dep[u]:u的深度
4 int siz[maxn];//siz[u]:以u为根的子树节点个数
5 int tid[maxn];//tid[u]:u在线段树中的位置
6 int top[maxn];//top[u]:u所在重链的祖先节点
7 int e[maxn][3];//e[i][0]与e[i][1]有条权值为e[i][2]的边
8 vector<int >v[maxn<<2];//v[i]:存储线段树中i号节点的所有边的权值

  (树链剖分,默认来看这篇博客的都会辽,逃)  

  下面重点介绍一下v[]的作用(将样例2中的权值改为了10):

  

  由树链剖分可知(图a,紫色部分代表重链)

    tid[1]=1,tid[3]=2,tid[5]=3;

    tid[2]=4,tid[4]=5;

  那么,线段树维护啥呢?

1 struct SegmentTree
2 {
3     int l,r;
4     int mid()
5     {
6         return l+((r-l)>>1);
7     }
8 }segTree[maxn<<2];
9 vector<int >v[maxn<<2];//v[i]:存储线段树中i号节点的所有边的权值

  对于我而言,此次线段树,主要维护节点 i 的左右区间[l,r],重点是 v[] 中维护的东西;

  首先将边权存到线段树中,如何存呢?

  对于边 u,v,w ,(假设 fa[v]=u),将 w 存在 v[ tid[ v ] ]中;

  看一下Update()函数:

 1 //将节点x在线段树中对应的pos位置的v中加入val
 2 void Update(int x,int val,int pos)
 3 {
 4     if(segTree[pos].l == segTree[pos].r)
 5     {
 6         v[pos].push_back(val);//val加入到v[pos]中
 7         return ;
 8     }
 9     int mid=segTree[pos].mid();
10     if(x <= mid)
11         Update(x,val,ls(pos));
12     else
13         Update(x,val,rs(pos));
14 }

  例如上图b:

  ①-② : 10 ,调用函数Update(tid[2],10,1)  v[tid[2]].push_back(10)

  ①- : 10 ,调用函数Update(tid[3],10,1)  v[tid[3]].push_back(10)

  -④ : 10 ,调用函数Update(tid[4],10,1)  v[tid[4]].push_back(10)

  ③- : 10 ,调用函数Update(tid[5],10,1)  v[tid[5]].push_back(10)

  线段树中的节点9中的v存储一个10

  线段树中的节点5中的v存储一个10

  线段树中的节点6中的v存储一个10

  线段树中的节点7中的v存储一个10

  这个就是Update()函数的作用;

  接下来的pushUp()函数很重要:

 1 void pushUp(int pos)
 2 {
 3     if(segTree[pos].l == segTree[pos].r)
 4         return;
 5 
 6     pushUp(ls(pos));
 7     pushUp(rs(pos));
 8 
 9     //将ls(pos),rs(pos)中的元素存储到pos中
10     for(int i=0;i < v[ls(pos)].size();++i)
11         v[pos].push_back(v[ls(pos)][i]);
12     for(int i=0;i < v[rs(pos)].size();++i)
13         v[pos].push_back(v[rs(pos)][i]);
14     sort(v[pos].begin(),v[pos].end());//升序排列
15 }

  调用pushUp(1),将所有的pos 的 ls(pos),rs(pos) 节点信息更新到pos节点;

  调用完这个函数后,你会发现:

  v[1]:10,10,10,10([1,5]中的所有节点到其父节点的权值,根节点为null)

  v[2]:10,10([1,3]中的所有节点到其父节点的权值)

  v[3]:10,10([4,5]中的所有节点到其父节点的权值)

  v[4]:10([1,2]中的所有节点到其父节点的权值)

  v[5]:10([3,3]中的所有节点到其父节点的权值)

  v[6]:10([4,4]中的所有节点到其父节点的权值)

  v[7]:10([5,5]中的所有节点到其父节点的权值)

  v[8]:null(根节点为null)

  v[9]:10([2,2]中的所有节点到其父节点的权值)

  你会发现,v[i]中存的值就是[ tree[i].l , tree[i].r ]中所有节点与其父节点的权值;

  接下来就是询问操作了:

 1 int BS(int pos,int w)
 2 {
 3     int l=-1,r=v[pos].size();
 4     while(r-l > 1)
 5     {
 6         int mid=l+((r-l)>>1);
 7         if(v[pos][mid] <= w)
 8             l=mid;
 9         else
10             r=mid;
11     }
12     return l+1;
13 }
14 int Query(int l,int r,int pos,int w)
15 {
16     if(v[pos][0] > w)//当前区间的如果最小的值要 > w,直接返回0
17         return 0;
18     if(segTree[pos].l == l && segTree[pos].r == r)
19         return BS(pos,w);//二分查找pos区间值 <= w 得个数(还记得pushUp()中的sort函数么?
20 
21     int mid=segTree[pos].mid();
22     if(r <= mid)
23         return Query(l,r,ls(pos),w);
24     else if(l > mid)
25         return Query(l,r,rs(pos),w);
26     else
27         return Query(l,mid,ls(pos),w)+Query(mid+1,r,rs(pos),w);
28 }

AC代码:

  1 #include<bits/stdc++.h>
  2 using namespace std;
  3 #define ls(x) (x<<1)
  4 #define rs(x) (x<<1|1)
  5 #define INF 0x3f3f3f3f
  6 #define mem(a,b) memset(a,b,sizeof(a))
  7 const int maxn=1e5+50;
  8 
  9 int n,m;
 10 int fa[maxn];//fa[u]:u的父节点
 11 int son[maxn];//son[u]:u的重儿子
 12 int dep[maxn];//dep[u]:u的深度
 13 int siz[maxn];//siz[u]:以u为根的子树节点个数
 14 int tid[maxn];//tid[u]:u在线段树中的位置
 15 int top[maxn];//top[u]:u所在重链的祖先节点
 16 int e[maxn][3];//e[i][0]与e[i][1]有条权值为e[i][2]的边
 17 vector<int >v[maxn<<2];//v[i]:存储线段树中i号节点的所有边的权值
 18 int num;
 19 int head[maxn];
 20 struct Edge
 21 {
 22     int to;
 23     int w;
 24     int next;
 25 }G[maxn<<1];
 26 void addEdge(int u,int v,int w)
 27 {
 28     G[num].to=v;
 29     G[num].w=w;
 30     G[num].next=head[u];
 31     head[u]=num++;
 32 }
 33 struct SegmentTree
 34 {
 35     int l,r;
 36     int mid()
 37     {
 38         return l+((r-l)>>1);
 39     }
 40 }segTree[maxn<<2];
 41 void DFS1(int u,int f,int depth)
 42 {
 43     fa[u]=f;
 44     son[u]=-1;
 45     siz[u]=1;
 46     dep[u]=depth;
 47     for(int i=head[u];~i;i=G[i].next)
 48     {
 49         int v=G[i].to;
 50         if(v == f)
 51             continue;
 52         DFS1(v,u,depth+1);
 53 
 54         siz[u] += siz[v];
 55 
 56         if(son[u] == -1 || siz[v] > siz[son[u]])
 57             son[u]=v;
 58     }
 59 }
 60 void DFS2(int u,int anc,int &k)
 61 {
 62     top[u]=anc;
 63     tid[u]=++k;
 64     if(son[u] == -1)
 65         return ;
 66     DFS2(son[u],anc,k);
 67 
 68     for(int i=head[u];~i;i=G[i].next)
 69     {
 70         int v=G[i].to;
 71         if(v != fa[u] && v != son[u])
 72             DFS2(v,v,k);
 73     }
 74 }
 75 void pushUp(int pos)
 76 {
 77     if(segTree[pos].l == segTree[pos].r)
 78         return;
 79 
 80     pushUp(ls(pos));
 81     pushUp(rs(pos));
 82 
 83     //将ls(pos),rs(pos)中的元素存储到pos中
 84     for(int i=0;i < v[ls(pos)].size();++i)
 85         v[pos].push_back(v[ls(pos)][i]);
 86     for(int i=0;i < v[rs(pos)].size();++i)
 87         v[pos].push_back(v[rs(pos)][i]);
 88     sort(v[pos].begin(),v[pos].end());//升序排列
 89 }
 90 void buildSegTree(int l,int r,int pos)
 91 {
 92     segTree[pos].l=l;
 93     segTree[pos].r=r;
 94     if(l == r)
 95         return ;
 96 
 97     int mid=l+((r-l)>>1);
 98     buildSegTree(l,mid,ls(pos));
 99     buildSegTree(mid+1,r,rs(pos));
100 }
101 //将节点x在线段树中对应的pos位置的v中加入val
102 void Update(int x,int val,int pos)
103 {
104     if(segTree[pos].l == segTree[pos].r)
105     {
106         v[pos].push_back(val);//val加入到v[pos]中
107         return ;
108     }
109     int mid=segTree[pos].mid();
110     if(x <= mid)
111         Update(x,val,ls(pos));
112     else
113         Update(x,val,rs(pos));
114 }
115 int BS(int pos,int w)
116 {
117     int l=-1,r=v[pos].size();
118     while(r-l > 1)
119     {
120         int mid=l+((r-l)>>1);
121         if(v[pos][mid] <= w)
122             l=mid;
123         else
124             r=mid;
125     }
126     return l+1;
127 }
128 int Query(int l,int r,int pos,int w)
129 {
130     if(v[pos][0] > w)//当前区间的如果最小的值要 > w,直接返回0
131         return 0;
132     if(segTree[pos].l == l && segTree[pos].r == r)
133         return BS(pos,w);//二分查找pos区间值 <= w 得个数(还记得pushUp()中的sort函数么?
134 
135     int mid=segTree[pos].mid();
136     if(r <= mid)
137         return Query(l,r,ls(pos),w);
138     else if(l > mid)
139         return Query(l,r,rs(pos),w);
140     else
141         return Query(l,mid,ls(pos),w)+Query(mid+1,r,rs(pos),w);
142 }
143 int Find(int u,int v,int w)//查询节点u到节点v之间权值小于等于w得路径个数
144 {
145     int ans=0;
146     int topU=top[u];
147     int topV=top[v];
148     while(topU != topV)
149     {
150         if(dep[topU] > dep[topV])
151         {
152             swap(u,v);
153             swap(topU,topV);
154         }
155         ans += Query(tid[top[v]],tid[v],1,w);
156         v=fa[topV];
157         topV=top[v];
158     }
159     if(u == v)
160         return ans;
161     if(dep[u] > dep[v])
162         swap(u,v);
163     return ans+Query(tid[son[u]],tid[v],1,w);
164 }
165 void Solve()
166 {
167     DFS1(1,1,1);
168     int k=0;
169     DFS2(1,1,k);
170 
171     buildSegTree(1,k,1);
172 
173     for(int i=1;i < n;++i)
174     {
175         if(dep[e[i][0]] > dep[e[i][1]])
176             swap(e[i][0],e[i][1]);//令fa[e[i][1]] = e[i][0],方便更新操作
177         Update(tid[e[i][1]],e[i][2],1);//将e[i][2]加入到tid[e[i][1]]中
178     }
179     pushUp(1);//更新线段树中所有的pos
180 
181     for(int i=1;i <= m;++i)
182     {
183         int u,v,w;
184         scanf("%d%d%d",&u,&v,&w);
185         printf("%d\n",Find(u,v,w));
186     }
187 }
188 void Init()
189 {
190     num=0;
191     mem(head,-1);
192     for(int i=0;i < 4*maxn;++i)
193         v[i].clear();
194 }
195 int main()
196 {
197 //    freopen("C:\\Users\\hyacinthLJP\\Desktop\\in&&out\\contest","r",stdin);
198     while(~scanf("%d%d",&n,&m))
199     {
200         Init();
201         for(int i=1;i < n;++i)
202         {
203             scanf("%d%d%d",e[i]+0,e[i]+1,e[i]+2);
204             addEdge(e[i][0],e[i][1],e[i][2]);
205             addEdge(e[i][1],e[i][0],e[i][2]);
206         }
207         Solve();
208     }
209     return 0;
210 }
View Code

posted @ 2019-04-21 18:55  HHHyacinth  阅读(306)  评论(0编辑  收藏  举报