树上差分

 

应该算是树上统计问题的基础知识了

参考了这篇:Icy_Knight - 关于差分,树上差分的浅谈

 


 

~ 一维差分 ~

 

一维差分的最直观应用,就是维护 区间加减、单点查询 的树状数组

 

用$a_1,a_2,...,a_n$来保存区间修改对每一位置的贡献

一开始,$\forall a_i=0$

若对区间$[i,j]$都加上$x$,那么这个数组就变为$a_1=0,...,a_{i-1}=0,a_i=x,...,a_j=x,a_{j+1}=0,...a_n=0$

我们可以对此时的$a$数组逐项作差,即令$b_i=a_i-a_{i-1}$

那么此时$b$数组是这样的情况:$b_1=0,...,b_{i-1}=0,b_i=x,b_{i+1}=0,...,b_j=0,b_{j+1}=-x,b_{j+2}=0,...,a_n=0$

即,一次$a$数组中的$[i,j]$区间加减,只对应着$b$数组中的两个单点修改:$b_i+=x,b_{j+1}-=x$

 

而由于$b$是对$a$的逐项作差,所以有$\sum_{i=1}^x b_i=\sum_{i=1}^x (a_i-a_{i-1})=a_x-a_0=a_x$

即,对$b$数组做前缀和,就可以得到$a$数组

由于是从$b_1$开始的前缀,用树状数组维护即可

 

其实一维差分博大精深,还有各种插值方法;不过对于树上差分来说,只会用到最基本的概念

 


 

~ 树上差分 ~ 

 

类似一维数组上的区间修改,差分也可以方便的处理树上路径的区间修改

这里所说的路径,指的是树上两点$i,j$间的唯一路径

有两种可能的区间修改:对于路径上的所有点、对于路径上的所有边

 

先说修改路径上点的情况吧

类似的定义两个数组:记$a_i$表示,所有区间修改对于$i$点的总贡献;记$b_i$表示,对$a_i$树上差分的数组

一提到树上两点的路径,第一反应就应该是LCA;于是很自然地将$i,j$间路径拆成$i-LCA(i,j),j-LCA(i,j)$这两段

先抛结论:对于 将$i,j$间路径上所有点加上$x$ 的操作,可以转化为$b_i\text{+=}x,b_j\text{+=}x,b_{fa[LCA(i,j)]}\text{-=}2x$

 

那么就涉及到如何对$b_i$“作前缀和”来得到$a_i$了

树上的“前缀和”其实和序列中的有些差距:序列中是从前向后求和,而树上是对子树求和

即,$a_i=\sum_\text{j在i的子树中} b_j$

不过并不需要真的在dfs序中求和啦,可以通过$a_i=b_i+\sum_\text{j为i的儿子} a_j$轻松得到

 

现在回到上面的结论

对$b_i$加上$x$,在还原回$a$时相当于对$i$的祖先都加上$x$——因为$i$都在以它们为根的子树中

于是,$b_i\text{+=}x,b_{fa[LCA(i,j)]}\text{-=}x$对应着路径$i-LCA(i,j)$上的区间修改;$b_j\text{+=}x,b_{fa[LCA(i,j)]}\text{-=}x$对应着路径$i-LCA(i,j)$上的区间修改

 

对于修改路径上边的情况

可以将边$u-v$获得的贡献放到$u,v$中更深的点上,于是整体就和路径上点的情况很像了

对于 将$i,j$间路径上所有边加上$x$ 的操作,可以转化为$b_i\text{+=}x,b_j\text{+=}x,b_{LCA(i,j)}\text{-=}2x$

这里是$b_{LCA(i,j)}$而不是$b_{fa[LCA(i,j)]}$,是因为在修改路径上点的情况中,$LCA(i,j)$在路径上、会获得贡献;而在修改路径上边的情况中,$LCA(i,j)$对应的是$LCA(i,j)$向父亲走的那条边,并不在$i,j$间的路径上,所以不会获得贡献

 

