[笔记]线段树合并

线段树合并,就是将两棵线段树对应位置相加,得到一棵新的线段树。
由于实际应用中,通常要对很多棵线段树进行多次合并,所以和主席树类似地,我们使用动态开点线段树来实现。

算法概述

线段树合并的代码实现如下:

int merge(int x,int y,int l,int r){//将x,y为根的树都合并到x上
	if(!x||!y) return x+y;//如果x=y=0则返回空节点,x=0则返回y,如果y=0则返回x
	if(l==r) return sum(x)+=sum(y),x;
	int mid=(l+r)>>1;
	lc(x)=merge(lc(x),lc(y),l,mid);
	rc(x)=merge(rc(x),rc(y),mid+1,r);
	return pushup(x),x;
}

也可以通过引用改成void类型的:

void merge(int& x,int y,int l,int r){
	if(!x||!y) return x+=y,void();
	if(l==r) return sum(x)+=sum(y),void();
	int mid=(l+r)>>1;
	merge(lc(x),lc(y),l,mid);
	merge(rc(x),rc(y),mid+1,r);
	pushup(x);
}

下文规定\(T_1+T_2\)为线段树\(T_1,T_2\)合并后的结果,\(|T|\)表示线段树\(T\)的节点数。

关于线段树合并的时间复杂度,有结论:

  • 对于\(n\)棵线段树\(T_1,T_2,\dots,T_n\),将它们合并的时间复杂度是\(O(\sum\limits_{i=1}^n |T_i|-|\sum\limits_{i=1}^n T_i|)\)

下面的内容来自算法学习笔记(88): 线段树合并 by Pecco

使用归纳法证明:

  • \(n=0\)时,时间复杂度为\(O(0)\)
  • 假如对于\(n<k\)都成立,当\(n=k\)时,将\(T_1,T_2,\dots,T_n\),划分成两个非空集合\(S_1,S_2\)
    • \(S_1,S_2\)分别合并成\(T'_1,T'_2\),时间复杂度是:

      \[O(\sum\limits_{T\in S_1}|T|-|T'_1|)+O(\sum\limits_{T\in S_2}|T|-|T'_2|)\\=O(\sum\limits_{i=1}^n |T_i|-|T'_1|-|T'_2|) \]

    • 再将\(T'_1,T'_2\)合并,根据代码可以发现时间复杂度就是两树重叠的节点个数,即:

      \[O(|T'_1|+|T'_2|-|T'_1+T'_2|) \]

    相加可得总时间复杂度:

    \[O(\sum\limits_{i=1}^n |T_i|-|T'_1+T'_2|)\\=O(\sum\limits_{i=1}^n |T_i|-|\sum\limits_{i=1}^n T_i|) \]

所以,对于值域为\(n\)的若干线段树,如果对其进行了\(k\)次单点修改,总节点数是\(O(k\log n)\)的,合并它们的时间复杂度是\(O(k\log n-|\sum\limits_{i=1}^n T_i|)<O(k\log n)\)

