[学习笔记]整体DP

问题:

有一些问题,通常见于二维的DP,另一维记录当前x的信息,但是这一维过大无法开下,O(nm)也无法通过。

但是如果发现,对于x,在第二维的一些区间内,取值都是相同的,并且这样的区间是有限个,就可以批量处理

 

思想:

通过动态开点线段树维护第二维,

如果某个节点没有儿子,那么这个节点区间都是同一个权值。

也即,一个节点是空节点,那么这个节点所有的值和父亲的值都一致。(其实它的兄弟也是空节点的)

对于序列的问题,

可以直接扫过去,修改某些位置的点。

或者线段树合并。

对于树上的问题,

线段树合并。

 

实现:

主要考虑什么时候线段树合并停止。以及pushdown的标记问题。

当x都没有儿子或者y都没有儿子时候,整个x的区间或整个y的区间都是同一个值,可以直接计算贡献转移过来(这个必须支持,否则不能整体DP)。

否则,pushdown,进行递归

pushdown时候建立新的儿子(如果之前没有)。

空间复杂度和时间复杂度基本一致。O(nlogn)

 

只要满足,在x都没有儿子或者y都没有儿子时候,可以快速合并然后return,那么就可以整体DP了。

 

例题1:[九省联考2018]秘密袭击coat

例题2:

 

$dp[x][c]=\Pi (sumy-dp[y][c])$sumy表示y的所有dp[y][*]的和

在x都没有儿子或者y都没有儿子时候,我们要么知道每个x的值,要么知道每个y的值。

在x都没有儿子时候,把y的节点内每个数乘-1再加sumy,再乘上x区间的值。

y都没有儿子时候,直接用(sumy-val)乘给x即可。

code:

#include<bits/stdc++.h>
#define reg register int
#define il inline
#define fi first
#define se second
#define mk(a,b) make_pair(a,b)
#define numb (ch^'0')
#define pb push_back
#define solid const auto &
#define enter cout<<endl
#define pii pair<int,int>
using namespace std;
typedef long long ll;
template<class T>il void rd(T &x){
    char ch;x=0;bool fl=false;while(!isdigit(ch=getchar()))(ch=='-')&&(fl=true);
    for(x=numb;isdigit(ch=getchar());x=x*10+numb);(fl==true)&&(x=-x);}
template<class T>il void output(T x){if(x/10)output(x/10);putchar(x%10+'0');}
template<class T>il void ot(T x){if(x<0) putchar('-'),x=-x;output(x);putchar(' ');}
template<class T>il void prt(T a[],int st,int nd){for(reg i=st;i<=nd;++i) ot(a[i]);putchar('\n');}
namespace Modulo{
const int mod=998244353;
int ad(int x,int y){return (x+y)>=mod?x+y-mod:x+y;}
void inc(int &x,int y){x=ad(x,y);}
int mul(int x,int y){return (ll)x*y%mod;}
void inc2(int &x,int y){x=mul(x,y);}
int qm(int x,int y=mod-2){int ret=1;while(y){if(y&1) ret=mul(x,ret);x=mul(x,x);y>>=1;}return ret;}
}
using namespace Modulo;
namespace Miracle{
const int N=2e5+5;
int n,m,k;
struct node{
    int nxt,to;
}e[2*N];
int hd[N],cnt;
void add(int x,int y){
    e[++cnt].nxt=hd[x];
    e[cnt].to=y;
    hd[x]=cnt;
}
#define mid ((l+r)>>1)
struct tr{
    int sum,mul,ad;
    int ls,rs,val;
    void op(){
        cout<<"SUM "<<sum<<" MUL "<<mul<<" AD "<<ad<<endl;
    }
}t[20000000+3];
int tot,S;
vector<int>no[N];
int rt[N];
int nc(){
    ++tot;
    t[tot].sum=0;t[tot].mul=1;t[tot].ad=0;
    t[tot].ls=t[tot].rs=0;t[tot].val=0;
    return tot;
}
void tag(int x,int l,int r,int ml,int aa){
    // cout<<" tag "<<x<<" l "<<l<<" r "<<r<<" ml "<<ml<<" ad "<<aa<<endl;
    // t[x].op();
    t[x].sum=mul(t[x].sum,ml);
    t[x].sum=ad(t[x].sum,mul(r-l+1,aa));
    t[x].val=ad(mul(t[x].val,ml),aa);
    t[x].mul=mul(t[x].mul,ml);
    t[x].ad=ad(mul(t[x].ad,ml),aa);
}
void pushup(int x){
    t[x].sum=ad(t[t[x].ls].sum,t[t[x].rs].sum);
}
void pushdown(int x,int l,int r){
    if(!t[x].ls) t[x].ls=nc();
    if(!t[x].rs) t[x].rs=nc();
    tag(t[x].ls,l,mid,t[x].mul,t[x].ad);
    tag(t[x].rs,mid+1,r,t[x].mul,t[x].ad);
    t[x].mul=1;t[x].ad=0;
}
void upda(int &x,int l,int r,int p){
    // cout<<" pp "<<p<<" x "<<x<<" l "<<l<<" r "<<r<<" sm "<<t[x].sum<<" mul "<<t[x].mul<<" ad "<<t[x].ad<<endl;
    // cout<<" ls "<<t[x].ls<<" rs "<<t[x].rs<<endl; 
    if(!x) x=nc();
    if(l==r){
        // cout<<" ss "<<t[x].sum<<endl;
        t[x].sum=0;
        t[x].val=0;
        return;
    }
    pushdown(x,l,r);
    if(p<=mid) upda(t[x].ls,l,mid,p);
    else upda(t[x].rs,mid+1,r,p);
    pushup(x);
}
int merge(int x,int y,int l,int r){
    if(!t[x].ls&&!t[x].rs){
        swap(x,y);
        int v=t[y].val;
        tag(x,l,r,mod-1,S);
        tag(x,l,r,v,0);
    }else if(!t[y].ls&&!t[y].rs){
        int v=t[y].val;
        tag(x,l,r,ad(S,mod-v),0);
    }else{
        pushdown(x,l,r);pushdown(y,l,r);
        t[x].ls=merge(t[x].ls,t[y].ls,l,mid);
        t[x].rs=merge(t[x].rs,t[y].rs,mid+1,r);
        pushup(x);
    }
    return x;//warining!!
}
void dfs(int x,int fa){
    rt[x]=nc();
    tag(rt[x],1,m,1,1);
    for(reg i=hd[x];i;i=e[i].nxt){
        int y=e[i].to;
        if(y==fa) continue;
        dfs(y,x);
        S=t[rt[y]].sum;
        rt[x]=merge(rt[x],rt[y],1,m);
        // cout<<y<<" back "<<x<<" : "<<" sum "<<t[rt[x]].sum<<endl;
    }
    for(solid c:no[x]){
        upda(rt[x],1,m,c);
    }
    // cout<<x<<" : "<<" sum "<<t[rt[x]].sum<<endl;
}
int main(){
    rd(n);rd(m);rd(k);
    int x,y;
    for(reg i=1;i<n;++i){
        rd(x);rd(y);
        add(x,y);add(y,x);
    }
    for(reg i=1;i<=k;++i){
        rd(x);rd(y);
        no[x].push_back(y);
    }
    dfs(1,0);
    printf("%d",t[rt[1]].sum);
    return 0;
}

}
signed main(){
    Miracle::main();
    return 0;
}

/*
   Author: *Miracle*
*/
View Code

 

posted @ 2019-05-29 11:30  *Miracle*  阅读(1386)  评论(0编辑  收藏  举报