洛谷4643:【模板】动态dp——题解

https://www.luogu.org/problemnew/show/P4643

很妙……让我重新又看了一遍猫锟的WC课件。

推荐一个有markdown神犇题解:https://www.cnblogs.com/RabbitHu/p/9112811.html

本文的代码和就是在此基础上改动与细化(更符合我这种蒟蒻的阅读体验)

————————————————————————

这道题是课件的模板题。

首先需要明白这个最大独立集是指取了u结点则不能取与u相连的点v。

不带修改的话能够看出这就是“没有上司的舞会”,于是先把静态的dp敲出来。

f[i][0/1]为节点i当i不取/取的时候其子树产生的最大价值。

方程f[u][0]=sigma(max(f[v][0],f[v][1]))

f[u][1]=w[u]+sigma(f[v][0])

接下来让它“动”起来,按照一般套路修改应当在线段树上做,于是先码一个树链剖分再说。

我们发现:重链的信息好储存,但是重链的侧链(轻链)没有办法只靠f就能够将信息合并到轻链上。

于是思考可以再开一个数组来压缩一些信息使其能够放到重链上。

g[i][0/1]表示节点i当i不取/取时,i不在这条链上的子孙的答案(即最大独立集)。

不难用g来更新f数组。

f[u][0]=g[u][0]+max(f[v][0],f[v][1])

f[u][1]=g[u][1]+f[v][0]

(u,v在一条重链上,且fa[v]=u)

为了去除冗杂,我们采用矩阵的方法来表示这个式子。

g[i][0],g[i][0]    (运算->)f[v][0]  (等于) f[u][0]

g[i][1],   0                         f[v][1]                 f[u][1]

运算定义如下(就直接拿代码来说了,反正您们看得懂):

matrix operator *(const matrix &b)const{
        matrix c;
        for(int i=0;i<2;i++)
            for(int j=0;j<2;j++)
                for(int k=0;k<2;k++)
                    c.g[i][j]=max(c.g[i][j],g[i][k]+b.g[k][j]);
        return c;
    }

用线段树维护矩阵,则1所在的重链的所有节点的矩阵运算在一起即为1结点不取/取的答案。

可能你会有疑问,我们只维护了g数组,怎么就得出了f数组的功能呢?

别忘了链的底端u是没有v的啊!所以我们只用g数组往前推就行了啊。

那么修改u,就需要将u到1的路径上所有的重链的信息全部修改一遍。

为了优化时间,不至于每次修改都要重新搜一遍该点所连接的所有非链上的点(TLE警告),我们开一个val矩阵,其功能可以理解为线段树上的lazy。初始时val就等于对应结点的矩阵。

实际上就是修改一条重链i,对于它的父亲重链j的最后一个结点要根据i所得到的f值来更新这个结点的g矩阵。

细节讲起来也是很麻烦的,直接看代码吧(反正您们看得懂)。

void path_modify(int u,int c){
    val[pos[u]].g[1][0]+=c-w[u];w[u]=c;
    while(u){
        matrix od=query(1,1,n,pos[top[u]],pos[ed[u]]);
        modify(1,1,n,pos[u]);//将会用val矩阵替换掉对应位置的矩阵
        matrix nw=query(1,1,n,pos[top[u]],pos[ed[u]]);
        u=fa[top[u]];
        val[pos[u]].g[0][0]+=max(nw.g[0][0],nw.g[1][0])-max(od.g[0][0],od.g[1][0]);
        val[pos[u]].g[0][1]=val[pos[u]].g[0][0];
        val[pos[u]].g[1][0]+=nw.g[0][0]-od.g[0][0];
    }
}

于是我们成功地AC了这道题(但愿这种题永远不要出出来。)

