动态 DP 学习笔记

动态 DP

P4719 动态 DP

给定一棵 \(n\) (\(n \leqslant 10^5\)) 个点的树,点带点权。

\(m\) (\(m \leqslant 10^5\)) 次操作,每次操作给定 \(x,y\),表示修改点 \(x\) 的权值为 \(y\)

你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

首先考虑 \(m=0\) 时的做法,可以简单地设计出 dp 转移方程式 \(f_{x, 0 / 1}\) 表示 \(x\) 选或不选的答案。

\[\begin{cases} f_{x, 0} = \sum\limits_{v \in x.\text{son}} \max\{f_{v, 1}, f_{v, 0}\} \\ f_{x, 1} = p_x + \sum\limits_{v \in x.\text{son}} f_{v, 0} \end{cases} \]

考虑用矩阵乘法维护这个东西,这个东西显然不能写成矩阵乘法的形式,故考虑广义矩阵乘法。

广义矩阵乘法

比如此题中的矩阵惩罚可以长这个样子:

\[C_{i, j} = \max\limits_{k = 0}^{size - 1} \left(A_{i, k} + B_{k, j}\right) \]

可以发现,就是将普通矩阵乘法中的加换成了 \(\max\),乘法换成了加法。

对应到上面的递推式,可以添加辅助数组 \(g\)\(g_{x, 0 / 1}\) 表示不考虑 \(x\) 的实儿子的情况下,选或不选 \(x\) 的答案:

\[\begin{cases} g_{x, 0} = \sum\limits_{v \in x.\text{son} \land v \ne son_x} \max\{f_{v, 0}, f_{v, 1}\}\\ g_{x, 1} = p_x + \sum\limits_{v \in x.\text{son} \land v \ne son_x} f_{v, 0} \\ f_{x, 0} = g_{x, 0} + \max\{f_{son_x, 0}, f_{son_x, 1}\}\\ f_{x, 1} = g_{x, 1} + f_{son_x, 0} \end{cases}\\ \begin{bmatrix} g_{x, 0} & g_{x, 0}\\ g_{x, 1} & -\infty \end{bmatrix} \times \begin{bmatrix} f_{son_x, 0} \\ f_{son_x, 1} \end{bmatrix} = \begin{bmatrix} f_{x, 0} \\ f_{x, 1} \end{bmatrix} \]

然后下面的这坨东西是可以用线段树维护的,接着就可以树剖了。

代码
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
#define int i64
const int N = 1E5 + 5;
const int inf = 0x3f3f3f3f;
int n, m, p[N];
vector <int> G[N];
int sz[N], son[N], top[N], dfn[N], tot, fa[N], dep[N], wt[N];
int f[N][2], g[N][2], epos[N];
void dfs1(int x) {
  sz[x] = 1; pair <int, int> mx = make_pair(0, 0);
  f[x][1] = p[x]; f[x][0] = 0;
  for (auto v : G[x]) {
    if (v == fa[x]) continue;
    fa[v] = x; dep[v] = dep[x] + 1;
    dfs1(v); sz[x] += sz[v];
    mx = max(mx, make_pair(sz[v], v));
    f[x][0] += max(f[v][0], f[v][1]);
    f[x][1] += f[v][0];
  } son[x] = mx.second;
}
void dfs2(int x, int topf) {
  top[x] = topf; wt[dfn[x] = ++tot] = x;
  g[x][1] = p[x];
  if (!son[x]) return void(epos[topf] = tot);
  dfs2(son[x], topf);
  for (auto v : G[x]) {
    if (v == son[x] || v == fa[x]) continue;
    dfs2(v, v);
    g[x][0] += max(f[v][0], f[v][1]);
    g[x][1] += f[v][0];
  }
}
struct mat {
  int a[2][2];
  mat () {
    a[0][0] = a[1][0] = 
    a[0][1] = a[1][1] = 0;
  }
  void set(int x, int y, int val = 1) {a[x][y] = val;}
} ;
mat bp(int x) {
  mat ans; ans.set(1, 1, -inf);
  ans.set(0, 0, g[x][0]); ans.set(0, 1, g[x][0]);
  ans.set(1, 0, g[x][1]); return ans;
}
mat mul(mat x, mat y) {
  mat ans; 
  for (int i = 0; i < 2; ++i) for (int j = 0; j < 2; ++j)
    for (int k = 0; k < 2; ++k)
      ans.a[i][j] = max(ans.a[i][j], x.a[i][k] + y.a[k][j]);
  return ans;
}
struct segt {
  struct node {
    int l, r;
    mat p;
  } t[N << 2];
  int lson(int x) {return x << 1;}
  int rson(int x) {return x << 1 | 1;}
  void pushup(int x) {t[x].p = mul(t[lson(x)].p, t[rson(x)].p);}
  void build(int x, int l, int r) {
    t[x].l = l; t[x].r = r;
    if (l == r) {
      t[x].p = bp(wt[l]);
      return ;
    } int mid = (l + r) >> 1;
    build(lson(x), l, mid);
    build(rson(x), mid + 1, r);
    pushup(x);
  }
  void update(int x, int pos) {
    if (t[x].l == t[x].r) {
      t[x].p = bp(wt[pos]);
      return ;
    } int mid = (t[x].l + t[x].r) >> 1;
    if (pos <= mid) update(lson(x), pos);
    else update(rson(x), pos);
    pushup(x);
  }
  node query(int x, int L, int R) {
    if (t[x].l >= L && t[x].r <= R)
      return t[x];
    int mid = (t[x].l + t[x].r) >> 1;
    if (L > mid) return query(rson(x), L, R);
    if (R <= mid) return query(lson(x), L, R);
    node ans, ls = query(lson(x), L, R), rs = query(rson(x), L, R);
    ans.l = ls.l; ans.r = rs.r;
    ans.p = mul(ls.p, rs.p);
    return ans;
  }
} t;
void update(int x, int val) {
  g[x][1] = g[x][1] - p[x] + val;
  p[x] = val; 
  while (x) {
    mat pre = t.query(1, dfn[top[x]], epos[top[x]]).p;
    t.update(1, dfn[x]);
    mat now = t.query(1, dfn[top[x]], epos[top[x]]).p;
    x = fa[top[x]];
    g[x][0] += max(now.a[0][0], now.a[1][0]) - max(pre.a[0][0], pre.a[1][0]);
    g[x][1] += now.a[0][0] - pre.a[0][0];
  }
}
signed main(void) {
  ios :: sync_with_stdio(false);
  cin.tie(0); cout.tie(0);
  cin >> n >> m;
  for (int i = 1; i <= n; ++i) cin >> p[i];
  for (int i = 1; i < n; ++i) {
    int u, v; cin >> u >> v;
    G[u].emplace_back(v);
    G[v].emplace_back(u);
  } dfs1(1); dfs2(1, 1);
  t.build(1, 1, n);
  for (int i = 1; i <= m; ++i) {
    int x, v; cin >> x >> v;
    update(x, v);
    mat p = t.query(1, 1, epos[1]).p;
    cout << max(p.a[0][0], p.a[1][0]) << '\n';
  }
  return 0;
}
posted @ 2024-01-30 14:33  CTHOOH  阅读(31)  评论(0)    收藏  举报