【模板】动态 DP
【模板】动态 DP
动态 dp 入门题。
题意
给定一棵 \(n\) 个节点的树,第 \(i\) 个点的点权为 \(a_i\)。
接下来有 \(m\) 次操作。每次操作给定 \(x\) 和 \(y\),把 \(a_x\) 修改为 \(y\)。
你需要在每次操作之后求出这棵树的最大权独立集的权值大小。
\(1 \leq n,m \leq 10^5\),任何时刻 \(|a_i| \leq 10^2\)。
思路
遇到最大权独立集问题可以考虑 dp。
我们令 \(f_{u,0}\) 表示以 \(u\) 为根的子树中不选 \(u\) 的最大权独立集,\(f_{u,1}\) 表示以 \(u\) 为根的子树中选 \(u\) 的最大权独立集,则有:
接下来考虑维护修改操作。
考虑重链剖分的性质,每个点到根的路径上只会有 \(O(\log n)\) 条重链。
我们令 \(g_{u,0}\) 表示以 \(u\) 为根的子树中除去重儿子子树且不选 \(u\) 的最大权独立集,\(g_{u,1}\) 表示以 \(u\) 为根的子树中除去重儿子子树且选 \(u\) 的最大权独立集,则有:
动态 dp 的经典思想是用矩阵维护转移。所以我们定义广义矩阵乘法如下:
如果 \(n\) 行 \(m\) 列的矩阵 \(A\) 和一个 \(m\) 行 \(k\) 列的矩阵 \(B\) 得到的乘积是 \(n\) 行 \(k\) 列的矩阵 \(C\),则 \(C_{x,y}=\max_{i=1}^{m}(A_{x,i}+B_{i,y})\)。容易发现这种矩阵乘法具有结合律。
所以可以得到下面的式子:
我们考虑对于每个节点维护 \(\begin{bmatrix} g_{u,0} & g_{u,0} \\ g_{u,1} & -\infty \end{bmatrix}\)。我们发现,对于所有的叶子节点,\(f_{u,i}=g_{u,i}\)。对于一条重链,可以由叶子节点的转移矩阵反推每一点的 \(f\) 值。所以我们并不需要维护 \(f\)。
我们计算单点的 \(f\) 值复杂度为 \(O(\log n)\),需要条 \(O(\log n)\) 次重链,总的时间复杂度应为 \(O(m \log^2 n)\)。
代码
#include<iostream>
#include<cstdio>
#include<vector>
using namespace std;
const int INF=0x3f3f3f3f;
int num[100010];
int tot,son[100010],pa[100010],dfn[100010],top[100010],rev[100010],bottom[100010],child[100010];
int dp_f[100010][2],dp_g[100010][2];
vector<int> G[100010],T[100010];
void dfs1(int u,int fa){
child[u]=1;
for(int i=0;i<G[u].size();i++){
int v=G[u][i];
if(v!=fa){
T[u].push_back(v);
pa[v]=u;
dfs1(v,u);
child[u]+=child[v];
if(child[v]>=child[son[u]]){
son[u]=v;
}
}
}
}
void dfs2(int u,int fa){
rev[++tot]=u;
dfn[u]=tot;
top[u]=fa;
bottom[fa]=u;
if(son[u]){
dfs2(son[u],fa);
}
for(int i=0;i<T[u].size();i++){
int v=T[u][i];
if(v!=son[u]){
dfs2(v,v);
}
}
}
void dfs3(int u){
dp_f[u][1]=dp_g[u][1]=num[u];
for(int i=0;i<T[u].size();i++){
int v=T[u][i];
dfs3(v);
dp_f[u][0]+=max(dp_f[v][0],dp_f[v][1]);
dp_f[u][1]+=dp_f[v][0];
if(v!=son[u]){
dp_g[u][0]+=max(dp_f[v][0],dp_f[v][1]);
dp_g[u][1]+=dp_f[v][0];
}
}
}
struct Matrix{
int matrix[2][2];
};
const Matrix operator *(const Matrix &x,const Matrix &y){
Matrix z;
for(int i=0;i<2;i++){
for(int j=0;j<2;j++){
z.matrix[i][j]=max(x.matrix[i][0]+y.matrix[0][j],x.matrix[i][1]+y.matrix[1][j]);
}
}
return z;
}
struct Node{
int l,r;
Matrix g;
}a[400010];
void pushup(int id){
a[id].g=a[id*2].g*a[id*2+1].g;
}
void build(int id,int l,int r){
a[id].l=l;
a[id].r=r;
if(a[id].l==a[id].r){
a[id].g.matrix[0][0]=a[id].g.matrix[0][1]=dp_g[rev[l]][0];
a[id].g.matrix[1][0]=dp_g[rev[l]][1];
a[id].g.matrix[1][1]=-INF;
}
else{
int mid=(l+r)>>1;
build(id*2,l,mid);
build(id*2+1,mid+1,r);
pushup(id);
}
}
Matrix query(int id,int l,int r){
if(l<=a[id].l && a[id].r<=r){
return a[id].g;
}
bool flag=false;
Matrix ans;
if(l<=a[id*2].r){
flag=true;
ans=query(id*2,l,r);
}
if(a[id*2+1].l<=r){
if(flag==false){
ans=query(id*2+1,l,r);
}
else{
ans=ans*query(id*2+1,l,r);
}
}
return ans;
}
void modify(int id,int pos,Matrix dif){
if(a[id].l==a[id].r){
a[id].g=dif;
return ;
}
if(pos<=a[id*2].r){
modify(id*2,pos,dif);
}
else{
modify(id*2+1,pos,dif);
}
pushup(id);
}
struct Query{
int dp0,dp1;
};
Query query_f(int u){
int id_u=dfn[u],id_bottom=dfn[bottom[u]];
Matrix tmp=query(1,id_bottom,id_bottom),tmp2=query(1,id_u,id_bottom-1);
Query tmp3=(Query){tmp.matrix[0][0],tmp.matrix[1][0]};
if(id_bottom==id_u) return tmp3;
Query ans=(Query){max(tmp2.matrix[0][0]+tmp3.dp0,tmp2.matrix[0][1]+tmp3.dp1),max(tmp2.matrix[1][0]+tmp3.dp0,tmp2.matrix[1][1]+tmp3.dp1)};
return ans;
}
int main(){
int n,m;
scanf("%d %d",&n,&m);
for(int i=1;i<=n;i++){
scanf("%d",&num[i]);
}
for(int i=1;i<n;i++){
int u,v;
scanf("%d %d",&u,&v);
G[u].push_back(v);
G[v].push_back(u);
}
dfs1(1,-1);
dfs2(1,1);
for(int i=1;i<=n;i++){
bottom[i]=bottom[top[i]];
}
dfs3(1);
build(1,1,n);
while(m--){
int u,dif;
scanf("%d %d",&u,&dif);
Query lst=query_f(top[u]);
Matrix tmp=query(1,dfn[u],dfn[u]);
tmp.matrix[1][0]+=dif-num[u];
num[u]=dif;
modify(1,dfn[u],tmp);
while(top[u]!=1){
int pre=top[u];
Query f=query_f(pre);
int fa=pa[pre];
tmp=query(1,dfn[fa],dfn[fa]);
tmp.matrix[0][0]+=max(f.dp0,f.dp1)-max(lst.dp0,lst.dp1);
tmp.matrix[0][1]+=max(f.dp0,f.dp1)-max(lst.dp0,lst.dp1);
tmp.matrix[1][0]+=f.dp0-lst.dp0;
lst=query_f(top[fa]);
modify(1,dfn[fa],tmp);
u=fa;
}
lst=query_f(1);
printf("%d\n",max(lst.dp0,lst.dp1));
}
return 0;
}

浙公网安备 33010602011771号