例题(题单点这里

1. P4556 [Vani有约会] 雨天的尾巴 /【模板】线段树合并

在树上进行若干条路径的区间修改(静态),我们通常使用树上差分。

举个例子,对于一棵树,想要给\(u\)\(v\)路径上的每个节点增加\(k\),就相当于在该树的差分数组上进行如下操作:

  • \(u\)处增加\(k\)
  • \(v\)处增加\(k\)
  • \(\text{LCA}(u,v)\)处减少\(k\)
  • \(fa[\text{LCA}(u,v)]\)处减少\(k\)

将差分数组\(b\)还原成原数组\(a\),仅需令\(a[u]=(\sum\limits_{fa[v]=u} a[v])+b[u]\)即可。

我们给每个节点建一个值域为\(V\)的线段树,位置\(i\)上的值表示救济粮\(i\)有多少袋。

对于每次操作,按照上面的过程进行\(4\)次单点修改。

完成所有操作后,搜索每个节点\(u\),按上面的过程求出\(u\)还原后的线段树形态,此时\(u\)点的答案即为该线段树中最大值所在的下标。

不过这里我们要累加的不是整数,而是若干颗线段树,这就要用到线段树合并。

时间复杂度为\(O(m\log V)\)

注意节点总数是\(4m\log V\approx 1.6\times 10^6\),因为单点修改次数是\(4m\)

点击查看代码
#include<bits/stdc++.h>
#define N 100010
#define M 100010
#define V 100010
using namespace std;
struct edge{int nxt,to;}e[M<<1];
struct SEG{
	struct node{int lc,rc,maxx,pos;}tr[M*80];//4MlogV
	int idx;
	#define lc(x) (tr[x].lc)//请注意,不使用undef的话,define的作用域是从此处直到文件结尾
	#define rc(x) (tr[x].rc)//放在这里面只是为了条理一些
	#define maxx(x) (tr[x].maxx)
	#define pos(x) (tr[x].pos)
	void pushup(int x){
		maxx(x)=-1e9;
		if(lc(x)&&maxx(lc(x))>maxx(x)) maxx(x)=maxx(lc(x)),pos(x)=pos(lc(x));
		if(rc(x)&&maxx(rc(x))>maxx(x)) maxx(x)=maxx(rc(x)),pos(x)=pos(rc(x));
	}
	void chp(int &x,int a,int v,int l,int r){
		if(!x) x=++idx;
		if(l==r) return maxx(x)+=v,pos(x)=l,void();
		int mid=(l+r)>>1;
		if(a<=mid) chp(lc(x),a,v,l,mid);
		else chp(rc(x),a,v,mid+1,r);
		pushup(x);
	}
	void merge(int &x,int y,int l,int r){
		if(!x||!y) return x+=y,void();
		if(l==r) return maxx(x)+=maxx(y),void();
		int mid=(l+r)>>1;
		merge(lc(x),lc(y),l,mid);
		merge(rc(x),rc(y),mid+1,r);
		pushup(x);
	}
}tr;
int n,m,head[N],fa[N][20],dep[N],idx,root[N],ans[N];
void add(int u,int v){e[++idx]={head[u],v},head[u]=idx;}
void dfs(int u){
	dep[u]=dep[fa[u][0]]+1;
	for(int i=1;i<20;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v!=fa[u][0]) fa[v][0]=u,dfs(v);
	}
}
int LCA(int u,int v){
	if(dep[u]<dep[v]) swap(u,v);
	for(int i=19;~i;i--) if(dep[fa[u][i]]>=dep[v]) u=fa[u][i];
	if(u==v) return u;
	for(int i=19;~i;i--) if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
	return fa[u][0];
}
void dfs2(int u){
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v!=fa[u][0]) dfs2(v),tr.merge(root[u],root[v],1,V);
	}
	if(tr.tr[root[u]].maxx) ans[u]=tr.tr[root[u]].pos;
}
signed main(){
	cin>>n>>m;
	for(int i=1,u,v;i<n;i++){
		cin>>u>>v;
		add(u,v),add(v,u);
	}
	dfs(1);
	for(int i=1,x,y,z,l;i<=m;i++){
		cin>>x>>y>>z;
		l=LCA(x,y);
		tr.chp(root[x],z,1,1,V);
		tr.chp(root[y],z,1,1,V);
		tr.chp(root[l],z,-1,1,V);
		tr.chp(root[fa[l][0]],z,-1,1,V);
	}
	dfs2(1);
	for(int i=1;i<=n;i++) cout<<ans[i]<<"\n";
	return 0;
}

2. P3605 [USACO17JAN] Promotion Counting P

相当于上道题的弱化版,仍然给每个节点建一个线段树(注意值域太大需要离散化)。

对于每个节点\(u\),将它到根节点的路径上每个节点对应的线段树上第\(p[u]\)\(+1\),仍然使用树上差分来解决。

时间复杂度是\(O(n\log n)\)

节点总数是\(n\log n\)

点击查看代码
#include<bits/stdc++.h>
#define int long long
#define N 100010
using namespace std;
int n,nn,tmp[N],p[N],root[N],ans[N];
vector<int> G[N];
unordered_map<int,int> to;
struct SEG{
	struct node{int lc,rc,sum;}tr[20*N];//nlogn
	int idx;
	#define lc(x) (tr[x].lc)
	#define rc(x) (tr[x].rc)
	#define sum(x) (tr[x].sum)
	void pushup(int x){sum(x)=sum(lc(x))+sum(rc(x));}
	void chp(int &x,int a,int v,int l,int r){
		if(!x) x=++idx;
		if(l==r) return sum(x)+=v,void();
		int mid=(l+r)>>1;
		if(a<=mid) chp(lc(x),a,v,l,mid);
		else chp(rc(x),a,v,mid+1,r);
		pushup(x);
	}
	int query(int x,int a,int b,int l,int r){
		if(a<=l&&r<=b) return sum(x);
		int mid=(l+r)>>1,ans=0;
		if(a<=mid) ans+=query(lc(x),a,b,l,mid);
		if(b>mid) ans+=query(rc(x),a,b,mid+1,r);
		return ans;
	}
	void merge(int &x,int y,int l,int r){
		if(!x||!y) return x+=y,void();
		if(l==r) return sum(x)+=sum(y),void();
		int mid=(l+r)>>1;
		merge(lc(x),lc(y),l,mid);
		merge(rc(x),rc(y),mid+1,r);
		pushup(x);
	}
}tr;
void dfs(int u){
	for(int i:G[u]) dfs(i),tr.merge(root[u],root[i],1,nn);
	ans[u]=tr.query(root[u],to[p[u]]+1,nn,1,nn);
	tr.chp(root[u],to[p[u]],1,1,nn);
}
signed main(){
	cin>>n;
	for(int i=1;i<=n;i++) cin>>p[i],tmp[i]=p[i];
	sort(tmp+1,tmp+1+n);
	nn=unique(tmp+1,tmp+1+n)-tmp-1;
	for(int i=1;i<=nn;i++) to[tmp[i]]=i;
	nn++;//哨兵节点 
	for(int i=2,u;i<=n;i++) cin>>u,G[u].emplace_back(i);
	dfs(1);
	for(int i=1;i<=n;i++) cout<<ans[i]<<"\n";
	return 0;
}

