动态dp(ddp)小记

posted on 2023-11-13 12:01:45 | under | source

简介

使用矩阵乘、重链剖分等,解决带修改的 \(\rm dp\) 问题。

引入

考虑一类可用 \(\rm dp\) 解的问题,但是需多次查询、且每次查询都重新 \(\rm dp\) 会超时。

此时可用矩阵的形式描述转移矩阵,那么每次查询的结果便为 初始矩阵 \(*\) 转移矩阵之积。

考虑构造矩阵。定义广义矩阵乘:

\[C_{i,j} = \max/\min A_{i,k}+B_{k,j} \]

这样是满足结合率的,具体见 这篇博客。总之,只需维护 转移矩阵之积 即可。

正文

若上述问题加上修改操作,怎么办?依然可做,用线段树维护矩阵乘积即可。

若上述问题放在树上,怎么办?结合重链剖分的思路,对于单点修改,只会影响其到根节点路径上的点的 \(\rm dp\) 值,又因为该路径的重链数量仅为 \(\log n\) 级别,于是时间复杂度得到保证。

具体的,将 \(\rm dp\) 分轻、重儿子两部分讨论,每次修改只需对上一条重链的轻儿子部分操作即可(因为跳了一条链,所以不可能是其重儿子,也就不可能对重儿子部分修改)。链顶的 \(\rm dp\) 值矩阵即为链上矩阵乘积。

代码设计流程如下:

  1. 写出朴素 \(\rm dp\),将分为轻重儿子两部分。
  2. 用矩阵描述转移方程。
  3. 将每条重链的转移矩阵放在线段树上。
  4. 修改操作:每次修改往上一条重链的 轻儿子矩阵 部分。
  5. 查询操作:即为链上矩阵乘积。

P4719 举例,先将朴素 \(\rm dp\) 分为轻重儿子两部分。定义 \(g_{u,0}\) 表示所有轻儿子可取或不取、\(u\) 不取时最大值,\(g_{u,1}\) 表示轻儿子都不取、\(u\) 取时最大值。这里的定义是为之后方便设计为主。

然后转移方程大家应该都会:

\[f_{u,0} = g_{u,0}+f_{wson,0/1} \]

\[f_{u,1} = g_{u,1}+f_{wson,0} \]

在转为矩阵前先想想怎么构造比较好。由于朴素 \(\rm dp\) 是从链底到链顶转移,为了线段树实现方便,所以将转移矩阵放前面,\(f\) 放后面,这样下来转移矩阵从左往右为链顶到链尾。

转化矩阵:

\[\left( \begin{matrix} g_{u,0} & g_{u,0} \\ g_{u,1} & -inf \end{matrix} \right) * \left( \begin{matrix} f_{wson,0} \\ f_{wson,1} \end{matrix} \right) = \left( \begin{matrix} f_{u,0} \\ f_{u,1} \end{matrix} \right) \]

可以发现,对于叶子节点:

\[f_{leaf,0/1}=g_{u,0/1} \]

所以对于本题而言,初始矩阵就是叶子的转移矩阵,不需再使用 \(1*2\) 的矩阵。

然后对于修改操作,假设前一条重链是 \(y\),现在要更新 \(x\)。只需减去 \(y\)\(x\) 的原贡献,修改完 \(x\) 后再加上其新贡献即可。再重复一遍:\(f_{x_{top}}\)\(x\) 上的转移矩阵乘积。

结束前讲下卡常技巧吧:

  1. 暴力展开矩阵乘。
  2. 对于每条重链都开一个线段树,那么便省去了 query 函数,可以 \(O(1)\) 查询每条重链的 \(f_{top}\)

代码

还是有点抽象,看看代码应该就懂了,主要看 \(dfs\)\(change\) 部分。

此处为卡了常的代码,可过加强版。

#include<bits/stdc++.h>
using namespace std;

const int N = 1e6 + 5, inf = -0x7f7f7f7f;
int n, m, a[N], head[N], tot, u, v, x, y;
int id[N], raw[N], top[N], siz[N], wson[N], fa[N], ed[N], df;
int f[N][2];
struct edge{int v, nxt;}e[N << 1];

