树上启发式合并

树上启发式合并,一种美妙的黑科技,可以用普通的优化让你$n^2$变成严格$n log$,解决一些类似(树上数颜色,树上查众数)这样的问题

首先你要知道暴力为什么是$n^2$的

以这个图为例

 

每次你从一个节点开始向下搜,你从1节点搜到3,搜完这个子树然后你需要把3存的col等信息删去再遍历另一个子树才是正确的

那么我们每次遍历这个节点一个子树,每次搜完这棵子树都要清空当前子树储存信息这样(最差)复杂度$n^2$

我们可以发现清空最后一个遍历的子树是没有意义的,那么我们人为把最后一个子树放到最后不就是最优的吗

所以,首先我们先找出来重链,轻链,对于轻链我们求出子树答案,再清除子树贡献,.然后求出重链上子树答案,不清除贡献.最后我们再算一遍子树对当前节点贡献即可

你可能会认为,这不就是一个简单的优化吗,怎么就是$n log$了

我不知道

它并没有优化最优复杂度而是避免了最差复杂度

以给一棵根为1的树,每次询问子树颜色种类数为例

代码大致如下

#include<bits/stdc++.h>
using namespace std;
#define ll int
#define r register 
#define A 1001010
ll head[A],nxt[A],ver[A],size[A],col[A],cnt[A],ans[A],son[A];
ll tot=0,num,sum,nowson,n,m,xx,yy;
inline void add(ll x,ll y){
    nxt[++tot]=head[x],head[x]=tot,ver[tot]=y;
}
inline ll read(){
    ll f=1,x=0;char c=getchar();
    while(!isdigit(c)){
        if(c=='-') f=-1;
        c=getchar();
    }
    while(isdigit(c))
        x=(x<<1)+(x<<3)+(c^48),c=getchar();
    return f*x;
}
void dfs(ll x,ll fa){
    size[x]=1;
    for(ll i=head[x];i;i=nxt[i]){
        ll y=ver[i];
        if(y==fa) continue;
        dfs(y,x);
        size[x]+=size[y];
        if(size[son[x]]<size[y])
            son[x]=y;
    }
}
void cal(ll x,ll fa,ll val){
    if(!cnt[col[x]]) ++sum;
    cnt[col[x]]+=val;
    for(ll i=head[x];i;i=nxt[i]){
        ll y=ver[i];
        if(y==fa||y==nowson) continue;
        cal(y,x,val); 
    }
}
void dsu(ll x,ll fa,bool op){
    for(ll i=head[x];i;i=nxt[i]){
        ll y=ver[i];
        if(y==fa||y==son[x])
            continue;
        dsu(y,x,0);
        //从轻儿子出发
    }
    if(son[x])
        dsu(son[x],x,1),nowson=son[x];
    cal(x,fa,1);nowson=0;
    ans[x]=sum;
    if(!op){
        cal(x,fa,-1);
        sum=0;
    }
}
int main(){
    n=read();
    for(ll i=1;i<=n-1;i++){
        xx=read(),yy=read();
        add(xx,yy),add(yy,xx);
    }
    for(ll i=1;i<=n;i++)
        col[i]=read();
    dfs(1,0);
    dsu(1,0,1);
    m=read();
    for(ll i=1;i<=m;i++){
        xx=read();
        printf("%d\n",ans[xx]);
    }
}

另一种打法

#include<iostream>
#include<cstdio>
#include<cstring>
#include<cmath>
using namespace std;
#define R register
#define ll long long
inline ll read(){
    ll aa=0;R int bb=1;char cc=getchar();
    while(cc<'0'||cc>'9')
        {if(cc=='-')bb=-1;cc=getchar();}
    while(cc>='0'&&cc<='9')
        {aa=(aa<<1)+(aa<<3)+(cc^48);cc=getchar();}
    return aa*bb;
}
const int N=1e5+3;
struct edge{
    int v,last;
}ed[N<<1];
int first[N],tot;
inline void add(int x,int y)
{
    ed[++tot].v=y;
    ed[tot].last=first[x];
    first[x]=tot;
}
int n,m,c[N],son[N],cnt[N],ans[N],siz[N];
void dfsi(int x,int fa)
{
    siz[x]=1;
    for(R int i=first[x],v;i;i=ed[i].last){
        v=ed[i].v;
        if(v==fa)continue;
        dfsi(v,x);
        siz[x]+=siz[v];
        if(siz[v]>siz[son[x]])son[x]=v;
    }
    return;
}
int dfsj(int x,int fa,int bs,int kep)
{
    if(kep){
        for(R int i=first[x],v;i;i=ed[i].last){
            v=ed[i].v;
            if(v!=fa&&v!=son[x])
                dfsj(v,x,0,1);
        }
    }
    int res=0;
    if(son[x])res+=dfsj(son[x],x,1,kep);
    for(R int i=first[x],v;i;i=ed[i].last){
        v=ed[i].v;
        if(v!=fa&&v!=son[x])
            res+=dfsj(v,x,0,0);
    }
    if(!cnt[c[x]])res++;
    cnt[c[x]]++;
    if(kep){
        ans[x]=res;
        if(!bs)memset(cnt,0,sizeof(cnt));
    }
    return res;
}
int main()
{
    n=read();
    for(R int i=1,x,y;i<n;++i){
        x=read();y=read();
        add(x,y);add(y,x);
    }
    for(R int i=1;i<=n;++i)c[i]=read();
    dfsi(1,0); dfsj(1,0,1,1);
    m=read();
    for(R int i=1,x;i<=m;++i){
        x=read();
        printf("%d\n",ans[x]);
    }
    return 0;
}

