线段树合并与分裂

前言

前置芝士:
动态开点线段树
权值线段树
比肩 pushdown 的灵活。

原理

如何高效的存储多棵线段树

既然要把线段树合并,那么程序里肯定就不止一棵线段树了。这时要考虑怎么把这么多线段树高效的储存起来。

传统线段树里起手一个 \(N\times 4\) 大小的数组显然不行了,我们要用动态开点线段树。

我们知道动态开点线段树给每个区间代表的点赋予了一个唯一的编号,编号与线段树代表的区间之间完全没有关系。既然如此,一整棵线段树的编号可以不连续,这也就是说,我们可以把一棵线段树塞到一个数组里了。

这种存储方法有时称作 内存池法

要注意,当你要写线段树合并分裂的时候,万万不可使用 vector 这种动态分配内存的数据结构,原因在动态开点线段树中有解释。

我们另外定义一个数组 root 存储每一棵线段树根节点的编号,这样就能实现快速锁定需要的一棵线段树了。

当然这个数组就可以用 vector 啦。

下面是简短的示例:

t[0]=seg{0,0,0};//哨兵
t[1]=seg{0,0,0};//一号线段树的根节点
cnt=1;//计数器一定一定要变成 1!!!
root.push_back(0);
root.push_back(1);

合并

这一块如果你知道 Fhq-Treap 就会感觉很熟悉。

定义函数 merge (x,y) 表示将编号为 \(y\) 的线段树合并到 \(x\) 上,然后有三种情况:

  • 同一个区间,\(x,y\) 都有节点,这时候先把两个节点的值合并,然后递归处理左右儿子。处理完成后,删除 \(y\) 对应的节点。
  • 同一个区间,\(x\) 有但 \(y\) 没有对应节点,这时不用操作。
  • 同一个区间,\(x\) 没有但 \(y\) 有对应节点,这时候把 \(y\) 的节点给 \(x\)

当然,这里是否要进行删除操作随题目而定,如果不进行删除,则 \(x\)\(y\) 合并的节点却可以通过 \(y\) 访问到,可能会对 \(y\) 这个节点的信息造成影响,如果之后还会用到 \(y\) 的情况下。但是如果之后不再用 \(y\) 直接删了就可以。

下面直接给出合并和删除的代码:

点击查看代码
int malloc(){
  if(recyc.size()!=0){
    int p=recyc.back();
    recyc.pop_back();
    return p;
  }
  else{
    cnt++;
    return cnt;
  }
}
void del(int &p){
  recyc.push_back(p);
  t[p]=DEF;
  p=0;
  return ;
}
void merge(int &x,int y){
  if(x==0||y==0){
    x=x|y;
    return ;
  }
  else{
    t[x].v+=t[y].v;
    merge(t[x].l,t[y].l);
    merge(t[x].r,t[y].r);
    del(y);
  }
  return ;
}
删除操作的目的就是为了重新利用定长数组中的内存碎片。

merge 里用了一个引用传参来方便的更新关于 \(x\) 的值,不算难理解。

当然啦如果你想的话 merge 也可以这么写:

void merge(int &x,int &y,int l,int r){
  if(x==0||y==0){
    x=x|y;
    return ;
  }
  else if(l==r){// 不要少了这个!
    t[x].v+=t[y].v;
    del(y);
    return ;
  }
  else{
    int mid=l+r>>1;
    merge(t[x].l,t[y].l,l,mid);
    merge(t[x].r,t[y].r,mid+1,r);
    t[x].v=t[t[x].l].v+t[t[x].r].v;
    del(y);
  }
  return ;
}

分裂

把一棵线段树中的某一段连续区间分出去组成一个新的线段树。

定义函数 split (p,l,r,rl,rr) 表示要把 \(p\) 代表的 \([l,r]\) 区间内的 \([rl,rr]\) 区间分出去。并且返回分出去组成的线段树的根节点的编号。

首先如果 \([l,r] \subseteq [rl,rr]\) 显然这个区间整个都是要分出去的然后分开就行了。

如果不完全属于的话就分两半,每半向下递归。

然后因为我们不管是合并还是分裂的线段树其结构(即满线段树的样子)都是相同的,所以递归时可以直接接受返回值作为分裂出来线段树的左右孩子。

这个区间处理完成后更新一下就行了。

代码

int split(int &p,int l,int r,int rl,int rr){
  cnt++;
  int c=cnt;
  if(rl<=l&&r<=rr){
    t[c]=t[p];
    p=0;
  }
  else{
    int mid=l+r>>1;
    if(rl<=mid){
      t[c].l=split(t[p].l,l,mid,rl,rr);
    }
    if(mid<rr){
      t[c].r=split(t[p].r,mid+1,r,rl,rr);
    }
    t[c].v=t[t[c].l].v+t[t[c].r].v;
    t[p].v=t[t[p].l].v+t[t[p].r].v;
  }
  return c;
}

还是用了一下引用传参。