inline void read(int &a){
   int s = 0, w = 1; char ch = getchar();
   while(ch < '0' || ch > '9') {if(ch == '-') w = -1; ch = getchar();}
   while(ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
   a = s * w;
}
inline void write(int a){
	if(a < 0) putchar('-'), a = -a;
    if(a >= 10) write(a / 10);
    putchar(a % 10 + '0');
}
inline void add(int u, int v) {e[++tot] = {v, head[u]}, head[u] = tot;}
struct Matrix{
	int mt[2][2];
	Matrix() {memset(mt, -0x3f, sizeof mt);}
	inline Matrix operator * (Matrix B){
		Matrix C;
		C.mt[0][0] = max(mt[0][0] + B.mt[0][0], mt[0][1] + B.mt[1][0]);
		C.mt[0][1] = max(mt[0][0] + B.mt[0][1], mt[0][1] + B.mt[1][1]);
		C.mt[1][0] = max(mt[1][0] + B.mt[0][0], mt[1][1] + B.mt[1][0]);
		C.mt[1][1] = max(mt[1][0] + B.mt[0][1], mt[1][1] + B.mt[1][1]);
		return C; 
	}
}G[N], bef, now;

namespace Matrix_Sg_tree{
	#define lt (ls[u])
	#define rt (rs[u])
	#define mid (l + r >> 1)
	Matrix t[N << 2];
	int ls[N << 2], rs[N << 2], dc, rot[N];
	
	inline void psup(int u) {t[u] = t[lt] * t[rt];}
	inline void build(int &u, int l, int r){
		u = ++dc;
		if(l == r) {t[u] = G[raw[l]]; return ;}
		build(lt, l, mid), build(rt, mid + 1, r);
		psup(u);
	}
	inline void upd(int u, int l, int r, int k){
		if(l == r) {t[u] = G[raw[k]]; return ;}
		if(k <= mid) upd(lt, l, mid, k);
		else upd(rt, mid + 1, r, k);
		psup(u);
	}
} 
using namespace Matrix_Sg_tree;
inline void idfs(int u, int from){
	siz[u] = 1, fa[u] = from, wson[u] = 0;
	for(int i = head[u], v; v = e[i].v, i; i = e[i].nxt)
		if(v ^ from){
			idfs(v, u), siz[u] += siz[v];
			if(siz[v] > siz[wson[u]]) wson[u] = v;
		}
}
inline void dfs1(int u, int from, int tp){
	id[u] = ++df, raw[df] = u, top[u] = tp, ed[tp] = df;
	if(wson[u]) dfs1(wson[u], u, tp);
	for(int i = head[u], v; v = e[i].v, i; i = e[i].nxt)
		if(v ^ from && v ^ wson[u])
			dfs1(v, u, v);
	
}
inline void dfs2(int u, int from){
	G[u].mt[0][0] = G[u].mt[0][1] = 0, G[u].mt[1][0] = a[u];
	f[u][1] = a[u];
	for(int i = head[u], v; v = e[i].v, i; i = e[i].nxt)
		if(v ^ from){
			dfs2(v, u);
			f[u][0] += max(f[v][0], f[v][1]), f[u][1] += f[v][0];
			if(v ^ wson[u]){
				G[u].mt[0][0] += max(f[v][0], f[v][1]), G[u].mt[0][1] = G[u].mt[0][0];
				G[u].mt[1][0] += f[v][0]; 
			}
		}
}
inline void change(int u, int k){
	G[u].mt[1][0] += k - a[u], a[u] = k;
	while(u != 0){
		bef = t[rot[top[u]]];
		upd(rot[top[u]], id[top[u]], ed[top[u]], id[u]), now = t[rot[top[u]]];
		u = fa[top[u]];
		if(u){
			G[u].mt[0][0] += max(now.mt[0][0], now.mt[1][0]) - max(bef.mt[0][0], bef.mt[1][0]);
			G[u].mt[0][1] = G[u].mt[0][0];
			G[u].mt[1][0] += now.mt[0][0] - bef.mt[0][0];
		}
	}
}
inline void init(){
	cin >> n >> m;
	for(int i = 1; i <= n; ++i) read(a[i]);
	for(int i = 1; i < n; ++i) read(u), read(v), add(u, v), add(v, u);
	idfs(1, 0), dfs1(1, 1, 1), dfs2(1, 0);
	for(int i = 1; i <= n;++i) if(top[i] == i) build(rot[i], id[i], ed[i]);
}
inline void solve(){
	int lst = 0;
	while(m--){
		read(x), read(y);
		change(x ^ lst, y);
		now = t[rot[1]];
		write(lst = max(now.mt[0][0], now.mt[1][0])), putchar('\n');
	}
}
int main(){
	init();
	solve();
	return 0;
}

全局平衡二叉树

对上述做法稍稍优化,我们考虑对每条重链建立一棵平衡树,然后轻儿子连接虚边,即认父不认子。

这样子修改就直接向上跳,遇到虚边就暴力修改父亲的 \(g\);查询直接在对应的平衡树上查一个后缀即可。本质和树剖并无区别。

为了体现“全局平衡”,建立平衡树时我们令一个点的权重为它的轻子树大小 \(+1\),然后找带权中点做根,不断递归。

可以证明,这样做树高 \(O(\log n)\)。考虑虚边,等于祖先链上重链数量显然 \(O(\log n)\);考虑实边,每次向上跳大小至少翻倍,所以也 \(O(\log n)\)

代码

来自模板加强版题解。

// Problem: P4751 【模板】"动态DP"&动态树分治(加强版)
// From: Luogu
// URL: https://www.luogu.com.cn/problem/P4751
// Time: 2022-05-19 21:21
// Author: lingfunny

#include <bits/stdc++.h>
#define LL long long
using namespace std;
const int mxn = 1e6+10, inf = 1e9;

int n, m, lst, a[mxn], rt, f[mxn][2], g[mxn][2];
vector <int> G[mxn];

struct mat {
	static const int V = 2;
	int a[V][V];
	mat(int a00 = 0, int a01 = -inf, int a10 = -inf, int a11 = 0) { a[0][0] = a00, a[0][1] = a01, a[1][0] = a10, a[1][1] = a11; }
	inline mat operator * (const mat &rhs) const {
		mat res;
		for(int i = 0; i < V; ++i) for(int j = 0; j < V; ++j) {
			res.a[i][j] = -inf;
			for(int k = 0; k < V; ++k) res.a[i][j] = max(res.a[i][j], a[i][k] + rhs.a[k][j]);
		}
		return res;
	}
	inline void show() {
		for(int i = 0; i < V; ++i) for(int j = 0; j < V; ++j) printf("%d%c", a[i][j], " \n"[j==V-1]);
	}
} gm[mxn];

struct node { int lc, rc, anc; mat u, s; } nd[mxn];
inline void psup(int u) {
	node &o = nd[u];
	o.s = nd[o.rc].s * o.u * nd[o.lc].s;	// 反着乘,原因见上文 P4719
}

int sz[mxn], lsz[mxn], dep[mxn], fa[mxn], son[mxn], top[mxn], End[mxn], dfn[mxn], mp[mxn], dfc;
// lsz[u]: Lsize[u]
// top/End: 重链顶/底
// mp: mp[dfn[u]] = u

void dfs(int u, int f) {
	lsz[u] = sz[u] = 1, dep[u] = dep[f] + 1, fa[u] = f;
	for(int v: G[u]) if(v != f) {
		dfs(v, u), sz[u] += sz[v];
		if(sz[v] > sz[son[u]]) son[u] = v;
	}
	lsz[u] = sz[u] - sz[son[u]];
}
void dfs2(int u) {
	End[u] = mp[dfn[u] = ++dfc] = u; g[u][0] = a[u];
	if(son[u]) top[son[u]] = top[u], dfs2(son[u]), End[u] = End[son[u]];
	for(int v: G[u]) if(v != fa[u] && v != son[u]) top[v] = v, dfs2(v), g[u][0] += f[v][0], g[u][1] += f[v][1];
	f[u][0] = f[son[u]][1] + g[u][1];
	f[u][1] = max(f[u][0], f[son[u]][0] + g[u][0]);
	gm[u] = mat(-inf, g[u][0], g[u][1], g[u][1]);
}

int sbuild(int L, int R) {
	if(L > R) return 0;
	LL sum = 0, qsum = 0;
	for(int i = L; i <= R; ++i) sum += lsz[mp[i]];
	for(int i = L, o; i <= R; ++i) {
		qsum += lsz[mp[i]];
		if(qsum * 2 > sum) {
			o = mp[i];
			node &u = nd[o];
			u.u = gm[o];
			u.lc = sbuild(L, i-1), nd[u.lc].anc = o;
			u.rc = sbuild(i+1, R), nd[u.rc].anc = o;
			psup(o);
			return o;
		}
	}
	return -114514;
}

int build(int Tp) {
	int Ed = End[Tp], X = sbuild(dfn[Tp], dfn[Ed]);
	for(int i = dfn[Tp]; i <= dfn[Ed]; ++i) {
		const int &u = mp[i];
		for(int v: G[u]) if(v != son[u] && v != fa[u]) nd[build(v)].anc = u;
	}
	return X;
}

inline void modify(int u, int x) {
	gm[u].a[0][1] += x - a[u]; a[u] = x;
	nd[u].u = gm[u];
	while(u) {
		int F = nd[u].anc;
		if(nd[F].lc != u && nd[F].rc != u && F) { // 如果当前节点到父亲的边是虚边
			mat bef = nd[u].s;
			psup(u);
			mat aft = nd[u].s;
			int _f0 = max(bef.a[0][0], bef.a[1][0]), _f1 = max(bef.a[0][1], bef.a[1][1]),
			f0 = max(aft.a[0][0], aft.a[1][0]), f1 = max(aft.a[0][1], aft.a[1][1]);
			mat &r = nd[F].u;
			r.a[0][1] += f0 - _f0;
			r.a[1][0] += f1 - _f1, r.a[1][1] += f1 - _f1;
			gm[F] = r;
		} else psup(u);
		u = F;
	}
}

signed main() {
	scanf("%d%d", &n, &m);
	for(int i = 1; i <= n; ++i) scanf("%d", a+i);
	for(int i = 1, u, v; i < n; ++i) scanf("%d%d", &u, &v), G[u].push_back(v), G[v].push_back(u);
	dfs(1, 0), top[1] = 1, dfs2(1), rt  = build(1);
	while(m--) {
		int x, y; scanf("%d%d", &x, &y); x ^= lst;
		modify(x, y);
		printf("%d\n", lst=max(nd[rt].s.a[0][1], nd[rt].s.a[1][1]));
	}
	return 0;
}
posted @ 2026-01-14 17:57  Zwi  阅读(1)  评论(0)    收藏  举报