动态DP(DDP)

动态DP是树上的、带修改的DP。修改操作一般而言用树剖加线段树加广义矩阵乘法来维护,复杂度可以达到 \(n\log^2 n\)
叫DDP是不知从哪里延续下来的一种神秘简称。

P4719 【模板】动态 DP

给定一颗树,每个点有权值,维护最大独立集。最大独立集指没有两个集合中的点被一条边直接相连。

如果不带修,那就是一道很简单的 DP 题。
\(f_{u,0/1}\) 表示某个点选或者不选,它自己和其子树中选的点的最大权值。
方程也是比较直观的:

\[\begin{aligned} \left\{\begin{matrix} f_{u,0} =\sum _v \max (f_{v,0},f_{v,1}) \\ f_{u,1} =\sum _v f_{v,0} \end{matrix}\right. \end{aligned} \]

现在考虑带修改怎么做。
显然树剖是必要的。然后问题就变成如何快速的维护轻重链间的答案,同时一条重链内的答案可以被线段树合并维护。

考虑这么一个另外的变量 \(g_{u,0/1}\) 表示对于某一个点,如果不考虑其重儿子,这个点选或者不选,这棵子树的最大权值。
这样我们就可以将轻重儿子间的信息分开单独维护。新的通过 \(g\) 来转移的方程式:

\[\begin{aligned} \left\{\begin{matrix} f_{u,0} =\max (g_{u,0}+f_{son_u,1},g_{u,0}+f_{son_u,0}) \\ f_{u,1} = g_{u,1}+f_{son_u,0} \end{matrix}\right. \end{aligned} \]

这里 \(son_u\) 表示 \(u\) 的重儿子。

看起来好像没什么大用,反而还要多维护一个信息。但是这恰恰是线段树可以快速合并重链信息的关键。
考虑这种转移方程可以转化为一种广义矩阵乘法的形式。

具体的,设矩阵间的“*”积表示:

\[a*b=c \Leftrightarrow c_{i,j}=\max_k(a_{i,k}+b_{k,j}) \]

这时方程就可以写为:

\[\begin{vmatrix} g_{u,0} & g_{u,0}\\ g_{u,1} & -\infty \end{vmatrix}* \begin{vmatrix} f_{son_u,0} \\ f_{son_u,1} \end{vmatrix}= \begin{vmatrix} f_{u,0} \\ f_{u,1} \end{vmatrix} \]

如果我们知道了一条重链上的所有的点的 \(g\) 数组和重链底的那个点的 \(f\),那就可以通过线段树合并出重链顶的 \(f\)
由于我们只会查询树根的 \(f\) 值,而树根一定是一个重链顶。
也就是现在只需要用线段树维护前面只与 \(g\) 有关的转移矩阵,将其合并起来就做完了。

首先我们考虑单点修改的时候会造成什么影响。修改首先会影响当前重链的答案,然后传到更上面的一条重链,一直传递下去。
假设当前传递到的这条重链的顶为 \(u\) ,其父亲也就是更上面的那一条重链的底是 \(v\)。显然的,\(u\) 一定是 \(v\) 的轻儿子。那就说明 \(u\)\(f\) 值只会影响到 \(v\)\(g\) 值。

由于 \(g_{v,0}\) 的计算是所有轻儿子的 \(\sum \max{f_{u,0},f_{u,1}}\),因此只需要将修改之前的 \(\max{f_{u,0},f_{u,1}}\) 减掉,将新的加回来就可以了。\(g_{v,1}\) 同理。

由于按照矩阵的转移式,我们前面一堆转移矩阵后面还要跟一个 \(f\) 的矩阵。那这需要特判吗?
这里有一个比较巧妙的设计,可以发现在矩阵中 \(g_{u,0},g{u,1}\)\(f_{u,0},f_{u,1}\) 的位置是对应的。也就是说对于没有重儿子的叶子节点,我们也可以将本应是 \(f\) 的值直接塞到转移矩阵的第一列对应位置。这样矩乘的时候自然也就将其包含,放在最后面“*”乘起来了。

code

仿照了一下题解的写法。但是比题解“略有”压行。

#include<bits/stdc++.h>
using namespace std;
const int N=4e5+7,inf=1e9+7;
int n,m,a[N],tot[N],dep[N],siz[N],son[N],dfn[N],loc[N],top[N],dfncnt=0,len[N],f[N][2],fat[N];
vector <int> q[N];
struct node{int mp[2][2];};
void init2(node &x){for(int i=0;i<2;i++)for(int j=0;j<2;j++) x.mp[i][j]=0;}void init3(node &x){for(int i=0;i<2;i++)for(int j=0;j<2;j++) x.mp[i][j]=-inf;}
node operator * (const node &x,const node &y){
	node res;init3(res);
	for(int i=0;i<2;i++)for(int j=0;j<2;j++)for(int k=0;k<2;k++) res.mp[i][j]=max(res.mp[i][j],x.mp[i][k]+y.mp[k][j]);
	return res;
}
node tra[N],tr[N];
void dfs1(int u,int fa){
	dep[u]=dep[fa]+1,siz[u]=1;fat[u]=fa;
	for(int i=0;i<tot[u];i++){
		int v=q[u][i];if(v==fa) continue;
		dfs1(v,u);son[u]=siz[son[u]]<siz[v]?v:son[u];siz[u]+=siz[v];
	}
}
void dfs2(int u,int fa,int t){
	dfn[u]=++dfncnt,loc[dfncnt]=u;top[u]=t,len[t]++;
	init2(tra[u]);f[u][0]=0,f[u][1]=tra[u].mp[1][0]=a[u];tra[u].mp[1][1]=-inf;
	if(son[u]) dfs2(son[u],u,t),f[u][0]+=max(f[son[u]][0],f[son[u]][1]),f[u][1]+=f[son[u]][0];
	for(int i=0;i<tot[u];i++){
		int v=q[u][i];if(v==fa||v==son[u]) continue;
		dfs2(v,u,v);int tmp1=max(f[v][0],f[v][1]),tmp2=f[v][0];
		f[u][0]+=tmp1,f[u][1]+=tmp2;
		tra[u].mp[0][0]+=tmp1,tra[u].mp[1][0]+=tmp2;
	}
	tra[u].mp[0][1]=tra[u].mp[0][0];
}
#define ls (u<<1)
#define rs (u<<1|1)
void push_up(int u){tr[u]=tr[ls]*tr[rs];}
void build(int u,int l,int r){
	if(l==r) {tr[u]=tra[loc[l]];return;}
	int mid=(l+r)>>1;build(ls,l,mid),build(rs,mid+1,r);
	push_up(u);
}
void modify(int u,int l,int r,int x){
	if(l==r) {tr[u]=tra[loc[l]];return;}
	int mid=(l+r)>>1;if(x<=mid)modify(ls,l,mid,x);else modify(rs,mid+1,r,x);
	push_up(u);
}
node query(int u,int l,int r,int ql,int qr){
	if(ql<=l&&r<=qr){return tr[u];}
	int mid=(l+r)>>1;
	if(qr<=mid) return query(ls,l,mid,ql,qr);if(ql>mid) return query(rs,mid+1,r,ql,qr);
	return query(ls,l,mid,ql,qr)*query(rs,mid+1,r,ql,qr);
}
void update(int u,int val){
	tra[u].mp[1][0]+=val-a[u],a[u]=val;
	node x,y;
	while(u){
		x=query(1,1,n,dfn[top[u]],dfn[top[u]]+len[top[u]]-1);modify(1,1,n,dfn[u]);y=query(1,1,n,dfn[top[u]],dfn[top[u]]+len[top[u]]-1);
		u=fat[top[u]];
		tra[u].mp[0][0]+=max(y.mp[0][0],y.mp[1][0])-max(x.mp[0][0],x.mp[1][0]);
		tra[u].mp[1][0]+=y.mp[0][0]-x.mp[0][0];tra[u].mp[0][1]=tra[u].mp[0][0];
	}
}
signed main(){
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n>>m;for(int i=1;i<=n;i++) cin>>a[i];
	for(int i=1,u,v;i<=n-1;i++) cin>>u>>v,q[u].push_back(v),tot[u]++,q[v].push_back(u),tot[v]++;
	dfs1(1,0),dfs2(1,0,1);build(1,1,n);
	for(int i=1,u,x;i<=m;i++){
		cin>>u>>x;update(u,x);node res=query(1,1,n,dfn[top[1]],dfn[top[1]]+len[top[1]]-1);
		cout<<max(res.mp[0][0],res.mp[1][0])<<'\n';
	}
	return 0;
}
posted @ 2025-04-01 23:13  all_for_god  阅读(62)  评论(0)    收藏  举报