3. CF600E Lomsat gelral

和上道题类似,不过线段树要维护的东西变成了最大值所在的下标之和,对pushup()进行一些修改即可;不需要离散化。

时间复杂度\(O(n\log V)=O(n\log n)\)

节点总数是\(n\log V=n\log n\)

点击查看代码
#include<bits/stdc++.h>
#define int long long
#define N 100010 
using namespace std;
int n,ans[N],root[N];
vector<int> G[N];
struct SEG{
	struct node{int lc,rc,maxx,ans;}tr[20*N];//nlogV=nlogn
	int idx;
	#define lc(x) (tr[x].lc)
	#define rc(x) (tr[x].rc)
	#define maxx(x) (tr[x].maxx)
	#define ans(x) (tr[x].ans)
	void pushup(int x){
		if(maxx(lc(x))>maxx(rc(x))) maxx(x)=maxx(lc(x)),ans(x)=ans(lc(x));
		else if(maxx(lc(x))<maxx(rc(x))) maxx(x)=maxx(rc(x)),ans(x)=ans(rc(x));
		else maxx(x)=maxx(lc(x)),ans(x)=ans(lc(x))+ans(rc(x));
	}
	void chp(int &x,int a,int v,int l,int r){
		if(!x) x=++idx;
		if(l==r) return maxx(x)+=v,ans(x)=l,void();
		int mid=(l+r)>>1;
		if(a<=mid) chp(lc(x),a,v,l,mid);
		else chp(rc(x),a,v,mid+1,r);
		pushup(x);
	}
	void merge(int &x,int y,int l,int r){
		if(!x||!y) return x+=y,void();
		if(l==r) return maxx(x)+=maxx(y),ans(x)=l,void();
		int mid=(l+r)>>1;
		merge(lc(x),lc(y),l,mid);
		merge(rc(x),rc(y),mid+1,r);
		pushup(x);
	}
}tr;
void add(int u,int v){G[u].emplace_back(v);}
void dfs(int u,int fa){
	for(int i:G[u]) if(i!=fa) dfs(i,u),tr.merge(root[u],root[i],1,n);
	ans[u]=tr.tr[root[u]].ans;
}
signed main(){
	cin>>n;
	for(int i=1,c;i<=n;i++) cin>>c,tr.chp(root[i],c,1,1,n);
	for(int i=1,u,v;i<n;i++) cin>>u>>v,add(u,v),add(v,u);
	dfs(1,0);
	for(int i=1;i<=n;i++) cout<<ans[i]<<" ";
	return 0;
}

4. P3521 [POI 2011] ROT-Tree Rotations

递归的过程中,逆序对只可能:

  • 在左子树中。
  • 在右子树中。
  • 跨越左右子树。

如果前\(2\)种情况已经计算出来了,那么我们有\(2\)种决策:“交换左右子树”or“不交换左右子树”。

显然交不交换对只对第\(3\)种情况的答案有影响,所以我们可以贪心地取两种决策答案较小者。

至于如何计算两个子树之间的逆序对个数,可以为每个节点\(u\)开一个权值线段树,来表示子树\(u\)中每个数出现次数。

根据上面的分析,可以写出下面的代码:

int query(int x,int y,int l,int r){//子树x(左)和子树y(右)之间产生的逆序对数量 
	if(!x||!y) return 0;
	if(l==r) return 0;
	int mid=(l+r)>>1;
	return query(lc(x),lc(y),l,mid)+query(rc(x),rc(y),mid+1,r)+sum(rc(x))*sum(lc(y));
}

query()合并到merge()中即可在合并的同时求出两种决策的答案,在常数上有显著的效率提升。

时间复杂度是\(O(n\log V)=O(n\log n)\)

节点总数是\(n\log V=n\log n\)

此题有点卡空间,把不需要ll的变量尽量开int就好。