比较基础的应用是,给出一棵$n$个点的树和$m$个点对;询问$m$条点对间的路径一共经过每个点/每条边多少次

对于每个点对$(x,y)$,就相当于将$x,y$间路径上的点/边全部加$1$

按照上面的方法差分出$b$数组后,再通过一个dfs就能还原出$a$数组,即答案

例题:Luogu P2680 (运输计划,$NOIP2015$)

考虑二分答案

对于每一个二分中点$mid$,可以$O(n)$地差分一次来检验是否可行

首先可以在二分之前$O(m)$地预处理出来每个运输计划的用时

那么假设这$m$个计划中有$x$个用时大于$mid$、$m-x$个用时小于$mid$

显然只需要考虑用时大于$mid$的$x$个计划,设其中用时最多的为$maxlen$;因为只能将一条边的权值改成$0$,所以相当于要找到一条被这$x$个计划经过、且边权超过$maxlen-mid$的边

将$x$个计划的端点差分进去,dfs一遍就能得到每条边被经过的次数;然后看被经过$x$次的边权是否满足要求即可

可以不用二分答案,而是二分被修改边的边权,这样能砍掉$2/3$的常数

// luogu-judger-enable-o2
#include <cstdio>
#include <locale>
#include <cstring>
#include <algorithm>
using namespace std;

inline void read(int &x)
{
    char ch=getchar();
    while(!isdigit(ch))
        ch=getchar();
    x=0;
    while(isdigit(ch))
    {
        x=x*10+ch-'0';
        ch=getchar();
    }
}

typedef long long ll;
const int N=300005;
const int LOG=20;

struct Edge
{
    int to,cost,nxt;
    Edge(int a=0,int b=0,int c=0)
    {
        to=a,cost=b,nxt=c;
    }
};

int n,m;

int tot;
int v[N];
Edge e[N<<1];

inline void AddEdge(int x,int y,int w)
{
    e[++tot]=Edge(y,w,v[x]);
    v[x]=tot;
}

int dep[N],dist[N];
int id;
int st[N],ed[N];
int rmq[N<<1][LOG];

inline void dfs(int x,int f)
{
    dep[x]=dep[f]+1;
    rmq[++id][0]=x;
    st[x]=id;
    for(int i=v[x];i;i=e[i].nxt)
    {
        int to=e[i].to;
        if(to==f)
            continue;
        
        dist[to]=dist[x]+e[i].cost;
        dfs(to,x);
        rmq[++id][0]=x;
    }
    ed[x]=id;
}

int log[N<<1];

inline int cmp(int x,int y)
{
    return (dep[x]<dep[y]?x:y);
}

void ST()
{
    log[0]=-1;
    for(int i=1;i<=id;i++)
        log[i]=log[i>>1]+1;
    
    for(int i=0,t=1;i<LOG-1;i++,t<<=1)
        for(int j=1;j<=id;j++)
        {
            int l=rmq[j][i],r=(j+t>id?rmq[j][i]:rmq[j+t][i]);
            rmq[j][i+1]=cmp(l,r);
        }
}

inline int LCA(int x,int y)
{
    int lb=min(st[x],st[y]),rb=max(ed[x],ed[y]);
    int k=log[rb-lb+1];
    return cmp(rmq[lb][k],rmq[rb-(1<<k)+1][k]);
}

inline int Dist(int x,int y)
{
    return dist[x]+dist[y]-dist[LCA(x,y)]*2;
}

int cnt[N];
int D,maxlen;
int coincide;

inline void Calc(int x,int f)
{
    for(int i=v[x];i;i=e[i].nxt)
    {
        int to=e[i].to;
        if(to==f)
            continue;
        
        Calc(to,x);
        cnt[x]+=cnt[to];
        if(e[i].cost>=maxlen-D)
            coincide=max(coincide,cnt[to]);
    }
}

int xx[N],yy[N];

