树的直径
浅总结一下树的直径。
题目描述
我们将一棵树 T=(V,E) 的直径定义为 max(u,v) (u,v∈V),也就是说,树中所有最短路径距离的最大值即为树的直径。
给定一棵树,树有 N 个结点、N−1 条边,请你求出该树的直径。
输入格式
第一行输入一个整数 N(2≤N≤50000),表示树的结点数量。
接下来有 N−1 行,每行有三个整数 U,V,W,表示结点 U 和节点 V 连有一条边,边的权值为 W 。其中 1≤U,V≤N;1≤W≤100。
输出格式
输出一个整数,即树的直径。
样例
输入数据 1
8
1 2 2
1 3 1
1 4 2
2 5 1
2 6 2
6 7 1
6 8 3
输出数据 1
9
求树的直径,有两种方法:
- 第一种方法,是用两次DFS;
- 第二种方法,是用树型DP。
如果边权是正的,只求直径,则两种方法都可以使用;
如果边权是正的,既要求直径,又要求输出路径,则只能用两次DFS的方法。
如果边权是负的,求直径,只能用树型DP的方法(见下面说明);
关于边是负权求直径的说明:
如果边有负权,用两次 DFS 求树直径的方法是要出错的,因为如果搜索起点的选择不同,那么得到的直径也会不同。
比如:特殊样例如下:

如果以1为起点会先搜索到 5,然后 5 到 3(或4)的路径长度即为树的直径,此时求得的直径长度为 1;
但是如果我们以 2 为起点开始搜索,会先搜索到 3 (或4),然后到 3 到 4(或4到3)的路径长度即为树的直径,此时求得的直径长度为 4 。
所以边有负权值的情况下建议不用两次 DFS 求直径的方式,而是用 DP 方法求树的直径
【求树的直径-模板1-两次DFS实现-不含路径-参考代码】:
#include<bits/stdc++.h> using namespace std; const int N=50005,M=50005*2; int n,m,t,p,ed; int d[N],First[N],to[M],w[M],Next[M]; void add(int x,int y,int z) //邻接表 { Next[++t]=First[x]; First[x]=t; to[t]=y; w[t]=z; } void dfs1(int u,int father) //从x出发,向下走(父结点是father) { if(d[u]>d[p]) p=u; //更新从起点1能走到的最远点p for(int i=First[u];i;i=Next[i]) //枚举从u出发的所有边i { int v=to[i]; //v是边i另外一个端点 if(v==father) continue; //不能往回走(即:不能回到父节点) d[v]=d[u]+w[i]; //起点1到v的距离 dfs1(v,u); //从v出发向下走 } } void dfs2(int u,int father) //从u出发,向下走(父结点是father) { if(d[u]>d[ed]) ed=u; //更新从起点p能走到的最远点ed for(int i=First[u];i;i=Next[i]) //枚举从u出发的所有边i { int v=to[i]; //v是边i另外一个端点 if(v==father) continue; //不能往回走(即:不能回到父节点) d[v]=d[u]+w[i]; //起点p到v的距离 dfs2(v,u); //从v出发向下走 } } int main() { //freopen("diameter.in","r",stdin); //freopen("diameter.out","w",stdout); int x,y,z; scanf("%d",&n); for(int i=1;i<=n-1;i++) { scanf("%d%d%d",&x,&y,&z); add(x,y,z); add(y,x,z); } dfs1(1,0); //第一次dfs,找到从结点1能到的最远结点p memset(d,0,sizeof(d)); dfs2(p,0); //第二次dfs,找到从结点p能到的最远结点ed,则p->ed就是直径 printf("%d\n",d[ed]); return 0; }
【求树的直径-模板1-两次DFS实现-输出路径-参考代码】:
#include<bits/stdc++.h> using namespace std; const int N=50005,M=50005*2; int n,m,t,p,ed; int d[N],First[N],to[M],w[M],Next[M],step[N]; //step[]存储路径 void add(int x,int y,int z)//邻接表 { Next[++t]=First[x]; First[x]=t; to[t]=y; w[t]=z; } void dfs1(int u,int father) //从u出发,向下走(父结点是father) { if(d[u]>d[p]) p=u; //更新从起点1能走到的最远点p for(int i=First[u];i;i=Next[i]) //枚举从u出发的所有边i { int v=to[i]; //v是边i另外一个端点 if(v==father) continue; //不能往回走(即:不能回到父节点) d[v]=d[u]+w[i]; //起点1到v的距离 dfs1(v,u); //从v出发向下走 } } void dfs2(int u,int father) //从u出发,向下走(父结点是father) { if(d[u]>d[ed]) ed=u; //更新从起点p能走到的最远点ed for(int i=First[u];i;i=Next[i]) //枚举从u出发的所有边i { int v=to[i]; //v是边i另外一个端点 if(v==father) continue; //不能往回走(即:不能回到父节点) step[v]=u; d[v]=d[u]+w[i]; //起点到v的距离 dfs2(v,u); //从v出发向下走 } } void PrintPath() //输出直径的路径 { int now=ed; stack<int>stk; do //倒着入栈 { stk.push(step[now]); now=step[now]; }while(now!=p); while(stk.size()) //出栈,输出路径 { cout<<stk.top()<<"->"; stk.pop(); } cout<<ed<<endl; } int main() { freopen("diameter.in","r",stdin); freopen("diameter.out","w",stdout); int x,y,z; scanf("%d",&n); for(int i=1;i<=n-1;i++) { scanf("%d%d%d",&x,&y,&z); add(x,y,z); add(y,x,z); } dfs1(1,0); //第一次dfs,找到从结点1能到的最远结点p memset(d,0,sizeof(d)); dfs2(p,0); //第二次dfs,找到从结点p能到的最远结点ed,则p->ed就是直径 printf("%d\n",d[ed]); PrintPath(); //输出路径 return 0; }
【求树的直径-模板2-树型DP实现-参考代码】:
#include<bits/stdc++.h> using namespace std; const int N=50005,M=50005; int n,m,t,ans; int f1[N],f2[N]; int h[N],to[M*2],w[M*2],Next[M*2]; void add(int x,int y,int z) //邻接表存储 { Next[++t]=h[x]; h[x]=t; to[t]=y; w[t]=z; } void dp(int u,int father) //树型DP(从结点u出发,其父结点为father) { for(int i=h[u];i;i=Next[i]) //枚举从u出发的所有边i { int v=to[i]; //to是u的子结点 if(v==father) continue; //如果往回枚举到u的父结点,放弃 dp(v,u); //从to继续向下走 if(f1[u]<f1[v]+w[i]) //如果起点到u的最大距离能被to更新 { f2[u]=f1[u]; //把起点到u的次大距离更新为当前的最大距离 f1[u]=f1[v]+w[i]; //更新起点到u的最大距离 } else if(f2[u]<f1[v]+w[i]) //如果起点到u的最大距离不能被to更新 f2[u]=f1[v]+w[i]; //更新起点到u的次大距离 ans=max(ans,f1[u]+f2[u]); //更新直径 } } int main() { //freopen("diameter.in","r",stdin); //freopen("diameter.out","w",stdout); scanf("%d",&n); int x,y,z; for(int i=1;i<=n-1;i++) { scanf("%d%d%d",&x,&y,&z); add(x,y,z); add(y,x,z); } dp(1,0); //树型DP入口(从结点1出发,其父结点为0) printf("%d\n",ans); return 0; }
自己的代码--Code:
#include<bits/stdc++.h> using namespace std; const int N=1e5+5; int n,m,tot,root,num; int fi[N],ne[N],to[N],w[N]; void add(int x,int y, int z) { ne[++tot]=fi[x]; fi[x]=tot; to[tot]=y; w[tot]=z; } void dfs(int x,int y,int z) { for(int i=fi[x];i;i=ne[i]) { int v=to[i]; if(v==y) continue; if(num<z+w[i]) { root=v; num=z+w[i]; } dfs(v,x,z+w[i]); } } int main() { cin>>n; for(int i=1;i<n;i++) { int a,b,c; cin>>a>>b>>c; add(a,b,c),add(b,a,c); } dfs(1,0,0); num=0; dfs(root,0,0); cout<<num<<'\n'; }

浙公网安备 33010602011771号