bzoj 1912 tree_dp

  这道题我们加一条路可以减少的代价为这条路两端点到lca的路径的长度,相当于一条链,那么如果加了两条链的话,这两条链重复的部分还是要走两遍,反而对答案没有了贡献(其实这个可以由任意两条链都可以看成两条不重叠的链来证明),那么这道题k=2的时候就转化为了求出树上两条链,使得两条链不重叠的长度最大,那么答案就是(n-1)<<1-SumLen+2.当k=1的时候我们直接求出来树的最长链然后减去就好了,这个在此不再赘述。

  对于树上两链不重复部分最大我们是可以tree_dp的,设w[i][0..4]来表示当前以i为根的子树中选取了0/1/2条链的最大值,同时我们保留了一个3,4来记录以i为一端点的最长链,同时选取了0/1条最长链的最大值,这样直接转移就好了。

  我写的是另外一种方法,先找出最长链,然后将最长链上的边长设为-1,然后再找一次最长链,这样求出来的就是答案。

  反思:开始没意识到第二次最长链不能用两边bfs,所以果断的写了bfs,后来才发现的,又临时加了一个tree_dp,因为加的路必须选,所以我们要将每个点的最长和次长链设为-inf,叶子节点的为0,然后用非叶子节点更新答案,然后竟然1A,真是感动= =。

/**************************************************************
    Problem: 1912
    User: BLADEVIL
    Language: C++
    Result: Accepted
    Time:1268 ms
    Memory:5884 kb
****************************************************************/
 
//By BLADEVIL
#include <cstdio>
#include <cstring>
#include <algorithm>
#define maxn 100010
#define maxm 200020
#define inf (~0U>>1)
 
using namespace std;
 
int n,k,l;
int pre[maxm],other[maxm],last[maxn],len[maxm];
int que[maxn],dis[maxn],father[maxn],flag[maxn],max_1[maxn],max_2[maxn];
 
void connect(int x,int y) {
    pre[++l]=last[x];
    last[x]=l;
    other[l]=y;
    len[l]=1;
}
 
void bfs(int x) {
    memset(que,0,sizeof que);
    memset(dis,0,sizeof dis);
    memset(father,0,sizeof father);
    memset(flag,0,sizeof flag);
    int h=0,t=1;
    que[1]=x; dis[x]=1; flag[x]=1;
    while (h<t) {
        int cur=que[++h];
        for (int p=last[cur];p;p=pre[p]) {
            if (flag[other[p]]) continue;
            father[other[p]]=p;
            dis[other[p]]=dis[cur]+len[p];
            flag[other[p]]=1;
            que[++t]=other[p];
        }
    }
}
 
int tree_dp() {
    int ans=-inf;
    memset(que,0,sizeof que);
    memset(flag,0,sizeof flag);
    memset(dis,0,sizeof dis);
    memset(max_1,-128,sizeof max_1);
    memset(max_2,-128,sizeof max_2);
    int h=0,t=1;
    que[1]=1; flag[1]=1; dis[1]=1;
    while (h<t) {
        int cur=que[++h];
        for (int p=last[cur];p;p=pre[p]) {
            if (flag[other[p]]) continue;
            que[++t]=other[p]; flag[other[p]]=1; dis[other[p]]=dis[cur]+1;
        }
    }
    //for (int i=1;i<=n;i++) printf("%d ",que[i]); printf("\n");
    for (int i=n;i;i--) {
        int cur=que[i];
        for (int p=last[cur];p;p=pre[p]) {
            if (dis[other[p]]<dis[cur]) continue;
            if (max_1[other[p]]+len[p]>max_1[cur])
                max_2[cur]=max_1[cur],max_1[cur]=max_1[other[p]]+len[p]; else
            if (max_1[other[p]]+len[p]>max_2[cur])
                max_2[cur]=max_1[other[p]]+len[p];
        }
        if (max_1[cur]<-100000000) max_1[cur]=max_2[cur]=0; else ans=max(ans,max(max_1[cur]+max_2[cur],max_1[cur]));
    }
    //for (int i=1;i<=n;i++) printf("|%d %d\n",max_1[i],max_2[i]);
    return ans;
}
 
int getmax() {
    int s=0;
    for (int i=1;i<=n;i++) if (dis[i]>dis[s]) s=i;
    return s;
}
 
int main() {
    scanf("%d%d",&n,&k); l=1;
    for (int i=1;i<n;i++) {
        int x,y; scanf("%d%d",&x,&y);
        connect(x,y); connect(y,x);
    }
    bfs(1); bfs(getmax());
    if (k==1) {
        printf("%d\n",2*n-dis[getmax()]);
        return 0;
    }
    int cur=getmax(),ans=dis[cur]-2;
    while (father[cur]) len[father[cur]]=len[father[cur]^1]=-1,cur=other[father[cur]^1];
    ans+=tree_dp()-1;
    printf("%d\n",2*n-2-ans);
    return 0;
}

 

posted on 2014-04-08 18:27  BLADEVIL  阅读(242)  评论(0编辑  收藏  举报