#include<cmath>
#include<queue>
#include<cstdio>
#include<cctype>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N=1e5+5;
inline int read(){
    int X=0,w=0;char ch=0;
    while(!isdigit(ch)){w|=ch=='-';ch=getchar();}
    while(isdigit(ch))X=(X<<3)+(X<<1)+(ch^48),ch=getchar();
    return w?-X:X;
}
struct matrix{
    ll g[2][2];
    matrix(){
        memset(g,0,sizeof(g));
    }
    matrix operator *(const matrix &b)const{
        matrix c;
        for(int i=0;i<2;i++)
            for(int j=0;j<2;j++)
                for(int k=0;k<2;k++)
                    c.g[i][j]=max(c.g[i][j],g[i][k]+b.g[k][j]);
        return c;
    }
}val[N],tr[N*4];
struct node{
    int to,nxt;
}e[N*2];
ll w[N],f[N][2];
int n,m,cnt,tot,head[N];
int dep[N],fa[N],size[N],son[N],top[N],pos[N],idx[N],ed[N];
inline void add(int u,int v){
    e[++cnt].to=v;e[cnt].nxt=head[u];head[u]=cnt;
}
void dfs1(int u){
    int sum=0;size[u]=1;
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==fa[u])continue;
        fa[v]=u;dfs1(v);
        size[u]+=size[v];
        if(!son[u]||size[son[u]]<size[v])son[u]=v;
        f[u][0]+=max(f[v][0],f[v][1]);
        sum+=f[v][0];
    }
    f[u][1]=sum+w[u];
}
void dfs2(int u,int anc){
    pos[u]=++tot;idx[tot]=u;top[u]=anc;
    if(!son[u]){ed[u]=u;return;}
    dfs2(son[u],anc);ed[u]=ed[son[u]];
    for(int i=head[u];i;i=e[i].nxt){
        int v=e[i].to;
        if(v==fa[u]||v==son[u])continue;
        dfs2(v,v);
    }
}
void init(){
    dep[1]=1;
    dfs1(1);
    dfs2(1,1);
}
void build(int a,int l,int r){
    if(l==r){
        int u=idx[l];
        ll g0=0,g1=w[u];
        for(int i=head[u];i;i=e[i].nxt){
            int v=e[i].to;
            if(v==fa[u]||v==son[u])continue;
            g0+=max(f[v][0],f[v][1]);g1+=f[v][0];
        }
        tr[a].g[0][0]=tr[a].g[0][1]=g0;
        tr[a].g[1][0]=g1;
        val[l]=tr[a];
        return;
    }
    int mid=(l+r)>>1;
    build(a<<1,l,mid);build(a<<1|1,mid+1,r);
    tr[a]=tr[a<<1]*tr[a<<1|1];
}
matrix query(int a,int l,int r,int l1,int r1){
    if(l1<=l&&r<=r1)return tr[a];
    int mid=(l+r)>>1;
    if(r1<=mid)return query(a<<1,l,mid,l1,r1);
    if(l1>mid)return query(a<<1|1,mid+1,r,l1,r1);
    return query(a<<1,l,mid,l1,mid)*query(a<<1|1,mid+1,r,mid+1,r1);
}
void modify(int a,int l,int r,int k){
    if(l==r){
        tr[a]=val[l];
        return;
    }
    int mid=(l+r)>>1;
    if(k<=mid)modify(a<<1,l,mid,k);
    else modify(a<<1|1,mid+1,r,k);
    tr[a]=tr[a<<1]*tr[a<<1|1];
}
void path_modify(int u,int c){
    val[pos[u]].g[1][0]+=c-w[u];w[u]=c;
    while(u){
        matrix od=query(1,1,n,pos[top[u]],pos[ed[u]]);
        modify(1,1,n,pos[u]);
        matrix nw=query(1,1,n,pos[top[u]],pos[ed[u]]);
        u=fa[top[u]];
        val[pos[u]].g[0][0]+=max(nw.g[0][0],nw.g[1][0])-max(od.g[0][0],od.g[1][0]);
        val[pos[u]].g[0][1]=val[pos[u]].g[0][0];
        val[pos[u]].g[1][0]+=nw.g[0][0]-od.g[0][0];
    }
}
int main(){
    n=read(),m=read();
    for(int i=1;i<=n;i++)w[i]=read();
    for(int i=1;i<n;i++){
        int u=read(),v=read();
        add(u,v);add(v,u);
    }
    init();
    build(1,1,n);
    for(int i=1;i<=m;i++){
        int u=read(),x=read();
        path_modify(u,x);
        matrix ans=query(1,1,n,pos[top[1]],pos[ed[1]]);
        printf("%lld\n",max(ans.g[0][0],ans.g[1][0]));
    }
    return 0;
}

 

posted @ 2018-05-31 21:57  luyouqi233  阅读(319)  评论(0编辑  收藏  举报