bool Check()
{
    int exceed=0;
    memset(cnt,0,sizeof(cnt));
    for(int i=1;i<=m;i++)
        if(Dist(xx[i],yy[i])>D)
        {
            cnt[xx[i]]++;
            cnt[yy[i]]++;
            cnt[LCA(xx[i],yy[i])]-=2;
            exceed++;
        }
    
    coincide=0;
    Calc(1,0);
    return coincide==exceed;
}

int main()
{
    read(n),read(m);
    for(int i=1;i<n;i++)
    {
        int x,y,w;
        read(x),read(y),read(w);
        AddEdge(x,y,w);
        AddEdge(y,x,w);
    }
    
    dfs(1,0);
    ST();
    
    for(int i=1;i<=m;i++)
    {
        read(xx[i]),read(yy[i]);
        maxlen=max(maxlen,Dist(xx[i],yy[i]));
    }
    
    int l=0,r=1000,mid;
    while(l<r)
    {
        mid=(l+r)>>1;
        D=maxlen-mid;
        if(Check())
            l=mid+1;
        else
            r=mid;
    }
    D=maxlen-l;
    if(!Check())
        l--;
    printf("%d\n",maxlen-l);
    return 0;
}
View Code

 

训练赛被队友切掉的一题:Gym 102012G ($Rikka\ with\ Intersections\ of\ Paths$,$2018\ ICPC$徐州)

困难的地方在于,如何不重不漏地统计每一种方案

有一个不难证明、但是有些隐蔽的性质能够帮助我们:若几条路径至少交于一点,那么都相交的部分中有且最多只有一个点为某路径的LCA(且该点为相交部分中深度最浅的点)

如果我们对于每个点为这个LCA的情况分别统计,就不会重复了

若当前点被$n$条路径经过,且是$m$条路径的LCA

那么会对这个点会产生贡献的情况是,被选的$k$条路径中 至少有一条的LCA为当前点,其个数为$C^k_n-C^k_{n-m}$

计算当前点被多少条路径经过,就是上面说过的基本树上差分问题了

#include <map>
#include <locale>
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;
 
inline void read(int &x)
{
    x=0;
    char ch=getchar();
    while(!isdigit(ch))
        ch=getchar();
    while(isdigit(ch))
    {
        x=x*10+ch-'0';
        ch=getchar();
    }
}
 
typedef long long ll;
const int N=300005;
const int LOG=20;
const int MOD=1000000007;
 
int n,m,k;
vector<int> v[N];
 
ll C[N];
 
int dep[N];
int to[N][LOG];
 
void dfs(int x,int f)
{
    dep[x]=dep[f]+1;
    to[x][0]=f;
    
    for(int i=0;i<v[x].size();i++)
    {
        int nxt=v[x][i];
        if(nxt!=f)
            dfs(nxt,x);
    }
}
 
inline int LCA(int x,int y)
{
    if(dep[x]<dep[y])
        swap(x,y);
    
    for(int i=LOG-1;i>=0;i--)
        if(dep[to[x][i]]>=dep[y])
            x=to[x][i];
    for(int i=LOG-1;i>=0;i--)
        if(to[x][i]!=to[y][i])
            x=to[x][i],y=to[y][i];
    return (x!=y?to[x][0]:x);
}
 
ll ans;
ll sum[N];
ll cnt[N];
 
void Solve(int x,int f)
{
    for(int i=0;i<v[x].size();i++)
    {
        int nxt=v[x][i];
        if(nxt==f)
            continue;
        
        Solve(nxt,x);
        sum[x]+=sum[nxt];
    }
    ans=(ll(ans)+C[sum[x]]-C[sum[x]-cnt[x]]+MOD)%MOD;
}
 
inline ll rev(ll x)
{
    ll res=1;
    int k=MOD-2;
    while(k)
    {
        if(k&1)
            res=ll(res)*x%MOD;
        x=ll(x)*x%MOD;
        k>>=1;
    }
    return res;
}
 
void Init()
{
    ans=0;
    for(int i=1;i<=n;i++)
    {
        v[i].clear();
        sum[i]=cnt[i]=0;
    }
    
    for(int i=0;i<k;i++)
        C[i]=0;
    C[k]=1;
    for(int i=k+1;i<=m;i++)
        C[i]=ll(C[i-1])*i%MOD*rev(i-k)%MOD;
}
 
