树链剖分学习笔记

因为被 N总 D了好久,所以痛下决心想要学习一下树链剖分。

强烈推荐 N总对树链剖分的优质讲解博客(感觉比我写得好多了。。。。)。

树链剖分

树链剖分可以通过给树上的点重新编号(重新编号后就变成若干条链),使得可以将树中的任意一条路径转化成 \(O(\log n)\) 段连续的区间。那么树中路径上的所有操作都可以转化为 \(O(\log n)\) 段连续区间中的操作,也就可以用线段树来维护(也可以用其他支持区间维护的数据结构维护)。

如何将树上的路径变成区间?可以用 DFS 序来实现。

一些定义

重儿子:对于每一个非叶子节点,以它的所有儿子为根的子树中,大小最大的那个子树的根节点被称为该节点的重儿子(如果有多个子树大小最大的儿子,任选一个作为重儿子即可)。

轻儿子:对于每一个非叶子节点,除了重儿子以外的所有儿子被称为轻儿子。

重边:重儿子和它父亲之间的连边被称为重边。

轻边:轻儿子和它父亲之间的连边被称为轻边。

重链:由重边构成的路径被称为重链。(每个点都属于且仅属于一条重链,对于一些叶子节点,它们自己本身单独是一条重链)

如何判断每个点属于哪一条重链?对于每一个重儿子,它所在的重链就是它父亲所在的重链;对于每一个轻儿子,它所在的重链就是以它开头的往下走的重链

在 DFS 求 DFS 序时,优先遍历重儿子。这样可以保证每一条重链上所有点的编号都是连续的。

如下图所示:

定理

树中的任意一条路径均可拆分成 \(O(\log n)\) 个连续的区间,即 \(O(\log n)\) 条重链(不一定完整)。

树链剖分的基础操作

通常情况下,可以通过两遍 DFS 求出所有的信息。

在第一次 DFS 时,求出所有的重儿子。

在第二次 DFS 时,求出 DFS 序,以及每个点所在重链的顶点

将树上的路径转化为区间

有点类似于倍增求 LCA 的算法。如下图所示,假设要将 \(x,y\) 之间的路径转化为区间。记 \(x\) 所在重链的顶端为 \(top[x]\)\(y\) 所在重链的顶端为 \(top[y]\)

比较 \(depth[top[x]]\)\(depth[top[y]]\)。将深度较大的那个点跳到链顶点的父节点上。同时记录下这一段区间。

重复上述操作,直到 \(top[x]=top[y]\)。此时再比较一下两个节点的深度大小即可算出最后一段区间(这就体现了为什么要优先遍历重儿子)。

具体过程如下图所示:

而通常情况下题目就是要求对这些区间进行一些修改,用线段树维护的时间复杂度就是 \(O(n \log^2 n)\)。但是常数比较小。

模板题

给定一棵包含 \(N\) 个结点的树,每个节点上包含一个数值,需要支持以下操作:

\(1\) \(x\) \(y\) \(z\),表示将树从 \(x\)\(y\) 结点最短路径上所有节点的值都加上 \(z\)

\(2\) \(x\) \(y\),表示求树从 \(x\)\(y\) 结点最短路径上所有节点的值之和。

\(3\) \(x\) \(z\),表示将以 xx 为根节点的子树内所有节点值都加上 \(z\)

\(4\) \(x\) 表示求以 \(x\) 为根节点的子树内所有节点值之和

数据范围

\(1 \leq N \leq 10^5,1 \leq M \leq 10^5\)

思路

区间修改,区间查询,显然本题用线段树维护信息是最好的。

对于 \(1,2\) 两种操作来说,上面都已经介绍过了。而对于一棵子树内的操作,注意到一棵子树内的 DFS 序一定是连续的,那么这棵子树在序列上就是 \([dfn[x],dfn[x]+size[x]-1]\) 这一段连续的区间。直接区间修改即可。

code:

#include<cstdio>
using namespace std;
const int N=1e5+10;
#define LL long long
LL w[N],a[N];//a是树上的点转化到序列上的点后,这个点在原树的权值 
struct edge{
	int v,nex;
}e[N<<1];
int h[N],idx,son[N],f[N],siz[N],dfn[N],num,top[N];
int depth[N],n,m,R,mod;
void add(int u,int v){e[++idx].v=v,e[idx].nex=h[u],h[u]=idx;}
struct tree{
	int l,r,mid;
	LL val,tag;
}tr[N<<2];
void swap(int &a,int &b){int t=a;a=b,b=t;}
void push_up(int p){tr[p].val=(tr[p<<1].val+tr[p<<1|1].val)%mod;}
void push_down(int p)
{
    if(tr[p].tag)
    {
    	int lenl=tr[p<<1].r-tr[p<<1].l+1;int lenr=tr[p<<1|1].r-tr[p<<1|1].l+1;
    	tr[p<<1].val=(tr[p<<1].val+tr[p].tag*lenl)%mod;
    	tr[p<<1|1].val=(tr[p<<1|1].val+tr[p].tag*lenr)%mod;
    	tr[p<<1].tag+=tr[p].tag;
    	tr[p<<1|1].tag+=tr[p].tag;
    	tr[p].tag=0;
	}
}
void build_tree(int p,int l,int r)
{
	tr[p].l=l,tr[p].r=r,tr[p].mid=(l+r)>>1;
	if(l==r)
	{
		tr[p].val=a[l];
		return ;
	}
	build_tree(p<<1,tr[p].l,tr[p].mid);
	build_tree(p<<1|1,tr[p].mid+1,tr[p].r);
	push_up(p);
}
void updata(int p,int l,int r,LL k)
{
	if(l<=tr[p].l&&tr[p].r<=r)
	{
		tr[p].val+=(tr[p].r-tr[p].l+1)*k;
		tr[p].tag+=k;
		return ;
	}
	push_down(p);
	if(l<=tr[p].mid) updata(p<<1,l,r,k);
	if(r>tr[p].mid) updata(p<<1|1,l,r,k);
	push_up(p);
}
LL query(int p,int l,int r)
{
	if(l<=tr[p].l&&tr[p].r<=r) return tr[p].val;
	LL res=0;
	push_down(p);
	if(l<=tr[p].mid) res+=query(p<<1,l,r);
	if(r>tr[p].mid) res+=query(p<<1|1,l,r);
	return res;
}
void dfs_yy(int u,int fa)
{
	depth[u]=depth[fa]+1;
	siz[u]=1;
	f[u]=fa;
	for(int i=h[u];i;i=e[i].nex)
	{
		int v=e[i].v;
		if(v==fa) continue;
		dfs_yy(v,u);
		siz[u]+=siz[v];
		if(!son[u]||siz[v]>siz[son[u]]) son[u]=v;
	}
}
void dfs_nlc(int u,int t)
{
	dfn[u]=++num;top[u]=t;a[num]=w[u];
	if(siz[u]==1) return ;
	dfs_nlc(son[u],t);
	for(int i=h[u];i;i=e[i].nex)
	{
		int v=e[i].v;
		if(v==f[u]||v==son[u]) continue;
		dfs_nlc(v,v);
	}
}
void updata_tree(int x,int y,LL k)
{
	while(top[x]!=top[y])
	{
		if(depth[top[x]]<depth[top[y]]) swap(x,y);
		updata(1,dfn[top[x]],dfn[x],k);
		x=f[top[x]];
	}
	if(depth[x]>depth[y]) swap(x,y);
	updata(1,dfn[x],dfn[y],k);
}
LL query_tree(int x,int y)
{
	LL res=0;
	while(top[x]!=top[y])
	{
		if(depth[top[x]]<depth[top[y]]) swap(x,y);
		res=(res+query(1,dfn[top[x]],dfn[x]))%mod;
		x=f[top[x]];
	}
	if(depth[x]>depth[y]) swap(x,y);
	res=(res+query(1,dfn[x],dfn[y]))%mod;
	return res;
}
int main()
{
	scanf("%d%d%d%d",&n,&m,&R,&mod);
	for(int i=1;i<=n;i++) scanf("%lld",&w[i]);
	for(int u,v,i=1;i<n;i++)
	{
		scanf("%d%d",&u,&v);
		add(u,v),add(v,u);
	}
	dfs_yy(R,0);
	dfs_nlc(R,R);
	build_tree(1,1,n);
	while(m--)
	{
		int opt,x,y;
		LL k;
		scanf("%d%d",&opt,&x);
		if(opt==1) scanf("%d%lld",&y,&k),updata_tree(x,y,k);
		if(opt==3) scanf("%lld",&k),updata(1,dfn[x],dfn[x]+siz[x]-1,k);
		if(opt==2) scanf("%d",&y),printf("%lld\n",query_tree(x,y)%mod);
		if(opt==4) printf("%lld\n",query(1,dfn[x],dfn[x]+siz[x]-1)%mod);
	}
	return 0;
} 

