线段树合并 学习笔记

其实就是把两颗线段树合到一起。

比如这题:P4556 [Vani有约会] 雨天的尾巴 /【模板】线段树合并。发现只会在最后查询,所以可以先考虑树上差分。给每种食物建一个桶,最后从下到上加起来就好了。但是这样还是太慢,用线段树的话可以 \(O(\log n)\) 查询最小值,但是要怎么合并信息呢?

不妨按照树上差分的思路,从下到上合并线段树!

我们递归处理:

  • 如果两个节点都为叶子节点,那么可以直接合并。如本题就是相加。
  • 如果两个节点一个有一个没有,那么可以直接沿用有的那个。
  • 如果两个节点都是空的,不用管它。

然后就是线段树上查找最小值节点编号了。

注意事项:

本题由于树上差分会有小于 \(0\) 的节点出现,此时权值为 \(0\) 的节点会被判为最小。要注意到应当在最后记录答案的时候判断最小值是否为 \(0\)

代码是自己没看过板子写的,很丑。

点击查看代码
#include<bits/stdc++.h>

#define pii pair<int,int> 
#define pll pair<long long,long long> 
#define ll long long
#define i128 __int128

#define mem(a,b) memset((a),(b),sizeof(a))
#define m0(a) memset((a),0,sizeof(a))
#define m1(a) memset(a,-1,sizeof(a))
#define lb(x) ((x)&-(x))
#define lc(x) ((x)<<1)
#define rc(x) (((x)<<1)|1)
#define pb(G,x) (G).push_back((x))
#define For(a,b,c) for(int a=(b);a<=(c);a++)
#define Rep(a,b,c) for(int a=(b);a>=(c);a--)
#define in1(a) a=read()
#define in2(a,b) a=read(), b=read()
#define in3(a,b,c) a=read(), b=read(), c=read()
#define in4(a,b,c,d) a=read(), b=read(), c=read(), d=read()
#define fst first 
#define scd second 
#define dbg puts("IAKIOI")

using namespace std;

