[学习笔记]动态dp

其实就过了模板。

感觉就是带修改的dp

 

【模板】动态dp

给定一棵n个点的树,点带点权。

m次操作,每次操作给定x,y表示修改点x的权值为y

你需要在每次操作之后求出这棵树的最大权独立集的权值大小。

 

n,m<=1e5

 

参考题解:shadowice1984

n^2 DP简单又自然。

但是对于1e5次修改就不行了。

 

每一次修改会影响整个到根的链上的值。

采用树剖。

ldp[i][0/1]表示i选不选,对于所有的轻儿子dp值。

dp[i][0/1]表示i选不选,对于总共的所有儿子的dp值。

ldp[i][0]=∑max(ldp[lightson][1],ldp[lightson][0])

ldp[i][1]=∑ldp[lightson][0]

dp[i][0]=ldp[i][0]+max(dp[heavyson][1],dp[heavyson][0])

dp[i][1]=ldp[i][1]+dp[heavyson][0]

可以先把这个dp都求出来。

 

然后怎么维护?自然要用线段树维护dfs序。

采用矩阵。

a*b定义为:

c[i][j]=max(a[i][k]+b[k][j])

有结合律。

线段树维护区间矩阵乘积。(注意从右往左乘,自下而上)

只要在最前面乘上一个初始矩阵

第一行是0,第二行是-inf的矩阵。

就可以求出某个点的最终dp值了。

 

修改的时候,暴力修改这个 点的ldp0,ldp1

但是还会影响这个fa[top[x]]的ldp0,ldp1

所以要求出dp[top[x]],dp[top[y]]为了避免常数过大,

用一个数组记录dp值,然后把前后两次最大值的差值来修改fa[top[x]]的ldp0,ldp1

然后跳一条链,到fa[top[x]]

这样单次修改log^2n

 

每次返回max(dp[1][0],dp[1][1])

普通线段树版:(3000ms)