点击查看代码
#include<bits/stdc++.h>
#define ll long long
#define N 200010
using namespace std;
int n,idx;
ll ans,s1,s2;
struct SEG{
	struct node{int lc,rc,sum;}tr[20*N];//nlogV=nlogn
	int idx;
	#define lc(x) (tr[x].lc)
	#define rc(x) (tr[x].rc)
	#define sum(x) (tr[x].sum)
	void pushup(int x){sum(x)=sum(lc(x))+sum(rc(x));}
	void chp(int &x,int a,int v,int l,int r){
		if(!x) x=++idx;
		if(l==r) return sum(x)+=v,void();
		int mid=(l+r)>>1;
		if(a<=mid) chp(lc(x),a,v,l,mid);
		else chp(rc(x),a,v,mid+1,r);
		pushup(x);
	}
	void merge(int &x,int y,int l,int r){
		if(!x||!y) return x+=y,void();
		if(l==r) return sum(x)+=sum(y),void();
		int mid=(l+r)>>1;
		s1+=1ll*sum(lc(x))*sum(rc(y));
		s2+=1ll*sum(rc(x))*sum(lc(y));
		merge(lc(x),lc(y),l,mid);
		merge(rc(x),rc(y),mid+1,r);
		pushup(x);
	}
}tr;
int dfs(){//由于节点编号是1~n的排列,所以直接用权值当编号 
	int p=0,x;
	cin>>x;
	if(!x){
		int lc=dfs(),rc=dfs();
		s1=s2=0,tr.merge(lc,rc,1,n);
		p=lc,ans+=min(s1,s2);
	}else tr.chp(p,x,1,1,n);
	return p;
}
signed main(){
	cin>>n;
	dfs();
	cout<<ans<<"\n";
	return 0;
}

5.CF208E Blood Cousins

节点\(u\)的线段树的第\(i\)位存储“子树\(u\)中,深度为\(i\)的节点个数”。

对于询问“节点\(v\)有多少个\(p\)级表亲”,在\(u\)\(k\)级祖先处统计贡献即可。贡献为线段树上\(dep[v]\)处的值。

时间复杂度是\(O((n+q)\log n)\)

节点总数是\(n\log n\)

点击查看代码
#include<bits/stdc++.h>
#define eb emplace_back
#define N 100010
#define Q 100010
using namespace std;
int n,q,r[N],dep[N],root[N],fa[N][20],ans[Q];
vector<int> G[N];
struct Que{int id,d;};
vector<Que> que[N];
struct SEG{
	struct node{int lc,rc,sum;}tr[20*N];//nlogn
	int idx;
	#define lc(x) (tr[x].lc)
	#define rc(x) (tr[x].rc)
	#define sum(x) (tr[x].sum)
	void pushup(int x){sum(x)=sum(lc(x))+sum(rc(x));}
	void chp(int &x,int a,int v,int l,int r){
		if(!x) x=++idx;
		if(l==r) return sum(x)+=v,void();
		int mid=(l+r)>>1;
		if(a<=mid) chp(lc(x),a,v,l,mid);
		else chp(rc(x),a,v,mid+1,r);
		pushup(x);
	}
	int query(int x,int a,int l,int r){
		if(l==r) return sum(x);
		int mid=(l+r)>>1;
		if(a<=mid) return query(lc(x),a,l,mid);
		else return query(rc(x),a,mid+1,r);
	}
	void merge(int &x,int y,int l,int r){
		if(!x||!y) return x+=y,void();
		if(l==r) return sum(x)+=sum(y),void();
		int mid=(l+r)>>1;
		merge(lc(x),lc(y),l,mid);
		merge(rc(x),rc(y),mid+1,r);
		pushup(x);
	}
}tr;
void dfs(int u){
	dep[u]=dep[fa[u][0]]+1;
	for(int i=1;i<20;i++) fa[u][i]=fa[fa[u][i-1]][i-1];
	for(int i:G[u]) fa[i][0]=u,dfs(i);
}
int kthp(int u,int k){
	for(int i=0;i<20;i++) if((k>>i)&1) u=fa[u][i];
	return u; 
}
void dfs2(int u){
	if(!u){for(int i:G[u]) dfs2(i);return;}
	tr.chp(root[u],dep[u],1,1,n);
	for(int i:G[u]){
		dfs2(i);
		tr.merge(root[u],root[i],1,n);
	}
	for(Que i:que[u]) ans[i.id]=tr.query(root[u],i.d,1,n)-1;
}
signed main(){
	cin>>n;
	for(int i=1;i<=n;i++) cin>>r[i],G[r[i]].eb(i);
	cin>>q;
	dep[0]=-1,dfs(0);
	for(int i=1,x,y;i<=q;i++) cin>>x>>y,que[kthp(x,y)].eb(Que{i,dep[x]});
	dfs2(0);
	for(int i=1;i<=q;i++) cout<<ans[i]<<" ";
	return 0;
}

6.P5384 [Cnoi2019] 雪松果树