虽然好像没什么区别

 

然后再看一道例题

有一棵 n 个节点的以 1 号节点为根的树,每个节点上有一个小桶,节点u上的小桶可以容纳${k_u}$ 个小球,ljh每次可以给一个节点到根路径上的所有节点的小桶内放一个小球,如果这个节点的小桶满了则不能放进这个节点,最后多次询问某个节点值

首先暴力不能过

直接权值线段树+线段树合并很难维护,树链剖分也难以维护,但我们直接树上启发式合并+线段树暴力修改可以维护。

首先单纯线段树暴力修改可以维护,但这会超时。于是我们用启发式合并作为时间复杂度保证,莫名奇妙AC了这个题

#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define A 1001010
ll head[A],nxt[A],ver[A],size[A],son[A],tong[A],col[A],getfa[A],isbigson[A],ans[A],al[A];
vector<pair<ll,ll> >v[A];
map<ll,ll>mp;
ll n,m,tot=0,Q,wwb=0;
struct tree{
    ll l,r,f,x,t,c;
}tr[A];
void add(ll x,ll y){
    nxt[++tot]=head[x],head[x]=tot,ver[tot]=y;
}
void prdfs(ll x,ll fa){
    size[x]=v[x].size()+1;
    for(ll i=head[x];i;i=nxt[i]){
        ll y=ver[i];
        if(y==fa) continue;
        prdfs(y,x);
        size[x]+=size[y];
        if(size[son[x]]<size[y])
            isbigson[son[x]]=0,son[x]=y,isbigson[y]=1;
    }
}
void built(ll p,ll l,ll r){
    tr[p].l=l,tr[p].r=r;
    if(tr[p].l==tr[p].r){
        return ;
    }
    ll mid=(l+r)>>1;
    built(p<<1,l,mid);
    built(p<<1|1,mid+1,r);
}
ll ask(ll p,ll pos){
    if(pos>=tr[p].t) return tr[p].c;
    return (pos>=tr[p<<1].t?tr[p<<1].c+ask(p<<1|1,pos-tr[p<<1].t):ask(p<<1,pos));
}
void insert(ll p,ll pos,ll t,ll c){
    if(tr[p].l==tr[p].r)
        {tr[p].t+=t;tr[p].c+=c;return;}
    if(pos<=tr[p<<1].r)
        insert(p<<1,pos,t,c);
    else 
        insert(p<<1|1,pos,t,c);
    tr[p].t=tr[p<<1].t+tr[p<<1|1].t;
    tr[p].c=tr[p<<1].c+tr[p<<1|1].c;
}
void up(ll x,ll fa){
    if(v[getfa[x]].size()<v[getfa[fa]].size()){
        for(ll i=0;i<v[getfa[x]].size();i++)
            v[getfa[fa]].push_back(v[getfa[x]][i]);
        v[getfa[x]].clear();
        getfa[x]=getfa[fa];
    }
    else{
        for(ll i=0;i<v[getfa[fa]].size();i++)
            v[getfa[x]].push_back(v[getfa[fa]][i]);
        v[getfa[fa]].clear();
        getfa[fa]=getfa[x];
    }
}
void dfs(ll x,ll fa){

    for(ll i=head[x];i;i=nxt[i]){
        ll y=ver[i];
        if(y==fa||y==son[x])    continue;
        dfs(y,x);
    }
    if(son[x]) dfs(son[x],x);
    for(ll i=0;i<v[getfa[x]].size();i++){
        ll tim=v[getfa[x]][i].first,col=v[getfa[x]][i].second;
        if(!al[col])    al[col]=tim,insert(1,tim,1,1);
        else if(al[col]>tim){
            insert(1,al[col],0,-1);
            insert(1,tim,1,1);
            al[col]=tim;
        }
        else insert(1,tim,1,0);
    }
//    printf("t=%lld tong=%lld\n",tr[1].t,tong[x]);
    ans[x]=ask(1,min(tr[1].t,tong[x]));
    if(son[x])
        up(son[x],x);
    if(!isbigson[x]){
        for(ll i=0;i<v[getfa[x]].size();i++){
            ll tim=v[getfa[x]][i].first,col=v[getfa[x]][i].second;
            if(al[col]==tim)
                insert(1,tim,-1,-1),al[col]=0;
            else 
                insert(1,tim,-1,0);
        }
        up(x,fa);
    }    
/*    for(ll i=1;i<=5;i++){
        printf("ans=%lld ",ans[i]);
    }
*//*    cout<<endl;*/
}
int main(){
    scanf("%lld",&n);
    for(ll i=1;i<n;i++){
        ll xx,yy;
        scanf("%lld%lld",&xx,&yy);
        add(xx,yy),add(yy,xx);
    }
    for(ll i=1;i<=n;i++){
        scanf("%lld",&tong[i]);
        getfa[i]=i;
    }
    prdfs(1,0);
    scanf("%lld",&m);built(1,1,m);
    for(ll i=1,x,c;i<=m;i++){
        scanf("%lld%lld",&x,&c);
        if(!mp[c])
            mp[c]=++wwb;
        //离散化
        v[x].push_back(make_pair(i,mp[c]));
    }
    dfs(1,0);
    scanf("%lld",&Q);
    for(ll i=1,x;i<=Q;i++){
        scanf("%lld",&x);
        printf("%lld\n",ans[x]);
    }
}

 

posted @ 2019-07-30 21:44  znsbc  阅读(532)  评论(0编辑  收藏  举报