#include<bits/stdc++.h>
#define reg register int
#define il inline
#define numb (ch^'0')
#define mid ((l+r)>>1)
using namespace std;
typedef long long ll;
il void rd(int &x){
    char ch;x=0;bool fl=false;
    while(!isdigit(ch=getchar()))(ch=='-')&&(fl=true);
    for(x=numb;isdigit(ch=getchar());x=x*10+numb);
    (fl==true)&&(x=-x);
}
namespace Miracle{
const int N=1e5+5;
const int inf=0x3f3f3f3f;
int n,m;
struct node{
    int nxt,to;
}e[2*N];
int hd[N],cnt;
void add(int x,int y){
    e[++cnt].nxt=hd[x];
    e[cnt].to=y;
    hd[x]=cnt;
}
struct tr{
    int a[3][3];
    void init(int x,int y){//x:ldp0 y:ldp1
        a[1][1]=x,a[2][1]=x;
        a[1][2]=y,a[2][2]=-inf;
    }
    void pre(){
        memset(a,-inf,sizeof a);
    }
    void st(){
        a[1][1]=0,a[1][2]=-inf,a[2][1]=-inf,a[2][2]=0;
    }
    tr operator *(const tr& b){
        tr c;c.pre();
        for(reg i=1;i<=2;++i){
            for(reg k=1;k<=2;++k){
                for(reg j=1;j<=2;++j){
                    c.a[i][j]=max(c.a[i][j],a[i][k]+b.a[k][j]);
                }
            }
        }return c;
    }
    void op(){
        cout<<left<<setw(10)<<a[1][1]<<" "<<left<<setw(10)<<a[1][2]<<endl;
        cout<<left<<setw(10)<<a[2][1]<<" "<<left<<setw(10)<<a[2][2]<<endl;
        cout<<endl;
    }
}s[N],t[4*N],A;
int dfn[N],top[N],dfn2[N],fdfn[N],sz[N],dep[N],son[N];
int nd[N];//tot;//num of heavy chain
int fa[N];
int df;
int ldp[N][2],dp[N][2];
int w[N];
void dfs1(int x,int d){
    dep[x]=d;
    sz[x]=1;
    for(reg i=hd[x];i;i=e[i].nxt){
        int y=e[i].to;
        if(y==fa[x]) continue;
        fa[y]=x;
        dfs1(y,d+1);
        if(sz[y]>sz[son[x]]){
            son[x]=y;
        }
    }
}
void dfs2(int x){
    dfn[x]=++df;fdfn[df]=x;
    if(!top[x]) {
        top[x]=x;nd[top[x]]=x;
    }
    if(son[x]) top[son[x]]=top[x],nd[top[x]]=son[x],dfs2(son[x]);
    
    dp[x][1]=w[x];
    ldp[x][1]=w[x];
    for(reg i=hd[x];i;i=e[i].nxt){
        int y=e[i].to;
        if(y==son[x]||y==fa[x]) continue;
        dfs2(y);
        ldp[x][0]+=max(dp[y][0],dp[y][1]);    
        ldp[x][1]+=dp[y][0];
    }
    if(son[x]){
        dp[x][1]=ldp[x][1]+dp[son[x]][0];
        dp[x][0]=ldp[x][0]+max(dp[son[x]][1],dp[son[x]][0]);
    }
    s[x].init(ldp[x][0],ldp[x][1]);
}
void pushup(int x){
    t[x]=t[x<<1|1]*t[x<<1];
}
void build(int x,int l,int r){
    if(l==r){
        t[x]=s[fdfn[l]];return;
    }
    build(x<<1,l,mid);build(x<<1|1,mid+1,r);
    pushup(x);
}
tr query(int x,int l,int r,int L,int R){
    if(L<=l&&r<=R){
        return t[x];
    }
    tr ret;ret.st();
    if(mid<R) ret=ret*query(x<<1|1,mid+1,r,L,R);
    if(L<=mid) ret=ret*query(x<<1,l,mid,L,R);
    return ret;
}
void add(int x,int l,int r,int to,int p,int c){
    if(l==r){
        if(p) t[x].a[1][2]+=c;
        else t[x].a[1][1]+=c,t[x].a[2][1]+=c;
        return;
    }
    if(to<=mid) add(x<<1,l,mid,to,p,c);
    else if(mid<to) add(x<<1|1,mid+1,r,to,p,c);
    pushup(x);
}
int tmp[2];
int to[2];
int upda(int x,int y){
    tmp[0]=tmp[1]=0;
    to[0]=to[1]=0;
    tmp[1]=y-w[x];
    w[x]=y;
    while(x){
        tr anc=A*query(1,1,n,dfn[top[x]],dfn[nd[top[x]]]);
        to[0]=anc.a[1][1],to[1]=anc.a[1][2];
        add(1,1,n,dfn[x],0,tmp[0]);
        add(1,1,n,dfn[x],1,tmp[1]);
        anc=A*query(1,1,n,dfn[top[x]],dfn[nd[top[x]]]);
        tmp[0]=max(anc.a[1][1],anc.a[1][2])-max(to[0],to[1]);
        tmp[1]=anc.a[1][1]-to[0];
        x=fa[top[x]];
    }
    tr ans=A*query(1,1,n,dfn[top[1]],dfn[nd[top[1]]]);
    return max(ans.a[1][1],ans.a[1][2]);
}
int main(){
    scanf("%d%d",&n,&m);
    for(reg i=1;i<=n;++i)rd(w[i]);
    int x,y;
    for(reg i=1;i<=n-1;++i){
        rd(x);rd(y);add(x,y);add(y,x);
    }
    dfs1(1,1);
    dfs2(1);
    build(1,1,n);
    A.a[1][1]=0,A.a[1][2]=0;
    A.a[2][1]=-inf,A.a[2][2]=-inf;
    while(m--){
        rd(x);rd(y);
        printf("%d\n",upda(x,y));
    }
    return 0;
}

}
int main(){
    Miracle::main();
    return 0;
}

/*
   Author: *Miracle*
   Date: 2018/11/12 16:29:49
*/

 

zkw线段树版:(1500ms)