应用 [NOI2015] 软件包管理器

\(n\) 个软件安装包,它们的依赖关系构成一棵,其中 \(1\) 号软件不依赖任何软件,为根节点。每次有两种操作:

install \(x\) 表示安装 \(x\) 号软件包

uninstall \(x\) 表示卸载 \(x\) 号软件包

\(A\) 依赖 \(B\),则安装 \(A\) 之前必须先安装 \(B\),卸载 \(B\) 之前必须先卸载 \(A\)

求每次操作更改了多少个软件的状态。

数据范围

思路

通过观察可以发现,如果将软件包的状态标记为 \(0/1\) (未安装/已安装)。那么对于每一次安装操作,就是把 \(x\) 到根节点的路径上的 \(0\) 全部修改成 \(1\)。对于每一次卸载操作,就是把以 \(x\) 为根的子树内部的点全部修改成 \(0\)

要统计每一次更改多少软件的状态,其实就是比较一下更改前和更改后,树中 \(1\) 的数量的变化。那么也就可以先将原树进行树链剖分操作,用线段树维护区间内 \(1\) 的数量。每次操作的答案就是更改的 \(1\) 的数量。

code:

#include<cstdio>
#include<cstring>
using namespace std;
const int N=1e5+10;
int h[N],idx,n,m,top[N],depth[N],siz[N],dfn[N],num,son[N],f[N];
struct tree{
	int l,r,mid;
	int val,tag;
}tr[N<<2];
struct edge{
	int v,nex;
}e[N];
void swap(int &a,int &b){int t=a;a=b,b=t;}
void add(int u,int v){e[++idx].v=v;e[idx].nex=h[u];h[u]=idx;}
void push_up(int p){tr[p].val=tr[p<<1].val+tr[p<<1|1].val;}
void push_down(int p)
{
	if(tr[p].tag!=-1)
	{
		tr[p<<1].val=(tr[p<<1].r-tr[p<<1].l+1)*tr[p].tag;
		tr[p<<1|1].val=(tr[p<<1|1].r-tr[p<<1|1].l+1)*tr[p].tag;
		tr[p<<1].tag=tr[p].tag;
		tr[p<<1|1].tag=tr[p].tag;
		tr[p].tag=-1;
	}
}
void build_tree(int p,int l,int r)
{
	tr[p].l=l,tr[p].r=r,tr[p].mid=(l+r)>>1;
	if(l==r) return ;
	build_tree(p<<1,tr[p].l,tr[p].mid);
	build_tree(p<<1|1,tr[p].mid+1,tr[p].r);
	push_up(p);
}
void updata(int p,int l,int r,int k)
{
	if(l<=tr[p].l&&tr[p].r<=r)
	{
		tr[p].val=(tr[p].r-tr[p].l+1)*k;
		tr[p].tag=k;
		return ;
	}
	push_down(p);
	if(l<=tr[p].mid) updata(p<<1,l,r,k);
	if(r>tr[p].mid) updata(p<<1|1,l,r,k);
	push_up(p);
}
int query(int p,int l,int r)
{
	if(l<=tr[p].l&&tr[p].r<=r) return tr[p].val;
	push_down(p);
	int res=0;
	if(l<=tr[p].mid) res+=query(p<<1,l,r);
	if(r>tr[p].mid) res+=query(p<<1|1,l,r);
	return res;
}
void dfs_yy(int u,int fa)
{
	f[u]=fa,depth[u]=depth[fa]+1;siz[u]=1;
	for(int i=h[u];i;i=e[i].nex)
	{
		int v=e[i].v;
		dfs_yy(v,u);
		siz[u]+=siz[v];
		if(siz[v]>siz[son[u]]) son[u]=v;
	}
}
void dfs_nlc(int u,int t)
{
	dfn[u]=++num;top[u]=t;
	if(siz[u]==1) return;
	dfs_nlc(son[u],t);
	for(int i=h[u];i;i=e[i].nex)
	{
		int v=e[i].v;
		if(v!=son[u]) dfs_nlc(v,v);
	}
}
void updata_tree(int x,int y,int k)
{
	while(top[x]!=top[y])
	{
		if(depth[top[x]]<depth[top[y]]) swap(x,y);
		updata(1,dfn[top[x]],dfn[x],k);
		x=f[top[x]];
	}
	if(depth[x]>depth[y]) swap(x,y);
	updata(1,dfn[x],dfn[y],k);
}
int main()
{
	scanf("%d",&n);
	for(int u,i=2;i<=n;i++)
	{
		scanf("%d",&u);
		add(u+1,i);
	}
	dfs_yy(1,0);
	dfs_nlc(1,1);
	build_tree(1,1,n);
	scanf("%d",&m);
	for(int x,i=1;i<=m;i++)
	{
		char s[110];
		scanf("%s%d",s,&x);
		x++;
		if(s[0]=='i')
		{
			int t=tr[1].val;
			updata_tree(x,1,1);
			printf("%d\n",tr[1].val-t);
		}
		else
		{
			int t=tr[1].val;
			updata(1,dfn[x],dfn[x]+siz[x]-1,0);
			printf("%d\n",t-tr[1].val);
		}
	}
	return 0;
}

