【做题笔记】树形 dp

Solve

设计状态 \(dp[i]\) 表示子树 \(i\) 的最大点权和,则有:

  1. \(dp[son[i]] > 0\) 时,选以 \(son[i]\) 为根的子树肯定优;
  2. \(dp[son[i]] < 0\) 时,选以 \(son[i]\) 为根的子树肯定不优;

因此,转移方程为:

$dp[i] = \sum\limits_{a[son[i]]>0} dp[son[i]] + a[i]$

时间复杂度 \(O(n)\)

答案为 \(\max\limits_{1\le i \le n} dp[i]\)

Code

#include <bits/stdc++.h>
#define ll long long
#define H 19260817
#define rint register int
#define For(i,l,r) for(rint i=l;i<=r;++i)
#define FOR(i,r,l) for(rint i=r;i>=l;--i)
#define MOD 1000003
#define mod 1000000007

using namespace std;

inline int read() {
  rint x=0,f=1;char ch=getchar();
  while(ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
  while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
  return x*f;
}

void print(int x){
  if(x<0){putchar('-');x=-x;}
  if(x>9){print(x/10);putchar(x%10+'0');}
  else putchar(x+'0');
  return;
}

const int N = 16100;

int n, a[N], f[N], ans = -0x3f3f3f3f;

vector<int> e[N];

void dfs(int x, int fa) {
  f[x] = a[x];
  for (int i = 0; i < e[x].size(); i++) { 
    int y = e[x][i];
    if(y == fa) continue;
    dfs(y, x);
    if(f[y] > 0) f[x] += f[y];
  } 
}

signed main() {
  n = read();
  For(i,1,n) a[i] = read();
  For(i,1,n-1) {
    int u = read(), v = read();
    e[u].push_back(v);
    e[v].push_back(u);
  }
  dfs(1, 0);
  For(i,1,n) ans = max(ans, f[i]);
  cout << ans << '\n';
  return 0;
}

Solve

设计状态 \(dp[i][0/1]\) 表示在 \(i\) 子树内, 放/不放\(i\) 个节点使其合法所需的最少的士兵数目。则有:

  1. 不选 \(i\) 节点,则 \(i\) 的儿子必须选;
  2. \(i\) 节点,则 \(i\) 的儿子可选可不选;

因此,转移方程为:

$dp[i][0] = \sum dp[son[i]][1]$

\(dp[i][1] = \sum \min(dp[son[i]][0], dp[son[i]][1])\)

时间复杂度 \(O(n)\)

答案为 \(min(dp[0][0], dp[0][1])\)。(以 \(0\) 为根)

Code

#include <bits/stdc++.h>
#define int long long
#define H 19260817
#define rint register int
#define For(i,l,r) for(rint i=l;i<=r;++i)
#define FOR(i,r,l) for(rint i=r;i>=l;--i)
#define MOD 1000003
#define mod 1000000007

using namespace std;

inline int read() {
  rint x=0,f=1;char ch=getchar();
  while(ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
  while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
  return x*f;
}

void print(int x){
  if(x<0){putchar('-');x=-x;}
  if(x>9){print(x/10);putchar(x%10+'0');}
  else putchar(x+'0');
  return;
}

const int N = 1e5;

vector<int> e[N];

int n, f[N][2]; 

void dfs(int x, int fa) {
  f[x][0] = f[x][1] = 0;
  for (int i = 0; i < e[x].size(); i++) {
    int y = e[x][i];
    if(y == fa) continue;
    dfs(y, x);
    f[x][0] += f[y][1];
    f[x][1] += min(f[y][0], f[y][1]);
  }
  f[x][1]++; 
}

signed main() {
  n = read();
  For(i,1,n) {
    int x = read(), k = read();
    For(j,1,k) {
      int y = read();
      e[x].push_back(y);
      e[y].push_back(x);
    }
  }
  dfs(0, -1);
  cout << min(f[0][0], f[0][1]) << '\n';
  return 0; 
}
/*
8
0 2 1 2
1 2 3 4
2 0
3 0
4 1 5
5 2 6 7
6 0
7 0
*/
posted @ 2023-03-31 17:32  Daniel_yzy  阅读(30)  评论(0编辑  收藏  举报
Title