CCPC-Wannafly Winter Camp Day1 流流流动 (树形dp)

题目描述

 

喜欢数学的wlswls最近被萎住了。

现在他一共有1...n1...n这么多数字,取数字ii会得到f[i]f[i]的收益。数字之间有些边,对于所有的i(i != 1)i(i!=1),若ii为奇数,则ii与3i+13i+1之间有边,否则ii与i/2i/2之间有边。如果一条边的两个顶点xyxy都被取了,那么会失去d[min(x, y)]d[min(x,y)]的价值。请问wlswls怎么取,才能使得收益最大?

 

 
 

输入描述

 

第一行一个整数nn。

接下来一行nn个整数表示ff。

接下来一行nn个整数表示dd。

1 \leq n \leq 1001n100

1 \leq f[i], d[i] \leq 10001f[i],d[i]1000

 

输出描述

 

输出一个整数表示答案。

 

样例输入 1 

2
10 10 
1 2

样例输出 1

19


思路:
根据题目给的建边条件,建边后会形成一个森林,然后把森林转化为一个0为根节点的树,随后进行树形dp。
定义状态:
dp[u][0/1] 0为 第u个节点的子树中不取第u个节点的最多利益值,
       1为第u个节点的子树中取第u个节点的最多利益值,
常规的树形dp套路,
dp[u][0]+=max(dp[v][0],dp[v][1]);
dp[u][1]+=max(dp[v][0],dp[v][1]-d[min(u,v)]);// 一个边上两个节点都取的话,要减去对应的值。
最后max(dp[0][0],dp[0][1])就是我们的答案值。
细节见代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <vector>
#include <iomanip>
#define ALL(x) (x).begin(), (x).end()
#define rt return
#define dll(x) scanf("%I64d",&x)
#define xll(x) printf("%I64d\n",x)
#define sz(a) int(a.size())
#define all(a) a.begin(), a.end()
#define rep(i,x,n) for(int i=x;i<n;i++)
#define repd(i,x,n) for(int i=x;i<=n;i++)
#define pii pair<int,int>
#define pll pair<long long ,long long>
#define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define MS0(X) memset((X), 0, sizeof((X)))
#define MSC0(X) memset((X), '\0', sizeof((X)))
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define eps 1e-6
#define gg(x) getInt(&x)
#define db(x) cout<<"== [ "<<x<<" ] =="<<endl;
using namespace std;
typedef long long ll;
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
ll powmod(ll a, ll b, ll MOD) {ll ans = 1; while (b) {if (b % 2)ans = ans * a % MOD; a = a * a % MOD; b /= 2;} return ans;}
inline void getInt(int* p);
const int maxn = 1010;
const int inf = 0x3f3f3f3f;
/*** TEMPLATE CODE * * STARTS HERE ***/

int pre[maxn];
int f[maxn];
int d[maxn];
int dp[maxn][2];
int n;
int findpar(int x)
{
    return pre[x] == 0 ? x : pre[x] = findpar(pre[x]);
}
void mer(int x, int y)
{
    x = findpar(x);
    y = findpar(y);
    if (x != y)
    {
        pre[x] = y;
    }
}
std::vector<int> v[maxn];
int w;
void dfs(int u, int pre)
{
    // cout<<u<<" "<<pre<<endl;
    dp[u][0] = 0;
    dp[u][1] = f[u];
    for (auto x : v[u])
    {
        if (x != pre)
        {
            dfs(x, u);
            dp[u][0] += max(dp[x][0], dp[x][1]);
            dp[u][1] += max(dp[x][0], dp[x][1] - d[min(u, x)]);
        }
    }
}
int main()
{
    //freopen("D:\\common_text\\code_stream\\in.txt","r",stdin);
    //freopen("D:\\common_text\\code_stream\\out.txt","w",stdout);
    gbtb;
    cin >> n;
    repd(i, 1, n)
    {
        cin >> f[i];
    }
    repd(i, 1, n)
    {
        cin >> d[i];
    }
    repd(i, 2, n)
    {
        if (i & 1)
        {
            if (3 * i + 1 <= n)
            {
                v[i].push_back(3 * i + 1);
                v[3 * i + 1].push_back(i);
                mer(i, 3 * i + 1);
            }
        } else
        {
            v[i].push_back(i / 2);
            v[i / 2].push_back(i);
            mer(i, i / 2);
        }
    }
    repd(i, 1, n)
    {
        if (pre[i] == 0)
        {
            v[0].push_back(i);
            v[i].push_back(0);
        }
    }
    dfs(0, 0);
    cout << max(dp[0][0], dp[0][1]) << endl;

    return 0;
}

inline void getInt(int* p) {
    char ch;
    do {
        ch = getchar();
    } while (ch == ' ' || ch == '\n');
    if (ch == '-') {
        *p = -(getchar() - '0');
        while ((ch = getchar()) >= '0' && ch <= '9') {
            *p = *p * 10 - ch + '0';
        }
    }
    else {
        *p = ch - '0';
        while ((ch = getchar()) >= '0' && ch <= '9') {
            *p = *p * 10 + ch - '0';
        }
    }
}

 





posted @ 2019-05-22 17:06  茄子Min  阅读(380)  评论(0编辑  收藏  举报