和5.的题意相同。存在\(O(n)\)的DFS+差分做法,然后此题就卡线段树合并了。跳过。

7.P3899 [湖南集训] 更为厉害

对于询问\((p,k)\),答案有\(2\)种情况,分别统计:

  • \(b\)\(a\)的祖先:\(b\)\(a\)以上\(k\)个节点内任意选择,\(c\)在子树\(a\)中任意选择。
    贡献为:\(\min(dis(a,1),k)\times(siz[a]-1)\)
  • \(a\)\(b\)的祖先:\(b\)在子树\(a\)中任意选择,\(c\)在子树\(b\)中任意选择。
    贡献为:\(\sum\limits_{a是b的祖先,dis(a,b)\le k}(siz[b]-1)\)

其中统计后面的式子,可以每个节点\(u\)开一个线段树,第\(i\)个位置表示深度为\(i\)且在\(u\)子树中的节点\(b\)\((siz[b]-1)\)之和。

时间复杂度是\(O((n+q)\log n)\)

节点总数是\(n\log n\)

点击查看代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=3e5+10,Q=3e5+10;
struct SEG{
	struct node{int lc,rc,sum;}tr[N*20];//nlogn
	int idx;
	#define lc(x) (tr[x].lc)
	#define rc(x) (tr[x].rc)
	#define sum(x) (tr[x].sum)
	void pushup(int x){sum(x)=sum(lc(x))+sum(rc(x));}
	void chp(int &x,int a,int v,int l,int r){
		if(!x) x=++idx;
		if(l==r) return sum(x)+=v,void();
		int mid=(l+r)>>1;
		if(a<=mid) chp(lc(x),a,v,l,mid);
		else chp(rc(x),a,v,mid+1,r);
		pushup(x);
	}
	int query(int x,int a,int b,int l,int r){
		if(a<=l&&r<=b) return sum(x);
		int mid=(l+r)>>1,ans=0;
		if(a<=mid) ans+=query(lc(x),a,b,l,mid);
		if(b>mid) ans+=query(rc(x),a,b,mid+1,r);
		return ans;
	}
	void merge(int &x,int y,int l,int r){
		if(!x||!y) return x+=y,void();
		if(l==r) return sum(x)+=sum(y),void();
		int mid=(l+r)>>1;
		merge(lc(x),lc(y),l,mid);
		merge(rc(x),rc(y),mid+1,r);
		pushup(x);
	}
}tr;
struct Que{int id,k;};
int n,q,root[N],dep[N],siz[N],ans[Q];
vector<int> G[N];
vector<Que> que[N];
void add(int u,int v){G[u].emplace_back(v);}
void dfs(int u,int fa){
	dep[u]=dep[fa]+1,siz[u]=1;
	for(int i:G[u]) if(i!=fa) dfs(i,u),siz[u]+=siz[i];
}
void dfs2(int u,int fa){
	for(int i:G[u]) if(i!=fa) dfs2(i,u),tr.merge(root[u],root[i],1,n);
	for(Que i:que[u]){
		ans[i.id]=tr.query(root[u],dep[u]+1,min(dep[u]+i.k,n),1,n)+(siz[u]-1)*min(dep[u]-1,i.k);
	}
	tr.chp(root[u],dep[u],siz[u]-1,1,n);
}
signed main(){
	cin>>n>>q;
	for(int i=1,u,v;i<n;i++) cin>>u>>v,add(u,v),add(v,u);
	for(int i=1,p,k;i<=q;i++) cin>>p>>k,que[p].emplace_back(Que{i,k});
	dfs(1,0),dfs2(1,0);
	for(int i=1;i<=q;i++) cout<<ans[i]<<"\n";
	return 0;
}

8.CF1009F Dominant Indices

节点\(u\)的线段树的第\(i\)位表示“\(u\)子树内有多少个节点深度为\(i\)”。

其答案为该线段树内“最大值出现的最小下标”,有两种实现方法:

  • 记录\(maxx,mpos\)分别表示最大值和最大值所在最小下标,然后正常转移。
  • 仅记录\(maxx\),然后使用线段树上二分。

后者代码实现和时空占用上都更优,两种方法的具体实现详见代码。


此题有点卡线段树合并,所以需要节点回收

具体来说,在线段树合并时,将\(T_B\)合并到\(T_A\)上之后,\(T_B\)中不被转移到\(T_A\)上的节点就用不上了(不排除需要使用这些节点的情况,不过此题和之前的题都不需要),所以在merge的过程中将这些节点压入一个栈,表示我们已经将它们回收。

在创建新节点时,我们优先从栈中取,如果栈是空的再用++idx开辟新节点。

这样不难发现任何时刻线段树中的节点数都不会超过\(n\)

实际上前面的题都可以这样优化。


