hdu4679(树形dp)

 

题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4679

题意:给一棵树,每条边上都有一个权值,去掉树上任意一条边之后,分成两个子树,两个子树的最长路与这条边上的权值相乘的到一个乘积。问去掉那一条边可以使这个乘积最小。

分析:求出树的直径,然后判断边是否树的直径上,如果是的话,ans=w*mx_len(mx_len为树的直径),否则从直径的两个端点出发dfs处理以每个节点为根节点时的直径及根节点到子树中的最长路和次长路。

#pragma comment(linker,"/STACK:102400000,102400000")
#include <cstdio>
#include <cstring>
#include <cmath>
#include <iostream>
#include <algorithm>
#include <queue>
#include <cstdlib>
#include <stack>
#include <vector>
#include <set>
#include <map>
#define LL long long
#define mod 1000000007
#define inf 0x3f3f3f3f
#define N 100010
#define FILL(a,b) (memset(a,b,sizeof(a)))
using namespace std;

struct edge
{
    int v,w,id,next;
    edge(){}
    edge(int v,int w,int id,int next):v(v),w(w),id(id),next(next){}
}e[2*N];
int head[N*2],tot,n,mx_len;
int dp[N][3];//dp[u][0]表示该节点到子树中的最长路,dp[u][1]表示次长路,dp[u][2]表示这颗树的直径
int path[N];//保存树的直径上的路径
int mark[N];//直径上边的标志
int ans[N],st,ed;//st,ed为树的直径的两个端点
void addedge(int u,int v,int w,int id)
{
    e[tot]=edge(v,w,id,head[u]);
    head[u]=tot++;
}
void dfs_len(int u,int fa,int len)
{
    if(len>=mx_len)mx_len=len,ed=u;
    for(int i=head[u];~i;i=e[i].next)
    {
        int v=e[i].v;
        if(v==fa)continue;
        path[v]=u;
        dfs_len(v,u,len+1);
    }
}
void dfs(int u,int fa)
{
    dp[u][0]=dp[u][1]=dp[u][2]=0;
    for(int i=head[u];~i;i=e[i].next)
    {
        int v=e[i].v;
        if(v==fa)continue;
        dfs(v,u);
        int tmp=dp[v][0]+1;
        if(tmp>dp[u][0])
        {
            dp[u][1]=dp[u][0];
            dp[u][0]=tmp;
        }
        else if(tmp>dp[u][1])
        {
            dp[u][1]=tmp;
        }
        dp[u][2]=max(dp[u][2],dp[v][2]);
    }
    dp[u][2]=max(dp[u][2],dp[u][0]+dp[u][1]);
}
void solve(int u,int fa)
{
    for(int i=head[u];~i;i=e[i].next)
    {
        int v=e[i].v,w=e[i].w,id=e[i].id;
        if(v==fa)continue;
        solve(v,u);
        if(mark[u]&&mark[v])ans[id]=max(ans[id],w*dp[v][2]);
        else ans[id]=w*mx_len;
    }
}
int main()
{
    int u,v,w,T,cas=1;
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d",&n);
        FILL(head,-1);FILL(mark,0);
        FILL(ans,0);tot=0;
        for(int i=1;i<n;i++)
        {
            scanf("%d%d%d",&u,&v,&w);
            addedge(u,v,w,i);
            addedge(v,u,w,i);
        }
        mx_len=0;dfs_len(1,-1,0);
        st=ed;dfs_len(st,-1,0);
        mark[st]=1;path[st]=-1;
        int tmp=ed;
        while(path[tmp]!=-1)
        {
            mark[tmp]=1;
            tmp=path[tmp];
        }
        dfs(st,-1);solve(st,-1);
        dfs(ed,-1);solve(ed,-1);
        int ret=1;
        for(int i=1;i<n;i++)
            if(ans[i]<ans[ret])ret=i;
        printf("Case #%d: %d\n",cas++,ret);
    }
}
View Code


后来发现直接预处理好每个点到达树的直径的两个端点的距离,就可以判断边是否在直径上了,且如果边在直径上,那么两个子树中的最大值那条路,一定是以整棵树的最长路的两个端点为起始点的。好像使用vector会超时。。。

#pragma comment(linker,"/STACK:102400000,102400000")
#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
using namespace std;
#define N 100010
#define inf 0x7fffffff

int n, s, t, len, id, ans;
int ds[N], dt[N];//记录每个点到树的直径的两个端点的距离
struct node {
    int v, w, id,next;
    node() {}
    node(int _v, int _w, int _id,int _next) : v(_v), w(_w), id(_id) ,next(_next){}
}e[N<<1];
int head[N<<1],tot;
void addedge(int u,int v,int w,int id)
{
    e[tot]=node(v,w,id,head[u]);
    head[u]=tot++;
}
void dfs(int now, int fa) {
    int u;
    for (int i=head[now];~i;i=e[i].next)
        if ((u = e[i].v) != fa) {
            ds[u] = ds[now] + 1;
            dfs(u, now);
        }
}
void dfs2(int now, int fa) {
    int u;
    for (int i=head[now];~i;i=e[i].next)
        if ((u = e[i].v) != fa) {
            dt[u] = dt[now] + 1;
            dfs2(u, now);
        }
}
void work(int now, int fa) {
    int u, w;
    for(int i=head[now];~i;i=e[i].next)
        if ((u = e[i].v) != fa) {
            if (ds[now] + 1 + dt[u] == len)
                w = e[i].w * max(ds[now], dt[u]);
            else w = e[i].w * len;
            if (w < ans) { ans = w, id = e[i].id; }
            else if (w == ans && e[i].id < id)
                id = e[i].id;
            work(u, now);
        }
}
int main() {

    int T,a,b,c;
    scanf("%d", &T);
    for (int cas=1; cas<=T; cas++) {
        scanf("%d", &n);
        memset(head,-1,sizeof(head));tot=0;
        for (int i=1; i<n; i++) {
            scanf("%d%d%d", &a, &b, &c);
            addedge(a,b,c,i);
            addedge(b,a,c,i);
        }
        ds[1] = 0;
        dfs(1, 0);
        ds[s=0] = 0; for (int i=1; i<=n; i++) if (ds[i] > ds[s]) s = i;

        ds[s] = 0; dfs(s, 0);
        t = 0; for (int i=1; i<=n; i++) if (ds[i] > ds[t]) t = i;
        len = ds[t];

        dt[t] = 0;
        dfs2(t, 0);

        id = ans = inf;
        work(s, 0);
        printf("Case #%d: %d\n", cas, id);
    }
    return 0;
}
View Code

 

posted on 2015-01-07 19:37  lienus  阅读(208)  评论(0编辑  收藏  举报

导航