原理很简单,接下来看看例题吧!

例题

luogu P5494 【模板】线段树分裂

模板题嘛,没什么好说的。

注意使用权值线段树可以方便的把对值操作转化为对区间操作。

下面是代码了,可以通过右侧导航栏跳转。

代码时间

#define psb push_back
#define mkp make_pair
#define rep(i,a,b) for( int i=(a); i<=(b); ++i)
#define per(i,a,b) for( int i=(a); i>=(b); --i)
#define rd read()
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define DEF seg{0,0,0}
#define int ll
ll read(){
  ll x=0,f=1;
  char c=getchar();
  while(c>'9'||c<'0'){if(c=='-') f=-1;c=getchar();}
  while(c>='0'&&c<='9'){x=(x<<3)+(x<<1)+(c^48);c=getchar();}
  return x*f;
}
struct seg{
  int l,r;
  ll v;
};
seg t[4000005];
vector<int> recyc;
int cnt=0;
vector<int> root;
int malloc(){
  if(recyc.size()!=0){
    int p=recyc.back();
    recyc.pop_back();
    return p;
  }
  else{
    cnt++;
    return cnt;
  }
}
void del(int &p){
  recyc.push_back(p);
  t[p]=DEF;
  p=0;
  return ;
}
void ins(int p,int l,int r,int rl,int rr,ll k){
  if(rl<=l&&r<=rr){
    t[p].v+=k;
    return ;
  }
  int mid=l+r>>1;
  if(rl<=mid){
    if(t[p].l==0)t[p].l=malloc();
    ins(t[p].l,l,mid,rl,rr,k);
  }
  if(mid<rr){
    if(t[p].r==0)t[p].r=malloc();
    ins(t[p].r,mid+1,r,rl,rr,k);
  }

  t[p].v=t[t[p].l].v+t[t[p].r].v;
  return ;
}
int split(int &p,int l,int r,int rl,int rr){
  cnt++;
  int c=cnt;
  if(rl<=l&&r<=rr){t[c]=t[p];del(p);p=0;}
  else{
    int mid=l+r>>1;
    if(rl<=mid) t[c].l=split(t[p].l,l,mid,rl,rr);
    if(mid<rr)  t[c].r=split(t[p].r,mid+1,r,rl,rr);
    t[c].v=t[t[c].l].v+t[t[c].r].v;
    t[p].v=t[t[p].l].v+t[t[p].r].v;
  }
  return c;
}
void merge(int &x,int &y,int l,int r){
  if(x==0||y==0){x=x|y;return ;}
  else if(l==r){
    t[x].v+=t[y].v;
    del(y);
    return ;
  }
  else{
    int mid=l+r>>1;
    merge(t[x].l,t[y].l,l,mid);
    merge(t[x].r,t[y].r,mid+1,r);
    t[x].v=t[t[x].l].v+t[t[x].r].v;
    del(y);
  }
  return ;
}
ll query(int p,int l,int r,int rl,int rr){
  if(rl<=l&&r<=rr)return t[p].v;
  int mid=l+r>>1;
  ll sub=0;
  if(rl<=mid) sub+=query(t[p].l,l,mid,rl,rr);
  if(mid<rr)  sub+=query(t[p].r,mid+1,r,rl,rr);
  return sub;
}
ll kmin(int p,int l,int r,ll k){
  if(l==r)return l;
  int mid=l+r>>1;
  if(t[t[p].l].v>=k) return kmin(t[p].l,l,mid,k);
  else return kmin(t[p].r,mid+1,r,k-t[t[p].l].v);
}
signed main(){

  int n,m;
  cin>>n>>m;
  t[0]=seg{0,0,0};
  t[1]=seg{0,0,0};
  cnt=1;
  root.push_back(0);
  root.push_back(1);
  rep(i,1,n){
    ll cnt;
    cin>>cnt;
    if(cnt==0)continue;
    else ins(root[1],1,n,i,i,cnt);
  }
  rep(i,1,m){
    int op;
    cin>>op;
    if(op==0){
      ll p,x,y;
      cin>>p>>x>>y;
      root.push_back(split(root[p],1,n,x,y));
    }
    else if(op==1){
      ll p,t;
      cin>>p>>t;
      merge(root[p],root[t],1,n);
    }
    else if(op==2){
      ll p,x,q;
      cin>>p>>x>>q;
      ins(root[p],1,n,q,q,x);
    }
    else if(op==3){
      ll p,l,r;
      cin>>p>>l>>r;
      cout<<query(root[p],1,n,l,r)<<'\n';
    }
    else if(op==4){
      ll p,k;
      cin>>p>>k;
      if(t[root[p]].v<k)cout<<-1<<'\n';
      else cout<<kmin(root[p],1,n,k)<<'\n';
    }
  }
  
  return 0;
}

没做多少呢,过几天补上(逃

posted @ 2025-02-21 20:52  hm2ns  阅读(37)  评论(0)    收藏  举报