树的直径

树的直径
树的直径是一道非常经典的关于树的直径的例题,这道题需要求直径的长度、直径的条数和直径的必经边
直径的长度
直径的长度可以使用一次DP来求解,也可以使用两次BFS来求解
下面给出BFS的求解方法
int bfs(int root){
int node=root;
queue<int>q;
memset(vis,0,sizeof(vis));
memset(dis,0,sizeof(dis));
vis[root]=1;
q.push(root);
while(q.size()){
int x=q.front();q.pop();
vis[x]=1;
for(int i=head[x];~i;i=e[i].next){
int y=e[i].to;
if(vis[y])continue;
dis[y]=dis[x]+e[i].w;
vis[y]=1;
pre[y]=x;
if(dis[y]>dis[node])node=y;
q.push(y);
}
}
return node;
}
直径的条数
直径的条数需要使用dp来求解,它使用距离来进行求解,dp执行完后,既可以得到直径的长度,也可以得到直径的条数
下面给出代码
void dp(int x,int fa){
ll dist=0;
d[x]=0;node[x]=1;
for(int i=head[x];~i;i=e[i].next){
int y=e[i].to;
if(y==fa)continue;
dp(y,x);
dist=d[y]+e[i].w;
if(dist+d[x]>fx){
num=node[x]*node[y];
fx=dist+d[x];
}else if(dist+d[x]==fx)num+=node[x]*node[y];
if(d[x]<dist){
d[x]=dist;
node[x]=node[y];
}else if(d[x]==dist)node[x]+=node[y];
}
}
直径的必经边
直径的必经边需要使用一个pre数组和一个g数组,来分别存储前缀节点与后缀节点,先遍历一遍pre数组,再遍历一遍g数组,统计必经边的条数,在遍历g数组时,有终止条件,这个终止条件是:当dis[last]-dis[x]==d[x]时,终止
下面给出代码
void dfs(int x){
vis[x]=1;
for(int i=head[x];~i;i=e[i].next){
int y=e[i].to;
if(vis[y])continue;
dfs(y);
d[x]=max(d[y]+e[i].w,d[x]);
}
}
for(x=last;x!=-1;x=pre[x])vis[x]=1;
for(x=last;x!=-1;x=pre[x])dfs(x);
L=x;
for(x=last;x!=-1;x=pre[x]){
g[pre[x]]=x;
if(dis[x]==d[x]){
L=x;
break;
}
}
for(x=L;x!=last;x=g[x]){
if(d[x]==dis[last]-dis[x])break;
ans++;
}
最后给出总的代码
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N = 2e5+39+7,M = 4*N;
struct edge{
ll next,to,w;
}e[M];
ll mp[N],n,head[N],cnt=-1,num=0,fx=0,Max[N],dis[N],vis[N],node[N],pre[N],g[N],d[N],ans=0;
void add(int u,int v,int w){
e[++cnt]=(edge){head[u],v,w};
head[u]=cnt;
}
int bfs(int root){
int node=root;
queue<int>q;
memset(vis,0,sizeof(vis));
memset(dis,0,sizeof(dis));
vis[root]=1;
q.push(root);
while(q.size()){
int x=q.front();q.pop();
vis[x]=1;
for(int i=head[x];~i;i=e[i].next){
int y=e[i].to;
if(vis[y])continue;
dis[y]=dis[x]+e[i].w;
vis[y]=1;
pre[y]=x;
if(dis[y]>dis[node])node=y;
q.push(y);
}
}
return node;
}
void dfs(int x){
vis[x]=1;
for(int i=head[x];~i;i=e[i].next){
int y=e[i].to;
if(vis[y])continue;
dfs(y);
d[x]=max(d[y]+e[i].w,d[x]);
}
}
void dp(int x,int fa){
ll dist=0;
d[x]=0;node[x]=1;
for(int i=head[x];~i;i=e[i].next){
int y=e[i].to;
if(y==fa)continue;
dp(y,x);
dist=d[y]+e[i].w;
if(dist+d[x]>fx){
num=node[x]*node[y];
fx=dist+d[x];
}else if(dist+d[x]==fx)num+=node[x]*node[y];
if(d[x]<dist){
d[x]=dist;
node[x]=node[y];
}else if(d[x]==dist)node[x]+=node[y];
}
}
int main(){
memset(head,-1,sizeof(head));cnt=-1;
memset(node,1,sizeof(node));
memset(pre,-1,sizeof(pre));
memset(g,-1,sizeof(g));
cin>>n;
for(int i=1;i<n;i++){
ll a,b,c;cin>>a>>b>>c;
add(a,b,c);
add(b,a,c);
}
ll start=bfs(1);
ll last=bfs(start);
int x,L,ans=0;
pre[start]=-1;
dp(1,0);
memset(vis,0,sizeof(vis));
memset(d,0,sizeof(d));
for(x=last;x!=-1;x=pre[x])vis[x]=1;
for(x=last;x!=-1;x=pre[x])dfs(x);
L=x;
for(x=last;x!=-1;x=pre[x]){
g[pre[x]]=x;
if(dis[x]==d[x]){
L=x;
break;
}
}
for(x=L;x!=last;x=g[x]){
if(d[x]==dis[last]-dis[x])break;
ans++;
}
cout<<fx<<'\n'<<ans<<'\n'<<num;
return 0;
}

浙公网安备 33010602011771号