int main()
{
    int T;
    read(T);
    while(T--)
    {
        read(n),read(m),read(k);
        Init();
        
        for(int i=1;i<n;i++)
        {
            int x,y;
            read(x),read(y);
            v[x].push_back(y);
            v[y].push_back(x);
        }
        
        dfs(1,0);
        for(int i=1;i<LOG;i++)
            for(int j=1;j<=n;j++)
                to[j][i]=to[to[j][i-1]][i-1];
        
        for(int i=1;i<=m;i++)
        {
            int x,y,lca;
            read(x),read(y);
            lca=LCA(x,y);
            
            sum[x]++,sum[y]++;
            sum[lca]--,sum[to[lca][0]]--;
            
            cnt[lca]++;
        }
        
        Solve(1,0);
        printf("%d\n",ans);
    }
    return 0;
}
View Code

 

再来一道$NOIP$的:Luogu P1600 (天天爱跑步,$NOIP2016$)

当年不会做的确情有可原...

这题没办法按路径统计,因为实在没有办法对一条路径上能正好遇到观察者的点同时给出贡献

所以考虑按观察者统计,即对于每一个观察者,在多少路径上能够正好遇上他

考虑“路径”其实不太方便,但是通过LCA,可以将$x,y$间的路径拆成两条有向链,即$x\rightarrow LCA(x,y),LCA(x,y)\rightarrow y$

这样的好处是,可以直接根据端点判断该有向链能否遇上观察者

记$dep$数组表示各点在树上的深度,令$i=LCA(x,y)$

那么,$x\rightarrow LCA(x,y)$链(即向上走的链)能够遇上$i$处观察者的条件是$dep[x]-dep[i]=w[i]$,移项后为$dep[i]+w[i]=dep[x]$

此时,等式右端的$dep[x]$是与$i$无关的!于是只要在$i$为根的子树中,存在一个起点为$j$、经过$i$的有向链,且$dep[j]=dep[i]+w[i]$,那么就对$i$点的答案有$1$的贡献

判断对答案有贡献链的链是否存在,可以用上差分的思想,即仅在起点、终点打上标记

开一个$cnt$数组,$cnt[dep[j]]$表示子树中起点为$j$、经过$i$的链有多少条

在dfs中,若遇上某点$j$为$dlt$条链的起点,则把$cnt[dep[j]]$加上$dlt$;若遇到某点$j$为$dlt$条链的终点,就对于每条链的起点$k$,将$cnt[dep[k]]$减$1$

实际代码中,是对于每个点开两个vector<int>,叫做$add[i],del[i]$,向其中塞的是 以$i$为起点/终点的链 的起点的深度(即移项后的等号右边),这样一来打标记就很容易实现了

类似的,$LCA(x,y)\rightarrow y$链(即向下走的链)能够遇上$i$处观察者的条件是$Dist(x,y)-(dep[y]-dep[i])=w[i]$($Dist(x,y)$表示$x,y$间路径的长度),移项后为$dep[i]-w[i]=dep[y]-Dist(x,y)$

差分的操作也是类似的,只不过向vector<int>中塞的数值变成上式等号右边的东西

之后就是利用树上启发式合并,暴力地获得对于每个点的子树所对应的$cnt$数组;对向上/向下链分开统计即可

由于在树上启发式合并中,每个点最多被访问$logn$次,所以一共只会打$n\cdot logn$级别的差分标记,并不会存在问题

总时间复杂度$O(n\cdot logn)$,建议用ST表预处理LCA来降常数(其实倍增就够了,ST表常数巨大)

#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;

typedef pair<int,int> pii;
const int N=300005;
const int LOG=20;

int n,m;
int w[N];
vector<int> v[N];

int id;
int dep[N],sz[N],son[N];
int st[N],ed[N];
int rmq[N<<1][LOG];

