【模板】动态 DP

【模板】动态 DP

动态 dp 入门题。

题意

给定一棵 \(n\) 个节点的树,第 \(i\) 个点的点权为 \(a_i\)

接下来有 \(m\) 次操作。每次操作给定 \(x\)\(y\),把 \(a_x\) 修改为 \(y\)

你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

\(1 \leq n,m \leq 10^5\),任何时刻 \(|a_i| \leq 10^2\)

思路

遇到最大权独立集问题可以考虑 dp。

我们令 \(f_{u,0}\) 表示以 \(u\) 为根的子树中不选 \(u\) 的最大权独立集,\(f_{u,1}\) 表示以 \(u\) 为根的子树中选 \(u\) 的最大权独立集,则有:

\[\begin{cases} f_{u,0}=\sum_{v \in son_u} \max(f_{v,0},f_{v,1}) \\ f_{u,1}=a_u + \sum_{v \in son_u} f_{v,0} \end{cases} \]

接下来考虑维护修改操作。

考虑重链剖分的性质,每个点到根的路径上只会有 \(O(\log n)\) 条重链。

我们令 \(g_{u,0}\) 表示以 \(u\) 为根的子树中除去重儿子子树且不选 \(u\) 的最大权独立集,\(g_{u,1}\) 表示以 \(u\) 为根的子树中除去重儿子子树且选 \(u\) 的最大权独立集,则有:

\[\begin{cases} f_{u,0}=g_{u,0}+\max(f_{heavy_u,0},f_{heavy_u,1}) \\ f_{u,1}=g_{u,1}+f_{heavy_u,0} \\ g_{u,0}=\sum_{v \in light_u} \max(f_{v,0},f_{v,1}) \\ g_{u,1}=a_u + \sum_{v \in light_u} f_{v,0} \end{cases} \]

动态 dp 的经典思想是用矩阵维护转移。所以我们定义广义矩阵乘法如下:

如果 \(n\)\(m\) 列的矩阵 \(A\) 和一个 \(m\)\(k\) 列的矩阵 \(B\) 得到的乘积是 \(n\)\(k\) 列的矩阵 \(C\),则 \(C_{x,y}=\max_{i=1}^{m}(A_{x,i}+B_{i,y})\)。容易发现这种矩阵乘法具有结合律。

所以可以得到下面的式子:

\[\begin{bmatrix} g_{u,0} & g_{u,0} \\ g_{u,1} & -\infty \end{bmatrix} \times \begin{bmatrix} f_{heavy_u,0} \\ f_{heavy_u,1} \end{bmatrix} = \begin{bmatrix} f_{u,0} \\ f_{u,1} \end{bmatrix} \]

我们考虑对于每个节点维护 \(\begin{bmatrix} g_{u,0} & g_{u,0} \\ g_{u,1} & -\infty \end{bmatrix}\)。我们发现,对于所有的叶子节点,\(f_{u,i}=g_{u,i}\)。对于一条重链,可以由叶子节点的转移矩阵反推每一点的 \(f\) 值。所以我们并不需要维护 \(f\)

我们计算单点的 \(f\) 值复杂度为 \(O(\log n)\),需要条 \(O(\log n)\) 次重链,总的时间复杂度应为 \(O(m \log^2 n)\)

代码