时间复杂度是\(O(n\log n)\)

节点总数是\(n\log n\rightarrow n\)

实现$1$
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+10;
int n,head[N],idx,dep[N],root[N],ans[N];
struct edge{int nxt,to;}e[N<<1];
void add(int u,int v){e[++idx]={head[u],v},head[u]=idx;}
struct SEG{
	stack<int> gar;
	struct Data{int maxx,mpos;};
	struct Node{
		int lc,rc;
		Data data;
		void init(){data={0,0},lc=rc=0;}
	}tr[N];
	Data add(Data a,Data b){
		if(a.maxx<b.maxx) return {b.maxx,b.mpos};
		else return {a.maxx,a.mpos};
	}
	int idx;
	#define lc(x) (tr[x].lc)
	#define rc(x) (tr[x].rc)
	#define maxx(x) (tr[x].data.maxx)
	#define mpos(x) (tr[x].data.mpos)
	int newnode(){
		int k;
		if(!gar.empty()){
			k=gar.top(),gar.pop();
		}else k=++idx;
		return tr[k].init(),k;
	}
	void pushup(int x){tr[x].data=add(tr[lc(x)].data,tr[rc(x)].data);}
	void chp(int &x,int a,int v,int l,int r){
		if(!x) x=newnode();
		if(l==r) return maxx(x)+=v,mpos(x)=l,void();
		int mid=(l+r)>>1;
		if(a<=mid) chp(lc(x),a,v,l,mid);
		else chp(rc(x),a,v,mid+1,r);
		pushup(x);
	}
	Data query(int x,int a,int b,int l,int r){
		if(a<=l&&r<=b) return tr[x].data;
		int mid=(l+r)>>1;
		if(a<=mid&&b>mid) return add(query(lc(x),a,b,l,mid),query(rc(x),a,b,mid+1,r));
		if(a<=mid) return query(lc(x),a,b,l,mid);
		return query(rc(x),a,b,mid+1,r);
	}
	void merge(int &x,int y,int l,int r){
		if(!x||!y) return x+=y,void();
		gar.push(y);
		if(l==r) return maxx(x)+=maxx(y),void();
		int mid=(l+r)>>1;
		merge(lc(x),lc(y),l,mid);
		merge(rc(x),rc(y),mid+1,r);
		pushup(x);
	}
}tr;
void dfs(int u,int fa){
	dep[u]=dep[fa]+1;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v!=fa) dfs(v,u);
	}
}
void dfs2(int u,int fa){
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v!=fa) dfs2(v,u),tr.merge(root[u],root[v],1,n);
	}
	tr.chp(root[u],dep[u],1,1,n);
	ans[u]=tr.query(root[u],dep[u],n,1,n).mpos-dep[u];
}
signed main(){
	ios::sync_with_stdio(false);
	cin.tie(nullptr),cout.tie(nullptr);
	cin>>n;
	for(int i=1,u,v;i<n;i++) cin>>u>>v,add(u,v),add(v,u);
	dfs(1,0),dfs2(1,0);
	for(int i=1;i<=n;i++) cout<<ans[i]<<"\n";
	return 0;
}
实现$2$
#include<bits/stdc++.h>
using namespace std;
const int N=1e6+10;
int n,head[N],idx,dep[N],root[N],ans[N];
struct edge{int nxt,to;}e[N<<1];
void add(int u,int v){e[++idx]={head[u],v},head[u]=idx;}
struct SEG{
	stack<int> gar;
	struct Node{
		int lc,rc,maxx;
		void init(){lc=rc=maxx=0;}
	}tr[N];
	int idx;
	#define lc(x) (tr[x].lc)
	#define rc(x) (tr[x].rc)
	#define maxx(x) (tr[x].maxx)
	int newnode(){
		int k;
		if(!gar.empty()){
			k=gar.top(),gar.pop();
		}else k=++idx;
		return tr[k].init(),k;
	}
	void pushup(int x){maxx(x)=max(maxx(lc(x)),maxx(rc(x)));}
	void chp(int &x,int a,int v,int l,int r){
		if(!x) x=newnode();
		if(l==r) return maxx(x)+=v,void();
		int mid=(l+r)>>1;
		if(a<=mid) chp(lc(x),a,v,l,mid);
		else chp(rc(x),a,v,mid+1,r);
		pushup(x);
	}
	int query(int x,int l,int r){
		if(l==r) return l;
		int mid=(l+r)>>1;
		if(maxx(lc(x))==maxx(x)) return query(lc(x),l,mid);
		else return query(rc(x),mid+1,r);
	}
	void merge(int &x,int y,int l,int r){
		if(!x||!y) return x+=y,void();
		gar.push(y);
		if(l==r) return maxx(x)+=maxx(y),void();
		int mid=(l+r)>>1;
		merge(lc(x),lc(y),l,mid);
		merge(rc(x),rc(y),mid+1,r);
		pushup(x);
	}
}tr;
void dfs(int u,int fa){
	dep[u]=dep[fa]+1;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v!=fa) dfs(v,u);
	}
}
void dfs2(int u,int fa){
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v!=fa) dfs2(v,u),tr.merge(root[u],root[v],1,n);
	}
	tr.chp(root[u],dep[u],1,1,n);
	ans[u]=tr.query(root[u],1,n)-dep[u];
}
signed main(){
	ios::sync_with_stdio(false);
	cin.tie(nullptr),cout.tie(nullptr);
	cin>>n;
	for(int i=1,u,v;i<n;i++) cin>>u>>v,add(u,v),add(v,u);
	dfs(1,0),dfs2(1,0);
	for(int i=1;i<=n;i++) cout<<ans[i]<<"\n";
	return 0;
}

