CodeForces 592D Super M DP

Super M

题解:

定义 dp[u][0] 为遍历完u中的所有节点, 但不回到u点的路径花费值。

定义 dp[u][1] 为遍历完u中的所有节点, 且要回到u点的路径花费值。

 

转移方程.

dp[u][1] = sum(dp[v][1] + 2).

dp[u][0] = max(dp[v][1] + 2 - dp[v][0] - 1).

 

需要注意的是,不要把不需要走的路径值传递上来。

只有这个路径会遍历一个需要清除的点的时候,才可以转移状态。

 

这样从1dfs完之后,我们就可以计算出上面定义的状态的值。

 

然后我们反着dfs一遍,就可以算出从每个点出发清除完所有点的值。

 

代码:

#include<bits/stdc++.h>
using namespace std;
#define Fopen freopen("_in.txt","r",stdin); freopen("_out.txt","w",stdout);
#define LL long long
#define ULL unsigned LL
#define fi first
#define se second
#define pb push_back
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define lch(x) tr[x].son[0]
#define rch(x) tr[x].son[1]
#define max3(a,b,c) max(a,max(b,c))
#define min3(a,b,c) min(a,min(b,c))
typedef pair<int,int> pll;
const int inf = 0x3f3f3f3f;
const int _inf = 0xc0c0c0c0;
const LL INF = 0x3f3f3f3f3f3f3f3f;
const LL _INF = 0xc0c0c0c0c0c0c0c0;
const LL mod =  (int)1e9+7;
const int N = 2e5 + 100;
vector<int> vc[N];
LL dp[N][2];
/// 1->back   0-no-back
LL dif[N][2];
/// dif[0] > dif[1]
int vis[N];
int dfs(int o, int u){
    for(int v : vc[u]){
        if(o == v) continue;
        int f = dfs(u, v);
        if(!f) continue;
        LL t1 = dp[v][1] + 2;
        LL t2 = dp[v][0] + 1;
        t2 = t1 - t2;
        dp[u][1] += t1;
        if(dif[u][0] < t2) swap(t2, dif[u][0]);
        if(dif[u][1] < t2) swap(t2, dif[u][1]);
    }
    dp[u][0] = dp[u][1] - dif[u][0];
    if(dp[u][1] == 0 && !vis[u]) return 0;
    return 1;
}
LL ans = INF, ansid;
int fdfs(int o, int u){
    if(o){
        LL t1 = dp[o][1];
        if(dp[u][1] || vis[u]) t1 -= (dp[u][1] + 2);
        if(t1 == 0 && !vis[o]) ;
        else {
            LL t2;
            LL now_dif = 0;
            if(dp[u][1] || vis[u]) now_dif= (dp[u][1]+2) - (dp[u][0]+1);
            if(now_dif == dif[o][0]) t2 = t1 - dif[o][1];
            else t2 = t1 - dif[o][0];
            t2 += 1; t1 += 2;
            dp[u][1] += t1;
            now_dif = t1 - t2;
            if(dif[u][0] < now_dif) swap(dif[u][0], now_dif);
            if(dif[u][1] < now_dif) swap(dif[u][1], now_dif);
            dp[u][0] = dp[u][1] - dif[u][0];
        }
    }
    if(dp[u][0] < ans){
        ans = dp[u][0];
        ansid = u;
    }
    else if(dp[u][0] == ans && ansid > u) ansid = u;
    for(int v : vc[u]){
        if(o == v) continue;
        fdfs(u, v);
    }
    return 0;
}
int main(){
    int n, m;
    scanf("%d%d", &n, &m);
    int u, v;
    for(int i = 1; i < n; ++i){
        scanf("%d%d", &u, &v);
        vc[u].pb(v); vc[v].pb(u);
    }
    for(int i = 1; i <= m; ++i){
        scanf("%d", &u);
        vis[u] = 1;
    }
    dfs(0,1);
    fdfs(0,1);
    cout << ansid << "\n"<< ans << endl;
    return 0;
}
View Code

 

posted @ 2019-05-11 15:26  Schenker  阅读(200)  评论(0编辑  收藏  举报