#include<bits/stdc++.h>
#define reg register int
#define il inline
#define numb (ch^'0')
#define mid ((l+r)>>1)
using namespace std;
typedef long long ll;
il void rd(int &x){
    char ch;x=0;bool fl=false;
    while(!isdigit(ch=getchar()))(ch=='-')&&(fl=true);
    for(x=numb;isdigit(ch=getchar());x=x*10+numb);
    (fl==true)&&(x=-x);
}
namespace Miracle{
const int N=1e5+5;
const int inf=0x3f3f3f3f;
int n,m;
struct node{
    int nxt,to;
}e[2*N];
int hd[N],cnt;
il void add(int x,int y){
    e[++cnt].nxt=hd[x];
    e[cnt].to=y;
    hd[x]=cnt;
}
struct tr{
    int a[3][3];
    void init(int x,int y){//x:ldp0 y:ldp1
        a[1][1]=x,a[2][1]=x;
        a[1][2]=y,a[2][2]=-inf;
    }
    void pre(){
        memset(a,-inf,sizeof a);
    }
    void st(){
        a[1][1]=0,a[1][2]=-inf,a[2][1]=-inf,a[2][2]=0;
    }
    tr operator *(const tr& b) const{
        tr c;c.pre();
        for(reg i=1;i<=2;++i){
            for(reg k=1;k<=2;++k){
                for(reg j=1;j<=2;++j){
                    c.a[i][j]=max(c.a[i][j],a[i][k]+b.a[k][j]);
                }
            }
        }return c;
    }
    void op(){
        cout<<left<<setw(10)<<a[1][1]<<" "<<left<<setw(10)<<a[1][2]<<endl;
        cout<<left<<setw(10)<<a[2][1]<<" "<<left<<setw(10)<<a[2][2]<<endl;
        cout<<endl;
    }
}s[N],t[4*N],A;
int dfn[N],top[N],dfn2[N],fdfn[N],sz[N],dep[N],son[N];
int nd[N];//tot;//num of heavy chain
int fa[N];
int df;
int ldp[N][2],dp[N][2];
int w[N];
il void dfs1(int x,int d){
    dep[x]=d;
    sz[x]=1;
    for(reg i=hd[x];i;i=e[i].nxt){
        int y=e[i].to;
        if(y==fa[x]) continue;
        fa[y]=x;
        dfs1(y,d+1);
        if(sz[y]>sz[son[x]]){
            son[x]=y;
        }
    }
}
il void dfs2(int x){
    dfn[x]=++df;fdfn[df]=x;
    if(!top[x]) {
        top[x]=x;nd[top[x]]=x;
    }
    if(son[x]) top[son[x]]=top[x],nd[top[x]]=son[x],dfs2(son[x]);
    
    dp[x][1]=w[x];
    ldp[x][1]=w[x];
    for(reg i=hd[x];i;i=e[i].nxt){
        int y=e[i].to;
        if(y==son[x]||y==fa[x]) continue;
        dfs2(y);
        ldp[x][0]+=max(dp[y][0],dp[y][1]);    
        ldp[x][1]+=dp[y][0];
    }
    if(son[x]){
        dp[x][1]=ldp[x][1]+dp[son[x]][0];
        dp[x][0]=ldp[x][0]+max(dp[son[x]][1],dp[son[x]][0]);
    }
    s[x].init(ldp[x][0],ldp[x][1]);
}
int up;
il void build(){
    up=1;
    for(;up<=n+1;up<<=1);
    for(reg i=up;i<=up+up-1;++i){
        if(i>=up+1&&i<=up+n) t[i]=s[fdfn[i-up]];
        else t[i]=A;
    }
    for(reg i=up-1;i;--i) t[i]=t[i<<1|1]*t[i<<1];
} 
il void chan(int to,int c0,int c1){
    reg i=up+to;
    t[i].a[1][1]+=c0;t[i].a[2][1]+=c0;
    t[i].a[1][2]+=c1;
    for(i>>=1;i;i>>=1){
        t[i]=t[i<<1|1]*t[i<<1];
    }
//    cout<<" after chan "<<endl;
}
il tr query(int l,int r){
    tr le,ri;le.st();ri.st();
    for(reg s=up+l-1,e=up+r+1;s^e^1;s>>=1,e>>=1){
//        cout<<s<<" "<<e<<endl;
        if(!(s&1)) le=t[s^1]*le;
        if(e&1) ri=ri*t[e^1];
    }
    return ri*le;
}
int tmp[2];
int to[2];
il int upda(int x,int y){
    tmp[0]=tmp[1]=0;
    to[0]=to[1]=0;
    tmp[1]=y-w[x];
    w[x]=y;
    while(x){
        //tr anc=A*query(1,1,n,dfn[top[x]],dfn[nd[top[x]]]);
        to[0]=dp[top[x]][0],to[1]=dp[top[x]][1];
        chan(dfn[x],tmp[0],tmp[1]);
        tr anc=A*query(dfn[top[x]],dfn[nd[top[x]]]);
        tmp[0]=max(anc.a[1][1],anc.a[1][2])-max(to[0],to[1]);
        tmp[1]=anc.a[1][1]-to[0];
        dp[top[x]][0]=anc.a[1][1],dp[top[x]][1]=anc.a[1][2];
        x=fa[top[x]];
    }
    return  max(dp[1][0],dp[1][1]);
}
int main(){
    scanf("%d%d",&n,&m);
    for(reg i=1;i<=n;++i)rd(w[i]);
    int x,y;
    for(reg i=1;i<=n-1;++i){
        rd(x);rd(y);add(x,y);add(y,x);
    }
    dfs1(1,1);
    dfs2(1);
    A.a[1][1]=0,A.a[1][2]=0;
    A.a[2][1]=-inf,A.a[2][2]=-inf;
    build();
    while(m--){
        rd(x);rd(y);
        printf("%d\n",upda(x,y));
    }
    return 0;
}

}
int main(){
//    freopen("data.in","r",stdin);
//    freopen("my.out","w",stdout);
    Miracle::main();
    return 0;
}

/*
   Author: *Miracle*
   Date: 2018/11/12 16:29:49
*/

 

posted @ 2018-11-12 21:44  *Miracle*  阅读(643)  评论(0编辑  收藏  举报