可持久化线段树/并查集
引入
在有些题目中,在进行一系列更改以后,我们需要访问之前某次更改的版本。对于传统的线段树来说,这自然不易达成。然而我们可以记录一下每一次修改的版本。每改一次直接开一棵新的线段树即可。下面通过一道例题,阐述一下可持久化线段树的思想与实现。
SPOJ TTM - To the moon
题意翻译
一个长度为 \(N\) 的数组 \(\{A\}\),\(4\) 种操作 :
\(C\) \(l\) \(r\) \(d\):区间 \([l,r]\) 中的数都加 \(d\) ,同时当前的时间戳加 \(1\)。
\(Q\) \(l\) \(r\):查询当前时间戳区间 \([l,r]\) 中所有数的和 。
\(H\) \(l\) \(r\) \(t\):查询时间戳 \(t\) 区间 \([l,r]\) 的和 。
\(B\) \(t\):将当前时间戳置为 \(t\) 。
所有操作均合法 。
\(ps\):刚开始时时间戳为 \(0\)
输入格式,一行 \(N\) 和 \(M\),接下来 \(M\) 行每行一个操作
输出格式:对每个查询输出一行表示答案
数据保证:\(1\le N,M\le 10^5\),\(|A_i|\le 10^9\),\(1\le l \le r \le N\),\(|d|\le10^4\)。在刚开始没有进行操作的情况下时间戳为 \(0\),且保证 \(B\) 操作不会访问到未来的时间戳。
思路:
如果没有 \(H\) 和 \(B\) 操作的话,这一道题用普通的线段树就能很好地解决。现在,我们需要记录每一个历史版本,且要求能够对历史版本进行修改,即完全可持久化。
如何记录每一个版本?首先我们要建一棵动态开点的线段树,目的是为了节省空间和便于修改。假设我们现在有一个长度为 \(4\) 的数组,每个节点记录了该区间的区间和,如下图:

其中,一号节点对应区间 \([1,6]\),二号节点对应区间 \([1,3]\),三号节点对应区间 \([4,6]\),以此类推。
假设我们现在要给区间 \([2,3]\) 的所有数加上一个数,由于七号节点对应的位置的数组下标是二,五号节点对应的数组位置下标为三,那么值将要被修改的节点只有 \(1\),\(2\),\(4\),\(5\),\(7\) 这五个节点。我们新建五个点,将 \(1\),\(2\),\(4\),\(5\),\(7\) 几个节点的信息一一复制到新开的五个点中,再在这新建的五个点上修改,新建点后的树如下图:

其中,\(12\) 号节点对应 \(1\) 号节点,\(13\) 号节点对应 \(2\) 号节点。其它节点同理。容易发现,每进行一次区间修改,只会影响到有限个数的点。这样,我们每次修改所新建的节点个数就是 \(log\) 级别的,在接受的范围内。
在新建节点中,我们还在 \(14\) 和 \(6\), \(12\) 和 \(3\) 之间分别连接了一条边。为什么要这么做呢?
其实,我们只复制了需要更改的点,并将他们"替换"到以前的树上,这样,就可以以较小的修改得到完整的树。同时,为了记录每一个版本的树,我们需要一个 \(root\) 数组。\(root[i]\) 表示第 \(i\) 个版本的树的根节点的编号。
在这里还有一个问题:如果这样去写,区间标记的下传会变得很麻烦,这里采用标记永久化即可。
代码如下:
int root[MAXN],tot;//tot 为总节点数
struct node{
int ls,rs,sum,tag;//sum为该区间数字和,tag为在这个区间上的总标记
}tree[MAXN + 5];
vector<int> lsh;
int add(int i,int tl,int tr,int l,int r,int val){//返回值为进行这次操作后新增节点的编号
tree[++tot] = tree[i];//复制节点
int ci = tot;//最新节点编号
tree[ci].sum += (min(r,tr) - max(tl,l) + 1) * val;//修改值
if(tl >= l && tr <= r){
tree[ci].tag += val;
return ci;
}
int mid = (tl + tr) / 2;
if(tl <= mid){
tree[ci].ls = add(tree[i].ls,tl,mid,l,r,val);
}
if(r > mid){
tree[ci].rs = add(tree[i].rs,mid + 1,tr,l,r,val);
}
return ci;
}
查询代码:
int query(int i,int tl,int tr,int l,int r){
int ans = (min(r,tr) - max(l,tl) + 1) * tree[i].tag;
if(tl >=l && tr <= r)return tree[i].sum;
int mid = (tl + tr) / 2;
if(tl <= mid)ans += query(tree[i].ls,tl,mid,l,r);
if(tr > mid)ans += query(tree[i].rs,mid + 1,tr,l,r);
return ans;
}
并查集的结构是一个森林,因此只需建 \(n\) 棵树,类比上述代码即可。
部分可持久化:
完全可持久化要求对历史版本能进行操作和查值。它的空间复杂度一般比较大。在一些题目中,我们只需要查找值,不需要对过去版本进行修改,那么,这种只查找,在最新版本上进行修改的操作就叫部分可持久化。
部分可持久化其实很好实现。假设你有一个数组,你只需要给每一个位置开一个 \(vector<pair<int,int>>\),其中第一关键字记录其版本号,第二关键字记录其值,每修改一次就 \(push\) 一下,查找的时候二分一个版本号即可,这样就能实现部分可持久化。
第 \(k\) 大问题
第 \(k\) 大问题,即要求一个数列在区间 \([l,r]\) 中第 \(k\) 大的数是多少,且一般带有修改操作,强制在线。用可持久化线段树能较好地解决这些问题。
先看看不带修改操作的。
从左至右,每插入一个数,就构成了一个新的版本。我们建立 \(n\) 棵值域线段树,第 \(i\) 棵线段树表示从左至右插入了 \(i\) 个数的情况。那么,对于区间 \([l,r]\) 的情况,只需看看从第 \(l-1\) 棵线段树到第 \(r\) 棵线段树增加了多少元素,再在增加的这一部分元素形成的线段树上进行二分,求得第 \(k\) 大。当然实际操作中我们并不需要真的把那棵线段树求出来,只需要在两个版本的值域线段树上同时二分即可。
实现如下:
#define mid (tl + tr >> 1)
int root[MAXN + 5],n,a[MAXN + 5],tot,m;
struct node{
int sum,ls,rs;//sum表示该节点下一共有多少个数
}tree[MAXN + 5];
int add(int i,int tl,int tr,int pos){
int ci = ++tot;
tree[ci] = tree[i];
if(tl == tr){
tree[ci].sum ++;
return ci;
}
if(pos <= mid){
tree[ci].ls = add(tree[i].ls,tl,mid,pos);
}
else{
tree[ci].rs = add(tree[i].rs,mid + 1,tr,pos);
}
tree[ci].sum = tree[tree[ci].ls].sum + tree[tree[ci].rs].sum;
return ci;
}
int main(){//假设所有数字被离散化,值域为[1,m]
cin >> n;
for(int i = 1; i <= n; i++){
cin >> a[i];
add(root[i - 1],1,m,a[i]);
}
}
int query(int x,int y,int tl,int tr,int k){//查询操作
if(tl == tr)return tr;
int lsum = tree[tree[x].ls].sum;
int rsum = tree[tree[y].ls].sum;
if(rsum - lsum < k){
return query(x,y,mid + 1,tr,k - rsum + lsum);
}
else return query(x,y,tl,mid,k);
}
如果需要处理带修改的区间第 \(k\) 大,考虑到如果改变 \(i\) 位置的数,那么第 \(i\) 个版本及以后的线段树必然会需要修改,这也是一个区间操作,所以一般采用树状数组套在线段树外面。但其修改,查询思想还是类似的。
代码:
#include<bits/stdc++.h>
#define mid (tl + tr >> 1)
using namespace std;
const int MAXN = 3e6;
int n,m,a[MAXN + 5],l,r,k,tot,root[MAXN + 5],lseg[MAXN + 5],rseg[MAXN + 5];;
string s;
vector<int> lsh;
struct node{
int ls,rs,sum;
}tree[4 * MAXN + 5];
struct Mes{
bool flag;
int l,r,k;
}mes[MAXN + 5];
void insert(int i,int tl,int tr,int val,int ad){
if(tl == tr){
tree[i].sum += ad;
return;
}
if(mid >= val){
if(tree[i].ls == 0)tree[i].ls = ++tot;
insert(tree[i].ls,tl,mid,val,ad);
}
else{
if(tree[i].rs == 0)tree[i].rs = ++tot;
insert(tree[i].rs,mid + 1,tr,val,ad);
}
tree[i].sum = tree[tree[i].ls].sum + tree[tree[i].rs].sum;
}
int lowbit(int i){
return i & (-i);
}
void add(int pos,int num){
for(int i = pos; i <= n; i += lowbit(i)){
if(root[i] == 0)root[i] = ++tot;
insert(root[i],1,lsh.size(),num,1);
}
}
void change(int pos,int num){
for(int i = pos; i <= n; i += lowbit(i)){
insert(root[i],1,lsh.size(),a[pos],-1);
}
for(int i = pos; i <= n; i += lowbit(i)){
insert(root[i],1,lsh.size(),num,1);
}
a[pos] = num;
}
int query(int tl,int tr,int k)
{
if(tl==tr)return tl;
int sum=0;
for(int i=1;i<=rseg[0];++i){
int now = rseg[i];
sum += tree[tree[now].ls].sum;
}
for(int i=1;i<=lseg[0];++i){
int now = lseg[i];
sum -= tree[tree[now].ls].sum;
}
if(k<=sum)
{
for(int i=1;i<=rseg[0];++i)
rseg[i]=tree[rseg[i]].ls;
for(int i=1;i<=lseg[0];++i)
lseg[i]=tree[lseg[i]].ls;
return query(tl,mid,k);
}
else
{
for(int i=1;i<=rseg[0];++i)
rseg[i]=tree[rseg[i]].rs;
for(int i=1;i<=lseg[0];++i)
lseg[i]=tree[lseg[i]].rs;
return query(mid + 1,tr,k-sum);
}
}
int get_ans(int l,int r,int k)
{
lseg[0]=rseg[0]=0;
l--;
while(l)//分别记录组成两棵树的树根编号
{
lseg[++lseg[0]]=root[l];
l-=lowbit(l);
}
while(r)
{
rseg[++rseg[0]]=root[r];
r-=lowbit(r);
}
return query(1,lsh.size(),k);
}
signed main(){
scanf("%d%d",&n,&m);
for(int i = 1; i <= n; i++){//n个数
scanf("%d",&a[i]);
lsh.push_back(a[i]);
}
for(int i = 1; i <= m; i++){//m次操作,将他们记录下来
cin >> s;
if(s[0] == 'Q'){//询问操作
scanf("%d%d%d",&l,&r,&k);
mes[i].flag = 1;
mes[i].l = l,mes[i].r = r,mes[i].k = k;
}
else {//修改操作
scanf("%d%d",&l,&k);
mes[i].l = l,mes[i].k = k;
lsh.push_back(k);
}
}
sort(lsh.begin(),lsh.end());
lsh.erase(unique(lsh.begin(),lsh.end()),lsh.end());
for(int i = 1; i <= n; i++){
a[i] = lower_bound(lsh.begin(),lsh.end(),a[i]) - lsh.begin() + 1;
}
for(int i = 1; i <= n; i++)add(i,a[i]);
for(int i = 1; i <= m; i++){
if(mes[i].flag){
int ans = get_ans(mes[i].l,mes[i].r,mes[i].k);
printf("%d\n",lsh[ans - 1]);
}
else{
int k = lower_bound(lsh.begin(),lsh.end(),mes[i].k) - lsh.begin() + 1;
change(mes[i].l,k);//利用树状数组完成修改
}
}
}
可持久化线段树的综合应用
T1(bzoj4504):
兔子们在玩 \(k\) 个串的游戏。首先,它们拿出了一个长度为 \(n\) 的数字序列,选出其中的一个连续子串,然后统计其子串中所有数字之和(注意这里重复出现的数字只被统计一次)。兔子们想知道,在这个数字序列所有连续的子串中,按照以上方式统计其所有数字之和,第 \(k\) 大的和是多少。
关于可持久化线段树的题,我想最关键的的就是寻找问题中的 阶段性条件。即可以从之前的状态一步步推到后面的状态的条件,这一点与动态规划有一定相似之处。在这个题中,我们假设一个数组 \(sumax[i][l][r]\) 表示左端点在区间 \([l,r]\) 中,右端点在 \(i\) 这个位置上时区间的最大和(我们在具体实现的时候采用主席树实现这个功能)。容易发现,由 \(sumax[i][l][r]\) 能够推出 \(sum[i + 1][l][r]\),只需要加上 \(a[i + 1]\) 这个位置的数即可。通过枚举右端点,就可以如此一层层地推完整个 \(sumax\)。
具体实现来说,我们按照右端点的位置为线段树编号。当右端点位置在 \(i\) 时,它对应的版本就是 \(i\)。\(root[i]\) 表示第 \(i\) 个版本的线段树的根节点编号。
同时,这个题还要求的是对区间数去重后求和,对此,还需要引入一个 \(pre\) 数组记录数 \(x\) 在 \(i\) 位置前最后的出现位置。具体使用参见代码中的注释。主席树中这种pre数组的思想也是很常见的,具体还有以下题目:MEX,Boring Queries。
为了求第 \(k\) 大的和,我们需要优先队列,并引入一个五元组 \(v,x,l,r,p\),\(v\) 表示区间和,\(l,r\) 表示左端点范围, \(p\) 表示右端点的位置,\(x\) 表示对应的根节点编号。优先队列每次弹出 \(v\) 最大的一个五元组,对于这个五元组,\(v\) 的次大值只会存在 \(v1,x,l,p - 1,p1\),\(v2,x,p + 1,r,p2\)中,再将它们扔进去即可。反复 \(k\) 次,就可以求出第 \(k\) 大值。
具体实现如下:
#include<bits/stdc++.h>
#define int long long
#define mid (tl + tr >> 1)
using namespace std;
const int MAXN = 3e6;
int tot,n,k,a[MAXN + 5],root[MAXN + 5];
map<int,int> pre;//记录 i 前最后出现位置
struct no{//五元组
long long v;
int x,l,r,p;
bool operator<(const no a)const{
return this->v < a.v;;
}
no(){}
no(int a,int b,int c,int d,int e){
v =a;
x = b;
l = c;
r = d;
p = e;
}
}tmp;
struct node{
int ls,rs,tag;
pair<int,int> v;//v.first表示在这个节点里对应的区间和最大值,v.second表示当区间和取最大值时的左端点下标
}tree[4 * MAXN + 5];
priority_queue<no> q;
int add(int y,int p){//给节点y内的区间和值加上一个p
int x = ++tot;
tree[x] = tree[y];
tree[x].v.first += p;
tree[x].tag += p;
return x;
}
void push_down(int i){
tree[i].ls = add(tree[i].ls,tree[i].tag);
tree[i].rs = add(tree[i].rs,tree[i].tag);
tree[i].tag = 0;
}
int build(int tl,int tr){
int x = ++tot;
tree[x].v = make_pair(0,tl);
if(tl == tr)return x;
tree[x].ls = build(tl,mid);
tree[x].rs = build(mid + 1,tr);
return x;
}
int modify(int i,int tl,int tr,int l,int r,int p){
if(tl >= l && tr <= r)return add(i,p);
if(tree[i].tag)push_down(i);
int x = ++tot;
tree[x] = tree[i];
if(l <= mid)tree[x].ls = modify(tree[i].ls,tl,mid,l,r,p);
if(r > mid)tree[x].rs = modify(tree[i].rs,mid + 1,tr,l,r,p);
tree[x].v = max(tree[tree[x].ls].v,tree[tree[x].rs].v);
return x;
}
pair<int,int> query(int i,int tl,int tr,int l,int r){//右端点在()位置,左端点在区间[l,r]
if(tl == l && tr == r)return tree[i].v;
if(tree[i].tag)push_down(i);//标记处理
if(r <= mid)return query(tree[i].ls,tl,mid,l,r);
else if(l > mid)return query(tree[i].rs,mid + 1,tr,l,r);
return max(query(tree[i].ls,tl,mid,l,mid),query(tree[i].rs,mid +1 ,tr,mid + 1,r));
}
void extend(int i,int l,int r){//扩展一个五元组丢到优先队列里
if(l > r)return;
pair<int,int> t = query(i,1,n,l,r);
//cout << t.first << " " << t.second << "\n";
q.push(no(t.first,i,l,r,t.second));
}
signed main(){
//freopen("1.in","r",stdin);
scanf("%lld%lld",&n,&k);
root[0] = build(1,n);
for(int i = 1; i <= n; i++){
scanf("%lld",&a[i]);
root[i] = modify(root[i - 1],1,n,pre[a[i]] + 1,i,a[i]);//由上一个版本推到下一个版本。因为每个版本都是层层推进的,所以保证了1-pre[a[i]]之间是加上了a[i]的,因此只需要在[pre[a[i]] + 1,i]之间加上a[i]就可以满足去重的要求
pre[a[i]] = i;
extend(root[i],1,i);
}
while(k--){//循环 k 次求第 k 大
tmp = q.top();
q.pop();
extend(tmp.x,tmp.l,tmp.p - 1);
extend(tmp.x,tmp.p + 1,tmp.r);
}
cout << tmp.v << "\n";
}
T2 P2839 [国家集训队]middle
传送门
这里先介绍一个二分求中位数的方法:
对于一个序列 \(a\),我们假设 \(x\) 为这个数列的中位数,另开一个数组 \(b\),如果 \(a[i] >= x\),那么 \(b[i]\) 就赋值为 \(1\),否则赋值为 \(-1\)。之后对整个 \(b\) 数组求和,如果结果大于零,那么说明 \(x\) 小于真正的中位数,如果小于零,那么 \(x\) 大于真正的中位数,当等于零时,就找到了中位数。如此进行二分即可。
基于这个思想来思考这个题。我们同样地引入 \(b\) 数组,二分中位数 \(x\)。假设求出的最优区间的左端点为 \(l\) 在区间 \([a,b]\) 内,右端点 \(r\) 在区间 \([c,d]\) 内。可见 \([l,r]\) 一定会包含区间 \([b + 1,c - 1]\)。当我们对 \(b\) 数组中的 \([l,r]\) 区间求和时,也必将会求和 \([b + 1,c - 1]\) 这一段,因此我们需要用线段树来维护区间和。
另一方面,我们还希望求得尽可能大的中位数。那么也就需要在对 \(b\) 数组求出的区间和尽量大,因此还需要维护 \(b\) 数组的区间最大值。当然,对于一个不同的 \(x\),\(b\) 数组是会不同的。因此我们有不同版本的 \(b\) 数组,这里引入可持久化线段树进行维护。当 \(x\) 转为 \(x + 1\) 时,有使 \(a[i]\) 等于 \(x\) 的 \(i\),这些 \(i\) 所对应的 \(b[i]\) 会变为 \(-1\)。如此挨个修改,得到不同版本的 \(b\) 数组。
每一次得到一个 \(x\),我们就在 \(x\) 对应的 \(b\) 数组版本上求含区间 \([b + 1,c - 1]\) 的最大区间和,如果这个和大于 \(0\),说明 \(x\) 还可以更大,否则就将 \(x\) 变小点,如此不断二分,就能得到满足题意的最大中位数。
代码:
#include<bits/stdc++.h>
#define mid ((tl + tr) >> 1)
using namespace std;
const int MAXN = 5e6;
int tot,n,a[MAXN + 5],root[MAXN + 56],q;
vector<int> lsh;
int que[10],pos[20005][1000];
struct node{
int mx,mn,ls,rs,sum,lsx,rsx,cl,cr;
}tree[4 * MAXN + 5];
void push_up(int i){
tree[i].sum = tree[tree[i].ls].sum + tree[tree[i].rs].sum;
tree[i].lsx = max(tree[tree[i].ls].lsx,tree[tree[i].ls].sum + tree[tree[i].rs].lsx);
tree[i].rsx = max(tree[tree[i].rs].rsx,tree[tree[i].rs].sum + tree[tree[i].ls].rsx);
tree[i].mn = min(tree[tree[i].ls].mn,tree[tree[i].rs].mn);
tree[i].mx = max(tree[tree[i].ls].mx,tree[tree[i].rs].mx);
}
int build(int tl,int tr){
int x = ++tot;
if(tl == tr){
tree[x].sum = tree[x].lsx = tree[x].rsx = 1;
tree[x].mn = tree[x].mx = a[tl];
return x;
}
tree[x].ls = build(tl,mid);
tree[x].rs = build(mid + 1,tr);
push_up(x);
return x;
}
int update(int x,int y,int tl,int tr,int to,int val,bool flag){
if(!flag)x = ++tot;
if(tl == tr){
tree[x].sum = tree[x].lsx = tree[x].rsx = val;
tree[x].mx = -1e9;
tree[x].mn = 1e9;
return x;
}
if(to <= mid){
if(!tree[x].cr)tree[x].rs = tree[y].rs;
if(!tree[x].cl){
tree[x].cl = 1;
tree[x].ls = update(x,tree[y].ls,tl,mid,to,val,0);
}
else tree[x].ls = update(tree[x].ls,tree[y].ls,tl,mid,to,val,1);
}
else{
if(!tree[x].cl)tree[x].ls = tree[y].ls;
if(!tree[x].cr){
tree[x].cr = 1;
tree[x].rs = update(x,tree[y].rs,mid + 1,tr,to,val,0);
}
else tree[x].rs = update(tree[x].rs,tree[y].rs,mid + 1,tr,to,val,1);
}
push_up(x);
return x;
}
node query(int i,int tl,int tr,int l,int r){
node aa,b,ans = {(int)-1e9,(int)1e9,(int)-1e9,(int)-1e9,0,(int)-1e9,(int)-1e9,(int)-1e9,(int)-1e9};
if(l > r)return ans;
if(tl >= l && tr <= r)return tree[i];
if(tl > r || tr < l)return ans;
aa = query(tree[i].ls,tl,mid,l,r),b = query(tree[i].rs,mid + 1,tr,l,r);
ans.lsx = max(aa.lsx,aa.sum + b.lsx);
ans.rsx = max(b.rsx,b.sum + aa.rsx);
ans.sum = aa.sum + b.sum;
ans.mn = min(aa.mn,b.mn);
ans.mx = max(aa.mx,b.mx);
return ans;
}
int Query(int aa,int b,int c,int d){
int l = 0,r = lsh.size() + 1;
node k = query(root[1],1,n,aa,d);
l = k.mn - 1,r = k.mx + 1;
bool flag = 0;
while(l + 1 < r){
int m = (l + r) / 2;
node B = query(root[m],1,n,b + 1,c - 1),A = query(root[m],1,n,aa,b),C = query(root[m],1,n,c,d);
int sum = B.sum + A.rsx + C.lsx;
if(sum >= 0){
flag = 1;
l = m;
}
else r = m;
}
if(!flag)return 0;
return l;
}
int main(){
// freopen("data","r",stdin);
// freopen("ans1","w",stdout);
scanf("%d",&n);
for(int i = 1; i <= n; i++){
scanf("%d",&a[i]);
lsh.push_back(a[i]);
}
sort(lsh.begin(),lsh.end());
lsh.erase(unique(lsh.begin(),lsh.end()),lsh.end());
for(int i = 1; i <= n; i++){
a[i] = lower_bound(lsh.begin(),lsh.end(),a[i]) - lsh.begin() + 1;
pos[a[i]][++pos[a[i]][0]] = i;
//pos[a[i]].push_back(i);
}
root[1] = build(1,n);
for(int i = 2; i <= (int)lsh.size(); i++){
for(int j = 1; j <= pos[i - 1][0]; j++){
int to = pos[i - 1][j];
root[i] = update(root[i],root[i - 1],1,n,to,-1,root[i] > 0);
}
}
scanf("%d",&q);
int lastans = 0;
int aa,b,c,d;
for(int i = 1; i <= q; i++){
scanf("%d%d%d%d",&aa,&b,&c,&d);
aa = (aa + lastans) % n;b = (b + lastans) % n;c = (c + lastans) % n;d = (d + lastans) % n;
que[0] = 0;que[1] = aa;que[2] = b;que[3] = c;que[4] = d;
sort(que + 1,que + 5);
aa = que[1],b = que[2],c = que[3],d = que[4];
aa++;b++;c++;d++;
int ans = Query(aa,b,c,d);
if(ans == 0){
cout << "0\n";
lastans = 0;
continue;
}
printf("%d\n",lsh[ans - 1]);
lastans = lsh[ans - 1];
}
}

浙公网安备 33010602011771号