#include<iostream>
#include<cstdio>
#include<vector>
using namespace std;
const int INF=0x3f3f3f3f;
int num[100010];
int tot,son[100010],pa[100010],dfn[100010],top[100010],rev[100010],bottom[100010],child[100010];
int dp_f[100010][2],dp_g[100010][2];
vector<int> G[100010],T[100010];
void dfs1(int u,int fa){
	child[u]=1;
	for(int i=0;i<G[u].size();i++){
		int v=G[u][i];
		if(v!=fa){
			T[u].push_back(v);
			pa[v]=u;
			dfs1(v,u);
			child[u]+=child[v];
			if(child[v]>=child[son[u]]){
				son[u]=v;
			}
		}
	}
}
void dfs2(int u,int fa){
	rev[++tot]=u;
	dfn[u]=tot;
	top[u]=fa;
	bottom[fa]=u;
	if(son[u]){
		dfs2(son[u],fa);
	}
	for(int i=0;i<T[u].size();i++){
		int v=T[u][i];
		if(v!=son[u]){
			dfs2(v,v);
		}
	}
}
void dfs3(int u){
	dp_f[u][1]=dp_g[u][1]=num[u];
	for(int i=0;i<T[u].size();i++){
		int v=T[u][i];
		dfs3(v);
		dp_f[u][0]+=max(dp_f[v][0],dp_f[v][1]);
		dp_f[u][1]+=dp_f[v][0];
		if(v!=son[u]){
			dp_g[u][0]+=max(dp_f[v][0],dp_f[v][1]);
			dp_g[u][1]+=dp_f[v][0];
		}
	}
}
struct Matrix{
	int matrix[2][2];
};
const Matrix operator *(const Matrix &x,const Matrix &y){
	Matrix z;
	for(int i=0;i<2;i++){
		for(int j=0;j<2;j++){
			z.matrix[i][j]=max(x.matrix[i][0]+y.matrix[0][j],x.matrix[i][1]+y.matrix[1][j]);
		}
	}
	return z;
}
struct Node{
	int l,r;
	Matrix g;
}a[400010];
void pushup(int id){
	a[id].g=a[id*2].g*a[id*2+1].g;
}
void build(int id,int l,int r){
	a[id].l=l;
	a[id].r=r;
	if(a[id].l==a[id].r){
		a[id].g.matrix[0][0]=a[id].g.matrix[0][1]=dp_g[rev[l]][0];
		a[id].g.matrix[1][0]=dp_g[rev[l]][1];
		a[id].g.matrix[1][1]=-INF;
	}
	else{
		int mid=(l+r)>>1;
		build(id*2,l,mid);
		build(id*2+1,mid+1,r);
		pushup(id);
	}
}
Matrix query(int id,int l,int r){
	if(l<=a[id].l  &&  a[id].r<=r){
		return a[id].g;
	}
	bool flag=false;
	Matrix ans;
	if(l<=a[id*2].r){
		flag=true;
		ans=query(id*2,l,r);
	}
	if(a[id*2+1].l<=r){
		if(flag==false){
			ans=query(id*2+1,l,r);
		}
		else{
			ans=ans*query(id*2+1,l,r);
		}
	}
	return ans;
}
void modify(int id,int pos,Matrix dif){
	if(a[id].l==a[id].r){
		a[id].g=dif;
		return ;
	}
	if(pos<=a[id*2].r){
		modify(id*2,pos,dif);
	}
	else{
		modify(id*2+1,pos,dif);
	}
	pushup(id);
}
struct Query{
	int dp0,dp1;
};
Query query_f(int u){
	int id_u=dfn[u],id_bottom=dfn[bottom[u]];
	Matrix tmp=query(1,id_bottom,id_bottom),tmp2=query(1,id_u,id_bottom-1);
	Query tmp3=(Query){tmp.matrix[0][0],tmp.matrix[1][0]};
	if(id_bottom==id_u) return tmp3;
	Query ans=(Query){max(tmp2.matrix[0][0]+tmp3.dp0,tmp2.matrix[0][1]+tmp3.dp1),max(tmp2.matrix[1][0]+tmp3.dp0,tmp2.matrix[1][1]+tmp3.dp1)};
	return ans;
} 
int main(){
	int n,m;
	scanf("%d %d",&n,&m);
	for(int i=1;i<=n;i++){
		scanf("%d",&num[i]);
	}
	for(int i=1;i<n;i++){
		int u,v;
		scanf("%d %d",&u,&v);
		G[u].push_back(v);
		G[v].push_back(u);
	}
	dfs1(1,-1);
	dfs2(1,1);
	for(int i=1;i<=n;i++){
		bottom[i]=bottom[top[i]];
	}
	dfs3(1);
	build(1,1,n);
	while(m--){
		int u,dif;
		scanf("%d %d",&u,&dif);
		Query lst=query_f(top[u]); 
		Matrix tmp=query(1,dfn[u],dfn[u]);
		tmp.matrix[1][0]+=dif-num[u];
		num[u]=dif;
		modify(1,dfn[u],tmp);
		while(top[u]!=1){
			int pre=top[u];
			Query f=query_f(pre);
			int fa=pa[pre];
			tmp=query(1,dfn[fa],dfn[fa]);
			tmp.matrix[0][0]+=max(f.dp0,f.dp1)-max(lst.dp0,lst.dp1);
			tmp.matrix[0][1]+=max(f.dp0,f.dp1)-max(lst.dp0,lst.dp1);
			tmp.matrix[1][0]+=f.dp0-lst.dp0;
			lst=query_f(top[fa]);
			modify(1,dfn[fa],tmp);
			u=fa;
		}
		lst=query_f(1);
		printf("%d\n",max(lst.dp0,lst.dp1));
	}
	return 0;
}
posted @ 2025-05-18 09:00  Oken喵~  阅读(1)  评论(0)    收藏  举报