void dfs(int x,int f)
{
    dep[x]=dep[f]+1;
    rmq[++id][0]=x;
    st[x]=id;
    sz[x]=1;
    
    for(int i=0;i<v[x].size();i++)
    {
        int nxt=v[x][i];
        if(nxt==f)
            continue;
        
        dfs(nxt,x);
        
        rmq[++id][0]=x;
        sz[x]+=sz[nxt];
        if(!son[x] || sz[son[x]]<sz[nxt])
            son[x]=nxt;
    }
    ed[x]=id;
}

int log[N<<1];

inline int cmp(int x,int y)
{
    return (dep[x]<dep[y]?x:y);
}

void ST()
{
    log[0]=-1;
    for(int i=1;i<=id;i++)
        log[i]=log[i>>1]+1;
    
    for(int i=0,t=1;i<LOG-1;i++,t<<=1)
        for(int j=1;j<=id;j++)
        {
            int l=rmq[j][i],r=(j+t>id?rmq[j][i]:rmq[j+t][i]);
            rmq[j][i+1]=cmp(l,r);
        }
}

inline int LCA(int x,int y)
{
    int lb=min(st[x],st[y]),rb=max(ed[x],ed[y]);
    int k=log[rb-lb+1];
    return cmp(rmq[lb][k],rmq[rb-(1<<k)+1][k]);
}

inline int Dist(int x,int y)
{
    return dep[x]+dep[y]-dep[LCA(x,y)]*2;
}

int ans[N];
int cnt[N*3];
vector<int> add[2][N];
vector<int> del[2][N];

void Add(int p,int x,int f,int dlt)
{
    for(int i=0;i<add[p][x].size();i++)
        cnt[add[p][x][i]]+=dlt;
    for(int i=0;i<del[p][x].size();i++)
        cnt[del[p][x][i]]-=dlt;
    
    for(int i=0;i<v[x].size();i++)
    {
        int nxt=v[x][i];
        if(nxt==f)
            continue;
        Add(p,nxt,x,dlt);
    }
}

void Solve(int p,int x,int f,int keep)
{
    for(int i=0;i<v[x].size();i++)
    {
        int nxt=v[x][i];
        if(nxt==f || nxt==son[x])
            continue;
        Solve(p,nxt,x,0);
    }
    
    if(son[x])
        Solve(p,son[x],x,1);
    
    for(int i=0;i<v[x].size();i++)
    {
        int nxt=v[x][i];
        if(nxt==f || nxt==son[x])
            continue;
        Add(p,nxt,x,1);
    }
    
    for(int i=0;i<add[p][x].size();i++)
        cnt[add[p][x][i]]++;
    if(!p)
        ans[x]+=cnt[dep[x]+w[x]];
    else
        ans[x]+=cnt[dep[x]-w[x]+N];
    for(int i=0;i<del[p][x].size();i++)
        cnt[del[p][x][i]]--;
    
    if(!keep)
        Add(p,x,f,-1);
}

int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        v[x].push_back(y);
        v[y].push_back(x);
    }
    
    dfs(1,0);
    ST();
    
    for(int i=1;i<=n;i++)
        scanf("%d",&w[i]);
    for(int i=1;i<=m;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        int lca=LCA(x,y),val0,val1;
        
        val0=dep[x];
        add[0][x].push_back(val0);
        del[0][lca].push_back(val0);
        
        val1=dep[y]-Dist(x,y)+N;
        add[1][y].push_back(val1);
        del[1][lca].push_back(val1);
        
        if(val0==dep[lca]+w[lca] && val1==dep[lca]-w[lca]+N)
            ans[lca]--;
    }
    
    Solve(0,1,0,0);
    Solve(1,1,0,0);
    
    for(int i=1;i<=n;i++)
        printf("%d ",ans[i]);
    return 0;
}
View Code

 


 

感觉天天爱跑步算是复杂一点的树上差分了,一般遇到的都是裸的

差不多是这样吧,之后做到的话就放上来

 

seuOJ 223 (小雅米的舒服了)

二分+树上差分还是挺有意思的,至少比树剖线段树优美很多