int read() {
	int x=0,f=1; char c=getchar();
	for(;c<'0'||c>'9';c=getchar()) f=(c=='-'?-1:1); 
	for(;c<='9'&&c>='0';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
	return x*f;
}
void write(int x) { if(x>=10) write(x/10); putchar('0'+x%10); }

const int mod = 998244353;
int qpo(int a,int b) {int res=1; for(;b;b>>=1,a=(a*a)%mod) if(b&1) res=res*a%mod; return res; }
int inv(int a) {return qpo(a,mod-2); }

#define maxn 100050

int iidx;
struct Tp {
	int l,r,ls,rs,val,idx;
}tr[maxn<<6];

void psu(int idx) {
	if((tr[idx].ls==0&&tr[idx].rs==0)||max(tr[tr[idx].ls].val,tr[tr[idx].rs].val)==0) return ;
	if(tr[idx].rs==0) tr[idx].val=tr[tr[idx].ls].val,tr[idx].idx=tr[tr[idx].ls].idx;
	else if(tr[idx].ls==0) tr[idx].val=tr[tr[idx].rs].val,tr[idx].idx=tr[tr[idx].rs].idx;
	else {
		if(tr[tr[idx].ls].val>=tr[tr[idx].rs].val)
			tr[idx].val=tr[tr[idx].ls].val,tr[idx].idx=tr[tr[idx].ls].idx;
		else 
			tr[idx].val=tr[tr[idx].rs].val,tr[idx].idx=tr[tr[idx].rs].idx;
	}
}

void add(int idx,int l,int r,int k,int val) {
//	cout<<"Add:"<<idx<<' '<<l<<' '<<r<<' '<<k<<' '<<val<<'\n';
	if(l==r) {
		tr[idx].idx=l;
		tr[idx].val+=val;
		return ;
	}
	int mid=l+r>>1;
	if(k<=mid) {
		if(tr[idx].ls==0) tr[idx].ls=++iidx,tr[tr[idx].ls].l=l,tr[tr[idx].ls].r=mid;
		add(tr[idx].ls,l,mid,k,val);
	} else {
		if(tr[idx].rs==0) tr[idx].rs=++iidx,tr[tr[idx].rs].l=mid+1,tr[tr[idx].rs].r=r;
		add(tr[idx].rs,mid+1,r,k,val);
	}
	psu(idx);
}

Tp query(int idx,int l,int r,int L,int R) {
	if(L<=l&&r<=R) return tr[idx];
	int mid=l+r>>1;
	Tp ans={0,0,0,0,0};
	if(L<=mid) {
		if(tr[idx].ls!=0) 
			ans=query(tr[idx].ls,l,mid,L,R);
	}
	if(R>mid) {
		if(tr[idx].rs!=0) {
			Tp res=query(tr[idx].rs,mid+1,r,L,R);
			if(res.val>ans.val) ans=res;
		}
	}
	return ans;
}

void uni(int idx1,int idx2,int l,int r) { //将以 idx2 为根的子树合并到以 idx1 为根的子树里面去
	
//	cout<<"Union:"<<idx1<<' '<<idx2<<' '<<l<<' '<<r<<'\n';
	
	if(l==r) {
		tr[idx1].val+=tr[idx2].val;
		return ;
	}
	int mid=l+r>>1;
	
	if(tr[idx1].ls==0&&tr[idx2].ls!=0) { tr[idx1].ls=tr[idx2].ls; }
	else if(tr[idx1].ls!=0&&tr[idx2].ls!=0) uni(tr[idx1].ls,tr[idx2].ls,l,mid);
	
	if(tr[idx1].rs==0&&tr[idx2].rs!=0) { tr[idx1].rs=tr[idx2].rs; }
	else if(tr[idx1].rs!=0&&tr[idx2].rs!=0) uni(tr[idx1].rs,tr[idx2].rs,mid+1,r);

	psu(idx1);
}

vector<int> G[maxn];

struct LCA {
	int dep[maxn],fa[26][maxn];
	void dfs(int u,int fath) {
		fa[0][u]=fath;
		dep[u]=dep[fath]+1;
		int sz=log2(dep[u]);
		For(i,1,sz) fa[i][u]=fa[i-1][fa[i-1][u]];
		for(auto v:G[u]) if(v!=fath) dfs(v,u);
	}
	int query(int x,int y) {
		if(dep[x]<dep[y]) swap(x,y);
		while(dep[x]>dep[y]) {
			x=fa[(int)(log2(dep[x]-dep[y]))][x];
		}
		if(x==y) return x;
		int sz=log2(dep[x]);
		Rep(i,sz,0) if(fa[i][x]!=fa[i][y]) x=fa[i][x],y=fa[i][y];
		return fa[0][x];
	}
}Lca;

int n,m;
const int N=1e5;
int rt[maxn],ans[maxn];

void dfs(int u,int fa) {
	
	for(auto v:G[u]) if(v!=fa) {
		dfs(v,u);
		uni(rt[u],rt[v],1,N);
	}
	Tp res=query(rt[u],1,N,1,N);
	ans[u]=res.idx;
//	cout<<"dfs:"<<u<<' '<<res.l<<' '<<res.r<<' '<<res.val<<' '<<res.idx<<'\n';
}
void work() {
	in2(n,m);
	For(i,2,n) {
		int x,y;
		in2(x,y);
		G[x].push_back(y);
		G[y].push_back(x);
	}
	Lca.dfs(1,0);
	For(i,1,n) { rt[i]=++iidx;tr[rt[i]].l=1;tr[rt[i]].r=N; }
	For(i,1,m) {
		int x,y,z;
		in3(x,y,z);
		int top=Lca.query(x,y);
//		cout<<x<<' '<<y<<' '<<top<<' '<<Lca.fa[0][top]<<'\n';
		if(Lca.fa[0][top]!=0) add(rt[Lca.fa[0][top]],1,N,z,-1);
		add(rt[top],1,N,z,-1);
		add(rt[x],1,N,z,1);
		add(rt[y],1,N,z,1);
	}
	dfs(1,-1);
	For(i,1,n) cout<<ans[i]<<'\n';
}

signed main() {
//	freopen("data.in","r",stdin);
//	freopen("myans.out","w",stdout);
//	ios::sync_with_stdio(false); 
//	cin.tie(0); cout.tie(0);
	double stt=clock();
	int _=1;
//	_=read();
//	cin>>_;
	For(i,1,_) {
		work();
	}
	cerr<<"\nTotal Time is:"<<(clock()-stt)*1.0/1000<<" second(s)."<<'\n';
	return 0;
}

也可以看看这题

点击查看代码
#include<bits/stdc++.h>

#define pii pair<int,int> 
#define pll pair<long long,long long> 
#define ll long long
#define i128 __int128

#define mem(a,b) memset((a),(b),sizeof(a))
#define m0(a) memset((a),0,sizeof(a))
#define m1(a) memset(a,-1,sizeof(a))
#define lb(x) ((x)&-(x))
#define lc(x) ((x)<<1)
#define rc(x) (((x)<<1)|1)
#define pb(G,x) (G).push_back((x))
#define For(a,b,c) for(int a=(b);a<=(c);a++)
#define Rep(a,b,c) for(int a=(b);a>=(c);a--)
#define in1(a) a=read()
#define in2(a,b) a=read(), b=read()
#define in3(a,b,c) a=read(), b=read(), c=read()
#define in4(a,b,c,d) a=read(), b=read(), c=read(), d=read()
#define fst first 
#define scd second 
#define dbg puts("IAKIOI")

using namespace std;

int read() {
	int x=0,f=1; char c=getchar();
	for(;c<'0'||c>'9';c=getchar()) f=(c=='-'?-1:1); 
	for(;c<='9'&&c>='0';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
	return x*f;
}
void write(int x) { if(x>=10) write(x/10); putchar('0'+x%10); }

const int mod = 998244353;
int qpo(int a,int b) {int res=1; for(;b;b>>=1,a=(a*a)%mod) if(b&1) res=res*a%mod; return res; }
int inv(int a) {return qpo(a,mod-2); }

#define maxn 100050

int n,m;
int p[maxn];
struct Dsu {
	int fa[maxn];
	void pre(int n) {For(i,1,n) fa[i]=i; }
	int fnd(int x) {return x==fa[x]?fa[x]:fa[x]=fnd(fa[x]); }
}x;

struct SegT {
	struct node {
		int sum,idx,l,r;
	}tr[maxn<<6]; int idxcnt;
	int root[maxn];
	void psu(int idx) {
		tr[idx].sum=0;
		if(tr[idx].l) tr[idx].sum=tr[tr[idx].l].sum;
		if(tr[idx].r) tr[idx].sum+=tr[tr[idx].r].sum;
	}
	void modi(int idx,int l,int r,int k,int val,int u) {
		if(l==r) {
			tr[idx].idx=u;
			tr[idx].sum=val;
			return ;
		}
		int mid=l+r>>1;
		if(k<=mid) {
			if(!tr[idx].l) tr[idx].l=++idxcnt;
			modi(tr[idx].l,l,mid,k,val,u);
		} else {
			if(!tr[idx].r) tr[idx].r=++idxcnt;
			modi(tr[idx].r,mid+1,r,k,val,u);
		}
		psu(idx);
	}
	int query(int idx,int l,int r,int k) {
		if(l==r) return (k==1)?tr[idx].idx:-1;
		int mid=l+r>>1;
		if(tr[idx].l&&tr[tr[idx].l].sum>=k) return query(tr[idx].l,l,mid,k);
		if(tr[idx].r) return query(tr[idx].r,mid+1,r,k-(tr[idx].l!=0?tr[tr[idx].l].sum:0));
		return -1;
	}
	void uni(int idx1,int idx2,int l,int r) {//将 idx2 合并到 idx1 中
		if(l==r) return ; //理论上这题应该不会有这种情况
		int mid=l+r>>1;
		
		if(tr[idx1].l==0&&tr[idx2].l) tr[idx1].l=tr[idx2].l;
		else if(tr[idx1].l&&tr[idx2].l) uni(tr[idx1].l,tr[idx2].l,l,mid);
		
		if(tr[idx1].r==0&&tr[idx2].r) tr[idx1].r=tr[idx2].r;
		else if(tr[idx1].r&&tr[idx2].r) uni(tr[idx1].r,tr[idx2].r,mid+1,r);
		
		psu(idx1);
	}
}Tr;

void work() {
	in2(n,m);x.pre(n);
	For(i,1,n) in1(p[i]);
	For(i,1,n) Tr.root[i]=i; Tr.idxcnt=n;
	For(i,1,n) Tr.modi(Tr.root[i],1,n,p[i],1,i);
	For(i,1,m) {
		int u,v; in2(u,v); u=x.fnd(u),v=x.fnd(v);
		if(u==v) continue;
		x.fa[v]=u;
		Tr.uni(Tr.root[u],Tr.root[v],1,n);
	}
	int q=read();
	while(q--) {
		char ch=getchar(); while(ch!='Q'&&ch!='B') ch=getchar();
		int a,b; in2(a,b);
		if(ch=='Q') cout<<Tr.query(Tr.root[x.fnd(a)],1,n,b)<<'\n';
		else {
			a=x.fnd(a),b=x.fnd(b); if(a==b) continue;
			x.fa[b]=a; Tr.uni(Tr.root[a],Tr.root[b],1,n);
		}
	}
}

signed main() {
//	freopen("data.in","r",stdin);
//	freopen("myans.out","w",stdout);
//	ios::sync_with_stdio(false); 
//	cin.tie(0); cout.tie(0);
	double stt=clock();
	int _=1;
//	_=read();
//	cin>>_;
	For(i,1,_) {
		work();
	}
	cerr<<"\nTotal Time is:"<<(clock()-stt)*1.0/1000<<" second(s)."<<'\n';
	return 0;
}
posted @ 2025-07-10 15:24  coding_goat_qwq  阅读(27)  评论(0)    收藏  举报