顺便推荐一下这篇文章:浅谈如何优美地实现线段树? by Creed-qwq

这篇文章介绍了线段树的通用框架,在面对较复杂的信息维护下,能保证代码有很强的复用性,不用再写大量本质相同的代码。

实现\(1\)的线段树框架就部分参考了该文章(注意使用动态开点线段树的话,记录的左右孩子要独立于DataTag之外)。

8.CF570D Tree Requests

节点\(u\)的线段树的第\(i\)位表示“深度为\(i\)且在\(u\)子树中的节点上的字母的信息”。

显然我们只需要知道每个字母出现次数的奇偶性,因此我们可以把该信息压成一个整数,第\(i\)个二进制位表示第\(i\)个字母出现的次数的奇偶性。

合并时,将要合并的两个叶节点求异或即可。

本题中我们所要维护的仅有叶子结点的信息,因此不需要pushup

时间复杂度\(O((n+q)\log n)\)

节点总数是\(n\log n\)

点击查看代码
#include<bits/stdc++.h>
#define eb emplace_back
#define pc __builtin_popcount
using namespace std;
const int N=5e5+10,Q=5e5+10;
int n,q,root[N],ans[Q],dep[N];
string s;
vector<int> G[N];
vector<pair<int,int>> que[N];
struct SEG{
	struct node{int lc,rc,v;}tr[N*20];//nlogn
	int idx;
	#define lc(x) (tr[x].lc)
	#define rc(x) (tr[x].rc)
	#define v(x) (tr[x].v)
	void chp(int &x,int a,int v,int l,int r){
		if(!x) x=++idx;
		if(l==r) return v(x)^=v,void();
		int mid=(l+r)>>1;
		if(a<=mid) chp(lc(x),a,v,l,mid);
		else chp(rc(x),a,v,mid+1,r);
	}
	int query(int x,int a,int l,int r){
		if(l==r) return v(x);
		int mid=(l+r)>>1;
		if(a<=mid) return query(lc(x),a,l,mid);
		return query(rc(x),a,mid+1,r);
	}
	void merge(int &x,int y,int l,int r){
		if(!x||!y) return x+=y,void();
		if(l==r) return v(x)^=v(y),void();
		int mid=(l+r)>>1;
		merge(lc(x),lc(y),l,mid);
		merge(rc(x),rc(y),mid+1,r);
	}
}tr;
void dfs(int u){
	for(int i:G[u]){
		dep[i]=dep[u]+1,dfs(i);
		tr.merge(root[u],root[i],1,n);
	}
	tr.chp(root[u],dep[u],(1<<(s[u]-'a')),1,n);
	for(auto i:que[u]){
		ans[i.first]=(pc(tr.query(root[u],i.second,1,n))<=1);
	}
}
signed main(){
	ios::sync_with_stdio(0);
	cin.tie(nullptr),cout.tie(nullptr);
	cin>>n>>q;
	for(int i=2,u;i<=n;i++) cin>>u,G[u].eb(i);
	cin>>s,s=' '+s;
	for(int i=1,u,d;i<=q;i++) cin>>u>>d,que[u].eb(i,d);
	dep[1]=1,dfs(1);
	for(int i=1;i<=q;i++) cout<<(ans[i]?"Yes\n":"No\n");
	return 0;
}

8.P1600 [NOIP 2016 提高组] 天天爱跑步

参考:此文 by Engulf

对于\((s,t)\)这条路径,考虑它对\(x\)节点产生贡献的情况:

  1. \(x\)\((s,\text{lca})\)上。
    图片
    则有\(dep_s-dep_x=w_x\),即\(dep_s=dep_x+w_x\)

  2. \(x\)\((\text{lca},t)\)上。
    图片
    则有\((dep_\text{lca}-dep_s)+(dep_x-dep_\text{lca})=w_x\),即\(2\times dep_\text{lca}-dep_s=dep_x-w_x\)