二分答案后,就可以用类似莫队的方法将$[1,mid]$的标签打到树上(由于只有右端点的改动,所以整体是$O(n\cdot logn)$的)

接着就是一遍dfs,看看是否存在不合法的点即可

(由于树上差分一般是离线做法,所以其实不需要用ST表求LCA,先离线用倍增预存就够了)

#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
using namespace std;

const int N=500005;
const int LOG=20;

int n;
vector<int> v[N];

int id;
int dep[N],fa[N];
int st[N],ed[N];
int rmq[N<<1][LOG];

void dfs(int x,int f)
{
    fa[x]=f;
    dep[x]=dep[f]+1;
    rmq[++id][0]=x;
    st[x]=id;
    
    for(int i=0;i<v[x].size();i++)
    {
        int nxt=v[x][i];
        if(nxt==f)
            continue;
        
        dfs(nxt,x);
        rmq[++id][0]=x;
    }
    ed[x]=id;
}

inline int cmp(int x,int y)
{
    return (dep[x]<dep[y]?x:y);
}

int log[N<<1];

void ST()
{
    log[0]=-1;
    for(int i=2;i<=id;i++)
        log[i]=log[i>>1]+1;
    
    for(int i=0,t=1;i<LOG-1;i++,t<<=1)
        for(int j=1;j<=id;j++)
        {
            int l=rmq[j][i],r=(j+t>id?rmq[j][i]:rmq[j+t][i]);
            rmq[j][i+1]=cmp(l,r);
        }
}

inline int LCA(int x,int y)
{
    int lb=min(st[x],st[y]),rb=max(ed[x],ed[y]);
    int k=log[rb-lb+1];
    return cmp(rmq[lb][k],rmq[rb-(1<<k)+1][k]);
}

int lim[N];
int to[N];

int tag[N],num[N],vis[N];

inline void Add(int x,int y)
{
    int lca=LCA(x,y);
    tag[x]++,tag[y]++;
    
    tag[lca]--;
    tag[fa[lca]]--;
}

inline void Del(int x,int y)
{
    tag[x]--,tag[y]--;
    
    int lca=LCA(x,y);
    tag[lca]++;
    tag[fa[lca]]++;
}

bool flag;

void Calc(int x,int f)
{
    num[x]=tag[x];
    
    for(int i=0;i<v[x].size();i++)
    {
        int nxt=v[x][i];
        if(nxt==f)
            continue;
        
        Calc(nxt,x);
        num[x]+=num[nxt];
    }
    
    if(num[x]-vis[x]>lim[x])
        flag=false;
}

int main()
{
    int n;
    scanf("%d",&n);
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        v[x].push_back(y);
        v[y].push_back(x);
    }
    
    dfs(1,0);
    ST();
    
    for(int i=1;i<=n;i++)
        scanf("%d",&lim[i]);
    
    to[0]=1;
    for(int i=1;i<=n;i++)
        scanf("%d",&to[i]);
    
    int l=0,r=n,mid;
    int p=0;
    while(l<r)
    {
        mid=(l+r)>>1;
        
        while(p<mid)
        {
            Add(to[p],to[p+1]);
            if(p)
                vis[to[p]]++;
            p++;
        }
        while(p>mid)
        {
            Del(to[p-1],to[p]);
            if(p-1)
                vis[to[p-1]]--;
            p--;
        }
        
        flag=true;
        Calc(1,0);
        
        if(flag)
            l=mid+1;
        else
            r=mid;
    }
    
    while(p<l)
    {
        Add(to[p],to[p+1]);
        if(p)
            vis[to[p]]++;
        p++;
    }
    while(p>l)
    {
        Del(to[p-1],to[p]);
        if(p-1)
            vis[to[p-1]]--;
        p--;
    }
    
    flag=true;
    Calc(1,0);
        
    if(!flag)
        l--;
    
    printf("%d\n",l);
    return 0;
}
View Code

 

(待续)

posted @ 2019-09-10 14:35  LiuRunky  阅读(337)  评论(0编辑  收藏  举报