应用 [SDOI2011]染色

给定一棵有 \(n\) 个节点的无根树和 \(m\) 个操作,操作共两类。

将节点 \(a\) 到节点 \(b\) 路径上的所有节点都染上颜色;

询问节点 \(a\) 到节点 \(b\) 路径上的颜色段数量。

连续相同颜色的认为是同一段,例如 \(112221\) 由三段组成:\(11,222,1\)

数据范围

\(1 \leq n,m\leq10^5\),\(0 \leq c \leq 10^9\)\(1 \leq a,b \leq n\)

思路

看到树上操作,显然可以用树链剖分求解。本题要维护的是颜色的段数,可以用线段树直接维护。具体操作就是记录区间左右端点的颜色,在合并时比较一下两段中间的颜色是否相同即可。难点在于查询。

对于查询树上的路径,因为要被分为 \(O(\log N)\) 段连续的区间,所以不能和线段树一样简单的合并。但是可以发现,由于树链剖分的性质。将一条链转化为一段连续的区间时,右端点一定是深度最大的那个点。所以可以在向上跳的时候记录当前链的最顶点的颜色。再向上跳的时候,判断这一次跳的链的深度最大的点和上一次跳的顶点颜色是否相同。

但是问题又来了,\(x,y\) 往上跳,但是每次只会跳一边,如果无脑合并显然是错误的。所以需要记录两边往上跳的顶点。

而在最后跳到一条链的时候,显然需要特判一下。

后面跳的那个点深度更低。因为两个点最后一次跳的时候,是顶点深度更大的点先跳,那么也就一定跳上最终链更深的地方(最后的判断深度来交换的操作是特判两个点在一条链的情况)。设最终链的左右端点颜色为 \(l,r\),先跳上最终链的顶点为 \(v2\),后跳上的为 \(v1\),那么只需判断一下 \(l\)\(v1\)\(r\)\(v2\) 是否相同即可。

code:

#include<cstdio>
using namespace std;
const int N=1e5+10;
int a[N],f[N],w[N],n,q,top[N],son[N],siz[N],depth[N],h[N],idx;
struct edge{
	int v,nex;
}e[N<<1];
int dfn[N],num;
void add(int u,int v){e[++idx].v=v;e[idx].nex=h[u];h[u]=idx;}
struct tree{
	int l,r,mid;
	int lv,rv,sum,tag;
}tr[N<<2];
int max(int a,int b){return a>b?a:b;}
void swap(int &a,int &b){int t=a;a=b,b=t;}
void push_up(int p)
{
    tr[p].sum=tr[p<<1].sum+tr[p<<1|1].sum-(tr[p<<1].rv==tr[p<<1|1].lv);
    tr[p].lv=tr[p<<1].lv;
    tr[p].rv=tr[p<<1|1].rv;
}
void push_down(int p)
{
	if(tr[p].tag)
	{
		tr[p<<1].rv=tr[p<<1].lv=tr[p<<1|1].lv=tr[p<<1|1].rv=tr[p].tag;
		tr[p<<1].sum=tr[p<<1|1].sum=1;
		tr[p<<1].tag=tr[p].tag;
		tr[p<<1|1].tag=tr[p].tag;
		tr[p].tag=0;
	}
}
void build_tree(int p,int l,int r)
{
	tr[p].l=l,tr[p].r=r,tr[p].mid=(l+r)>>1;
	if(l==r)
	{
		tr[p].sum=1;
		tr[p].lv=tr[p].rv=a[l];
		tr[p].tag=0;
		return ;	
	}
	build_tree(p<<1,tr[p].l,tr[p].mid);
	build_tree(p<<1|1,tr[p].mid+1,tr[p].r);
	push_up(p);
}
void updata(int p,int l,int r,int k)
{
	if(l<=tr[p].l&&tr[p].r<=r)
	{
		tr[p].sum=1;
		tr[p].lv=tr[p].rv=k;
		tr[p].tag=k;
		return ;
	}
	push_down(p);
	if(l<=tr[p].mid) updata(p<<1,l,r,k);
	if(r>tr[p].mid) updata(p<<1|1,l,r,k);
	push_up(p);
}
tree query(int p,int l,int r)
{
	if(l<=tr[p].l&&tr[p].r<=r) return tr[p];
	push_down(p);
	tree res1,res2,res;
	if(l<=tr[p].mid&&r>tr[p].mid)
	{
		res1=query(p<<1,l,r);res2=query(p<<1|1,l,r);
		res.sum=res1.sum+res2.sum-(res1.rv==res2.lv);
		res.lv=res1.lv;res.rv=res2.rv;
	}
	else if(l<=tr[p].mid) res=query(p<<1,l,r);
	else res=query(p<<1|1,l,r);
	return res;
}
void dfs_yy(int u,int fa)
{
	f[u]=fa,depth[u]=depth[fa]+1;siz[u]=1;
	for(int i=h[u];i;i=e[i].nex)
	{
		int v=e[i].v;
		if(v==fa) continue;
		dfs_yy(v,u);
		siz[u]+=siz[v];
		if(siz[v]>siz[son[u]]) son[u]=v;
	}
}
void dfs_nlc(int u,int t)
{
    dfn[u]=++num;top[u]=t;a[num]=w[u];
	if(siz[u]==1) return ;
	dfs_nlc(son[u],t);
	for(int i=h[u];i;i=e[i].nex)
	{
		int v=e[i].v;
		if(v==f[u]||v==son[u]) continue;
		dfs_nlc(v,v);
	}	
}
void updata_tree(int x,int y,int k)
{
	while(top[x]!=top[y])
	{
		if(depth[top[x]]<depth[top[y]]) swap(x,y);
		updata(1,dfn[top[x]],dfn[x],k);
		x=f[top[x]];
	}
	if(depth[x]>depth[y]) swap(x,y);
	updata(1,dfn[x],dfn[y],k);
}
int query_sum(int x,int y)
{
	tree res;
	int lv1=-1,lv2=-1,ans=0;
	while(top[x]!=top[y])
	{
		if(depth[top[x]]<depth[top[y]]) swap(x,y),swap(lv1,lv2);
		res=query(1,dfn[top[x]],dfn[x]);
		ans+=res.sum-(lv1==res.rv);
		lv1=res.lv;
		x=f[top[x]];
	}
	if(depth[x]>depth[y]) swap(x,y),swap(lv1,lv2);
	res=query(1,dfn[x],dfn[y]);
	ans+=res.sum-(res.rv==lv2)-(res.lv==lv1);
	return ans;
}
int main()
{
	scanf("%d%d",&n,&q);
	for(int i=1;i<=n;i++) scanf("%d",&w[i]);
	for(int u,v,i=1;i<n;i++)
	{
		scanf("%d%d",&u,&v);
		add(u,v),add(v,u);
	}
	dfs_yy(1,0);
	dfs_nlc(1,1);
	build_tree(1,1,n);
	while(q--)
	{
		int x,y,k;
		char op[2];
		scanf("%s%d%d",op,&x,&y);
		if(op[0]=='C') scanf("%d",&k),updata_tree(x,y,k); 
		else if(op[0]=='Q') printf("%d\n",query_sum(x,y));
	} 
	return 0;
}
posted @ 2021-08-02 12:29  曙诚  阅读(132)  评论(0)    收藏  举报