线段树合并 学习笔记
其实就是把两颗线段树合到一起。
比如这题:P4556 [Vani有约会] 雨天的尾巴 /【模板】线段树合并。发现只会在最后查询,所以可以先考虑树上差分。给每种食物建一个桶,最后从下到上加起来就好了。但是这样还是太慢,用线段树的话可以 \(O(\log n)\) 查询最小值,但是要怎么合并信息呢?
不妨按照树上差分的思路,从下到上合并线段树!
我们递归处理:
- 如果两个节点都为叶子节点,那么可以直接合并。如本题就是相加。
- 如果两个节点一个有一个没有,那么可以直接沿用有的那个。
- 如果两个节点都是空的,不用管它。
然后就是线段树上查找最小值节点编号了。
注意事项:
本题由于树上差分会有小于 \(0\) 的节点出现,此时权值为 \(0\) 的节点会被判为最小。要注意到应当在最后记录答案的时候判断最小值是否为 \(0\)。
代码是自己没看过板子写的,很丑。
点击查看代码
#include<bits/stdc++.h>
#define pii pair<int,int>
#define pll pair<long long,long long>
#define ll long long
#define i128 __int128
#define mem(a,b) memset((a),(b),sizeof(a))
#define m0(a) memset((a),0,sizeof(a))
#define m1(a) memset(a,-1,sizeof(a))
#define lb(x) ((x)&-(x))
#define lc(x) ((x)<<1)
#define rc(x) (((x)<<1)|1)
#define pb(G,x) (G).push_back((x))
#define For(a,b,c) for(int a=(b);a<=(c);a++)
#define Rep(a,b,c) for(int a=(b);a>=(c);a--)
#define in1(a) a=read()
#define in2(a,b) a=read(), b=read()
#define in3(a,b,c) a=read(), b=read(), c=read()
#define in4(a,b,c,d) a=read(), b=read(), c=read(), d=read()
#define fst first
#define scd second
#define dbg puts("IAKIOI")
using namespace std;
int read() {
int x=0,f=1; char c=getchar();
for(;c<'0'||c>'9';c=getchar()) f=(c=='-'?-1:1);
for(;c<='9'&&c>='0';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
return x*f;
}
void write(int x) { if(x>=10) write(x/10); putchar('0'+x%10); }
const int mod = 998244353;
int qpo(int a,int b) {int res=1; for(;b;b>>=1,a=(a*a)%mod) if(b&1) res=res*a%mod; return res; }
int inv(int a) {return qpo(a,mod-2); }
#define maxn 100050
int iidx;
struct Tp {
int l,r,ls,rs,val,idx;
}tr[maxn<<6];
void psu(int idx) {
if((tr[idx].ls==0&&tr[idx].rs==0)||max(tr[tr[idx].ls].val,tr[tr[idx].rs].val)==0) return ;
if(tr[idx].rs==0) tr[idx].val=tr[tr[idx].ls].val,tr[idx].idx=tr[tr[idx].ls].idx;
else if(tr[idx].ls==0) tr[idx].val=tr[tr[idx].rs].val,tr[idx].idx=tr[tr[idx].rs].idx;
else {
if(tr[tr[idx].ls].val>=tr[tr[idx].rs].val)
tr[idx].val=tr[tr[idx].ls].val,tr[idx].idx=tr[tr[idx].ls].idx;
else
tr[idx].val=tr[tr[idx].rs].val,tr[idx].idx=tr[tr[idx].rs].idx;
}
}
void add(int idx,int l,int r,int k,int val) {
// cout<<"Add:"<<idx<<' '<<l<<' '<<r<<' '<<k<<' '<<val<<'\n';
if(l==r) {
tr[idx].idx=l;
tr[idx].val+=val;
return ;
}
int mid=l+r>>1;
if(k<=mid) {
if(tr[idx].ls==0) tr[idx].ls=++iidx,tr[tr[idx].ls].l=l,tr[tr[idx].ls].r=mid;
add(tr[idx].ls,l,mid,k,val);
} else {
if(tr[idx].rs==0) tr[idx].rs=++iidx,tr[tr[idx].rs].l=mid+1,tr[tr[idx].rs].r=r;
add(tr[idx].rs,mid+1,r,k,val);
}
psu(idx);
}
Tp query(int idx,int l,int r,int L,int R) {
if(L<=l&&r<=R) return tr[idx];
int mid=l+r>>1;
Tp ans={0,0,0,0,0};
if(L<=mid) {
if(tr[idx].ls!=0)
ans=query(tr[idx].ls,l,mid,L,R);
}
if(R>mid) {
if(tr[idx].rs!=0) {
Tp res=query(tr[idx].rs,mid+1,r,L,R);
if(res.val>ans.val) ans=res;
}
}
return ans;
}
void uni(int idx1,int idx2,int l,int r) { //将以 idx2 为根的子树合并到以 idx1 为根的子树里面去
// cout<<"Union:"<<idx1<<' '<<idx2<<' '<<l<<' '<<r<<'\n';
if(l==r) {
tr[idx1].val+=tr[idx2].val;
return ;
}
int mid=l+r>>1;
if(tr[idx1].ls==0&&tr[idx2].ls!=0) { tr[idx1].ls=tr[idx2].ls; }
else if(tr[idx1].ls!=0&&tr[idx2].ls!=0) uni(tr[idx1].ls,tr[idx2].ls,l,mid);
if(tr[idx1].rs==0&&tr[idx2].rs!=0) { tr[idx1].rs=tr[idx2].rs; }
else if(tr[idx1].rs!=0&&tr[idx2].rs!=0) uni(tr[idx1].rs,tr[idx2].rs,mid+1,r);
psu(idx1);
}
vector<int> G[maxn];
struct LCA {
int dep[maxn],fa[26][maxn];
void dfs(int u,int fath) {
fa[0][u]=fath;
dep[u]=dep[fath]+1;
int sz=log2(dep[u]);
For(i,1,sz) fa[i][u]=fa[i-1][fa[i-1][u]];
for(auto v:G[u]) if(v!=fath) dfs(v,u);
}
int query(int x,int y) {
if(dep[x]<dep[y]) swap(x,y);
while(dep[x]>dep[y]) {
x=fa[(int)(log2(dep[x]-dep[y]))][x];
}
if(x==y) return x;
int sz=log2(dep[x]);
Rep(i,sz,0) if(fa[i][x]!=fa[i][y]) x=fa[i][x],y=fa[i][y];
return fa[0][x];
}
}Lca;
int n,m;
const int N=1e5;
int rt[maxn],ans[maxn];
void dfs(int u,int fa) {
for(auto v:G[u]) if(v!=fa) {
dfs(v,u);
uni(rt[u],rt[v],1,N);
}
Tp res=query(rt[u],1,N,1,N);
ans[u]=res.idx;
// cout<<"dfs:"<<u<<' '<<res.l<<' '<<res.r<<' '<<res.val<<' '<<res.idx<<'\n';
}
void work() {
in2(n,m);
For(i,2,n) {
int x,y;
in2(x,y);
G[x].push_back(y);
G[y].push_back(x);
}
Lca.dfs(1,0);
For(i,1,n) { rt[i]=++iidx;tr[rt[i]].l=1;tr[rt[i]].r=N; }
For(i,1,m) {
int x,y,z;
in3(x,y,z);
int top=Lca.query(x,y);
// cout<<x<<' '<<y<<' '<<top<<' '<<Lca.fa[0][top]<<'\n';
if(Lca.fa[0][top]!=0) add(rt[Lca.fa[0][top]],1,N,z,-1);
add(rt[top],1,N,z,-1);
add(rt[x],1,N,z,1);
add(rt[y],1,N,z,1);
}
dfs(1,-1);
For(i,1,n) cout<<ans[i]<<'\n';
}
signed main() {
// freopen("data.in","r",stdin);
// freopen("myans.out","w",stdout);
// ios::sync_with_stdio(false);
// cin.tie(0); cout.tie(0);
double stt=clock();
int _=1;
// _=read();
// cin>>_;
For(i,1,_) {
work();
}
cerr<<"\nTotal Time is:"<<(clock()-stt)*1.0/1000<<" second(s)."<<'\n';
return 0;
}
也可以看看这题。
点击查看代码
#include<bits/stdc++.h>
#define pii pair<int,int>
#define pll pair<long long,long long>
#define ll long long
#define i128 __int128
#define mem(a,b) memset((a),(b),sizeof(a))
#define m0(a) memset((a),0,sizeof(a))
#define m1(a) memset(a,-1,sizeof(a))
#define lb(x) ((x)&-(x))
#define lc(x) ((x)<<1)
#define rc(x) (((x)<<1)|1)
#define pb(G,x) (G).push_back((x))
#define For(a,b,c) for(int a=(b);a<=(c);a++)
#define Rep(a,b,c) for(int a=(b);a>=(c);a--)
#define in1(a) a=read()
#define in2(a,b) a=read(), b=read()
#define in3(a,b,c) a=read(), b=read(), c=read()
#define in4(a,b,c,d) a=read(), b=read(), c=read(), d=read()
#define fst first
#define scd second
#define dbg puts("IAKIOI")
using namespace std;
int read() {
int x=0,f=1; char c=getchar();
for(;c<'0'||c>'9';c=getchar()) f=(c=='-'?-1:1);
for(;c<='9'&&c>='0';c=getchar()) x=(x<<1)+(x<<3)+(c^48);
return x*f;
}
void write(int x) { if(x>=10) write(x/10); putchar('0'+x%10); }
const int mod = 998244353;
int qpo(int a,int b) {int res=1; for(;b;b>>=1,a=(a*a)%mod) if(b&1) res=res*a%mod; return res; }
int inv(int a) {return qpo(a,mod-2); }
#define maxn 100050
int n,m;
int p[maxn];
struct Dsu {
int fa[maxn];
void pre(int n) {For(i,1,n) fa[i]=i; }
int fnd(int x) {return x==fa[x]?fa[x]:fa[x]=fnd(fa[x]); }
}x;
struct SegT {
struct node {
int sum,idx,l,r;
}tr[maxn<<6]; int idxcnt;
int root[maxn];
void psu(int idx) {
tr[idx].sum=0;
if(tr[idx].l) tr[idx].sum=tr[tr[idx].l].sum;
if(tr[idx].r) tr[idx].sum+=tr[tr[idx].r].sum;
}
void modi(int idx,int l,int r,int k,int val,int u) {
if(l==r) {
tr[idx].idx=u;
tr[idx].sum=val;
return ;
}
int mid=l+r>>1;
if(k<=mid) {
if(!tr[idx].l) tr[idx].l=++idxcnt;
modi(tr[idx].l,l,mid,k,val,u);
} else {
if(!tr[idx].r) tr[idx].r=++idxcnt;
modi(tr[idx].r,mid+1,r,k,val,u);
}
psu(idx);
}
int query(int idx,int l,int r,int k) {
if(l==r) return (k==1)?tr[idx].idx:-1;
int mid=l+r>>1;
if(tr[idx].l&&tr[tr[idx].l].sum>=k) return query(tr[idx].l,l,mid,k);
if(tr[idx].r) return query(tr[idx].r,mid+1,r,k-(tr[idx].l!=0?tr[tr[idx].l].sum:0));
return -1;
}
void uni(int idx1,int idx2,int l,int r) {//将 idx2 合并到 idx1 中
if(l==r) return ; //理论上这题应该不会有这种情况
int mid=l+r>>1;
if(tr[idx1].l==0&&tr[idx2].l) tr[idx1].l=tr[idx2].l;
else if(tr[idx1].l&&tr[idx2].l) uni(tr[idx1].l,tr[idx2].l,l,mid);
if(tr[idx1].r==0&&tr[idx2].r) tr[idx1].r=tr[idx2].r;
else if(tr[idx1].r&&tr[idx2].r) uni(tr[idx1].r,tr[idx2].r,mid+1,r);
psu(idx1);
}
}Tr;
void work() {
in2(n,m);x.pre(n);
For(i,1,n) in1(p[i]);
For(i,1,n) Tr.root[i]=i; Tr.idxcnt=n;
For(i,1,n) Tr.modi(Tr.root[i],1,n,p[i],1,i);
For(i,1,m) {
int u,v; in2(u,v); u=x.fnd(u),v=x.fnd(v);
if(u==v) continue;
x.fa[v]=u;
Tr.uni(Tr.root[u],Tr.root[v],1,n);
}
int q=read();
while(q--) {
char ch=getchar(); while(ch!='Q'&&ch!='B') ch=getchar();
int a,b; in2(a,b);
if(ch=='Q') cout<<Tr.query(Tr.root[x.fnd(a)],1,n,b)<<'\n';
else {
a=x.fnd(a),b=x.fnd(b); if(a==b) continue;
x.fa[b]=a; Tr.uni(Tr.root[a],Tr.root[b],1,n);
}
}
}
signed main() {
// freopen("data.in","r",stdin);
// freopen("myans.out","w",stdout);
// ios::sync_with_stdio(false);
// cin.tie(0); cout.tie(0);
double stt=clock();
int _=1;
// _=read();
// cin>>_;
For(i,1,_) {
work();
}
cerr<<"\nTotal Time is:"<<(clock()-stt)*1.0/1000<<" second(s)."<<'\n';
return 0;
}
本文来自博客园,作者:coding_goat_qwq,转载请注明原文链接:https://www.cnblogs.com/CodingGoat/p/18976968

浙公网安备 33010602011771号