因此我们用两个线段树合并。

  • 第一次,对于每个\((s_i,\text{lca}_i)\),将\(dep_s\)加入其上节点对应的权值线段树。对\(ans_x\)的贡献为\(dep_x+w_x\)处的值。
  • 第二次,对于每个\((\text{lca}_i,t_i)\),将\(2\times dep_\text{lca}-dep_s\)加入其上节点对应的权值线段树。对\(ans_x\)的贡献为\(dep_x-w_x\)处的值。

由于修改都是针对链来进行的,所以于同理P4556,使用树上差分即可解决。

注意不要重复/漏统计\(\text{lca}\)

时间复杂度\(O(n\log^2 n)\)

节点总数懒得算了(逃

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N=3e5+10;
int n,m,w[N],head[N],idx,root[N][2];
int dep[N],fa[N][20],mxdep,ans[N];
struct SEG{
	int idx;
	int lc[N*80],rc[N*80],s[N*80];
	void pushup(int x){s[x]=s[lc[x]]+s[rc[x]];}
	void chp(int &x,int a,int v,int l,int r){
		if(!x) x=++idx;
		if(l==r) return s[x]+=v,void();
		int mid=(l+r)>>1;
		if(a<=mid) chp(lc[x],a,v,l,mid);
		else chp(rc[x],a,v,mid+1,r);
		pushup(x); 
	}
	int qry(int x,int a,int l,int r){
		if(l==r) return s[x];
		int mid=(l+r)>>1;
		if(a<=mid) return qry(lc[x],a,l,mid);
		return qry(rc[x],a,mid+1,r);
	}
	void merge(int &x,int y,int l,int r){
		if(!x||!y) return x+=y,void();
		if(l==r) return s[x]+=s[y],void();
		int mid=(l+r)>>1;
		merge(lc[x],lc[y],l,mid);
		merge(rc[x],rc[y],mid+1,r);
		pushup(x); 
	}
}tr;
struct Edge{int nxt,to;}e[N<<1];
void add(int u,int v){e[++idx]={head[u],v},head[u]=idx;}
void dfs(int u){
	mxdep=max(mxdep,dep[u]=dep[fa[u][0]]+1);
	for(int i=0;i<19;i++) fa[u][i+1]=fa[fa[u][i]][i];
	for(int i=head[u],v;i;i=e[i].nxt)
		if((v=e[i].to)!=fa[u][0]) fa[v][0]=u,dfs(v);
}
int LCA(int u,int v){
	if(dep[u]<dep[v]) swap(u,v);
	for(int i=19;~i;i--) if(dep[fa[u][i]]>=dep[v]) u=fa[u][i];
	if(u==v) return u;
	for(int i=19;~i;i--) if(fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
	return fa[u][0];
}
void dfs2(int u){
	for(int i=head[u],v;i;i=e[i].nxt){
		if((v=e[i].to)==fa[u][0]) continue;
		dfs2(v);
		tr.merge(root[u][0],root[v][0],1,mxdep);
		tr.merge(root[u][1],root[v][1],-mxdep,mxdep<<1);
	}
	if(dep[u]+w[u]<=mxdep) ans[u]+=tr.qry(root[u][0],dep[u]+w[u],1,mxdep);
	if(dep[u]-w[u]>=-mxdep) ans[u]+=tr.qry(root[u][1],dep[u]-w[u],-mxdep,mxdep<<1);
}
signed main(){
	ios::sync_with_stdio(0),cin.tie(0),cout.tie(0);
	cin>>n>>m;
	for(int i=1,u,v;i<n;i++) cin>>u>>v,add(u,v),add(v,u);
	for(int i=1;i<=n;i++) cin>>w[i];
	dfs(1);
	for(int i=1,u,v,l;i<=m;i++){
		cin>>u>>v,l=LCA(u,v);
		tr.chp(root[u][0],dep[u],1,1,mxdep);
		tr.chp(root[l][0],dep[u],-1,1,mxdep);
		tr.chp(root[v][1],(dep[l]<<1)-dep[u],1,-mxdep,mxdep<<1);
		tr.chp(root[fa[l][0]][1],(dep[l]<<1)-dep[u],-1,-mxdep,mxdep<<1);
	}
	dfs2(1);
	for(int i=1;i<=n;i++) cout<<ans[i]<<" ";
	return 0;
}

\(\text{[Fin.]}\)

posted @ 2025-06-21 10:51  Sinktank  阅读(222)  评论(4)    收藏  举报
★CLICK FOR MORE INFO★ TOP-BOTTOM-THEME
Enable/Disable Transition
Copyright © 2023 ~ 2025 Sinktank - 1328312655@qq.com
Illustration from 稲葉曇『リレイアウター/Relayouter/中继输出者』,by ぬくぬくにぎりめし.