Loading

浅浅浅浅谈主席树

是什么

是神仙数据结构,可以维护一些持久化的东西,是我们的伟大领袖

可持久化的意思就是记录每个历史版本,可以根据历史版本做一些操作。

怎么做

如果对于每个版本都开一棵线段树那空间受不了,我们考虑对每个版本产生修改的链开新的节点,打补丁一样附加在原树上。

具体实现因题而异

void insert(int &y,int x,int l,int r,int p){//y是当前版本,x是要以它为基础修改的版本
    y=++tot,ls[y]=ls[x],rs[y]=rs[x],sz[y]=sz[x]+1;
    //sum[y]=sum[x]+? 这里修改每个节点存的值
    if(l==r) return;
    int mid=l+r>>1;
    if(p<=mid) insert(ls[y],ls[x],l,mid,p);
    else insert(rs[y],rs[x],mid+1,r,p);
}

然后是查询,就是在给定的版本上用类比线段树查询的方式查询就完事了,拓展很多,就不给出代码了,不要限制了思路。

一些模板

Luogu P3919 【模板】可持久化线段树 1(可持久化数组)

初始版本为序列本身,对每个修改都开个新的版本,查询和修改就是在给定的版本里面用线段树的方式修改就行了。

#include<iostream>
#include<cstdio>
#define N 1000005 
using namespace std;
int n,m,a[N],sum[N*20],ls[N*20],rs[N*20],rt[N],tot;
void update(int &y,int x,int l,int r,int p,int v){
	y=++tot;
	ls[y]=ls[x],rs[y]=rs[x],sum[y]=sum[x];
	if(l==r){
		sum[y]=v;
		return;
	}
	int mid=l+r>>1;
	if(p<=mid) update(ls[y],ls[x],l,mid,p,v);
	else update(rs[y],rs[x],mid+1,r,p,v);
}
int query(int y,int l,int r,int k){
	if(l==r) return sum[y];
	int mid=l+r>>1;
	if(k<=mid) return query(ls[y],l,mid,k);
	else return query(rs[y],mid+1,r,k);
}
void build(int &y,int l,int r){
	y=++tot;
	if(l==r){
		sum[y]=a[l];
		return;
	}
	int mid=l+r>>1;
	build(ls[y],l,mid);
	build(rs[y],mid+1,r);
}
int main(){
	scanf("%d%d",&n,&m);
	for(int i=1;i<=n;i++){
		scanf("%d",&a[i]);
	}
	build(rt[0],1,n);
	int vi,t,x,y,i=0;
	while(m--){
		i++;
		scanf("%d%d%d",&vi,&t,&x);
		if(t==1){
			scanf("%d",&y);
			update(rt[i],rt[vi],1,n,x,y);
		}
		else{
			printf("%d\n",query(rt[vi],1,n,x));
			rt[i]=rt[vi];
		}
	}
	return 0;
}

实际上主席树并不是一定要建树的

Luogu P3834 【模板】可持久化线段树 2(主席树)

双倍经验

经典静态区间第 \(k\)​ 大问题,考虑用主席树来做。

先进行离散化。

我们以权值为下标建主席树,对于每一个元素都开一个历史版本,维护一个 \(sum\)​ 数组表示这个节点的子树大小(也可以理解为这个点被加了几个历史版本或者前缀和),问题转化为了在这个区间的左右端点的历史版本之间在主席树中第 \(k\)​ 个数是什么(主席树中是以权值为下标并且我们钦定它为从小到大排序)。当前节点通过两个历史版本的 \(sum\) 数组相减得到当前节点在两个历史版本之间到底加了几个版本(因为对于每个插入的数都新建了一个历史版本,等价于这两个历史版本之间有几个数),来判断是选左子树还是右子树。

#include<bits/stdc++.h>
using namespace std;
template <class T>
inline T read(){
    int ans=0,f=0;char ch=getchar();
    while(!isdigit(ch)) f|=ch=='-',ch=getchar();
    while(isdigit(ch)) ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
    return f?-ans:ans;
}
const int N=3e5+5;
struct zfz{
    int v,id;
}a[N];
int n,m;
int sum[N*20],ls[N*20],rs[N*20],tot;
int rank[N],rt[N];
bool cmp(zfz x,zfz y){return x.v<y.v;}
void insert(int &y,int x,int l,int r,int p){
    y=++tot,sum[y]=sum[x]+1,ls[y]=ls[x],rs[y]=rs[x];
    if(l==r) return;
    int mid=l+r>>1;
    if(p<=mid) insert(ls[y],ls[x],l,mid,p);
    else insert(rs[y],rs[x],mid+1,r,p);
}
int query(int y,int x,int l,int r,int k){
    if(l==r) return l;
    int mid=l+r>>1;
    if(k<=sum[ls[y]]-sum[ls[x]]) return query(ls[y],ls[x],l,mid,k);
    else return query(rs[y],rs[x],mid+1,r,k-(sum[ls[y]]-sum[ls[x]]));
}
int main(){
    n=read<int>(),m=read<int>();
    for(int i=1;i<=n;++i) a[i].v=read<int>(),a[i].id=i;
    sort(a+1,a+1+n,cmp);
    for(int i=1;i<=n;++i) rank[a[i].id]=i;
    for(int i=1;i<=n;++i) insert(rt[i],rt[i-1],1,n,rank[i]);
    while(m--){
        int l=read<int>(),r=read<int>(),k=read<int>();
        printf("%d\n",a[query(rt[r],rt[l-1],1,n,k)].v);
    }
    return 0;
}

Luogu P2617 Dynamic Rankings

带修改的主席树。

静态区间的主席树是两个区间的左右端点对应版本的前缀和相减得到的这个区间,如果修改的话就要对每个历史版本的前缀和都修改一遍,复杂度是 \(O(nlogn)\) 的。考虑怎么优化,我们可以用树状数组维护前缀和,这样修改就变为了 \(O(logn)\)

查询时,依旧是 \(R\) 位置减去 \(L-1\) 位置,这时候不再是两棵线段树作差,而是 \(log\) 棵线段树与 \(log\) 棵线段树作差;跳的时候,\(log\)​ 个节点一起跳到左子树/右子树。

#include<bits/stdc++.h>
#define lowbit(x) x&(-x)
using namespace std;
inline int read(){
	int ans=0,f=1;char ch=getchar();
	while(!isdigit(ch)){if(ch=='-') f=-f;ch=getchar();}
	while(isdigit(ch)){ans=(ans<<3)+(ans<<1)+ch-48;ch=getchar();}
	return ans*f;
}
const int INF=1e9,N=1e5+5;
int a[N];
int b[N<<1],len;
struct QK{
	int x,y,z;
}q[N];
int rt[N],sum[400*N],ls[400*N],rs[400*N],n,m,tot;
int ggl,ggr,ll[N],rr[N];
void insert(int &y,int x,int l,int r,int p,int v){
	y=++tot,sum[y]=sum[x]+v;
	ls[y]=ls[x],rs[y]=rs[x];
	if (l==r) return;
	int mid=l+r>>1;
	if (p<=mid) insert(ls[y],ls[x],l,mid,p,v);
	else insert(rs[y],rs[x],mid+1,r,p,v);
}
int query(int l,int r,int q){
	if (l==r) return l;
	int mid=l+r>>1,cnt=0;
	for(int i=1;i<=ggl;i++) cnt-=sum[ls[ll[i]]];
	for(int i=1;i<=ggr;i++) cnt+=sum[ls[rr[i]]];
	if(q<=cnt){
		for(int i=1;i<=ggl;i++) ll[i]=ls[ll[i]];
		for(int i=1;i<=ggr;i++) rr[i]=ls[rr[i]];
		return query(l,mid,q);
	}
	else{
		for(int i=1;i<=ggl;i++) ll[i]=rs[ll[i]];
		for(int i=1;i<=ggr;i++) rr[i]=rs[rr[i]];
		return query(mid+1,r,q-cnt);
	}
}
void addn(int x,int v){
	int k=lower_bound(b+1,b+1+len,a[x])-b;
	for(int i=x;i<=n;i+=lowbit(i)) insert(rt[i],rt[i],1,len,k,v);//处理出要改哪些树
}

int main(){
	cin>>n>>m;
	for(int i=1;i<=n;i++) cin>>a[i],b[++len]=a[i];
	for(int i=1;i<=m;i++){
		char ch;
		cin>>ch>>q[i].x>>q[i].y;
		if(ch=='Q') cin>>q[i].z;
		else b[++len]=q[i].y;
	}
	sort(b+1,b+1+len);
	len=unique(b+1,b+1+len)-b-1;
	for(int i=1;i<=n;i++) addn(i,1);
	for(int i=1;i<=m;i++){
		if(q[i].z){
			ggl=ggr=0;
			for(int j=q[i].x-1;j;j-=lowbit(j)) ll[++ggl]=rt[j];//处理出哪些树一起跳
			for(int j=q[i].y;j;j-=lowbit(j)) rr[++ggr]=rt[j];
			cout<<b[query(1,len,q[i].z)]<<endl;
		}
		else addn(q[i].x,-1),a[q[i].x]=q[i].y,addn(q[i].x,1);
	}
	return 0;
}

Luogu P2633 Count on a tree

树上主席树,因为主席树是维护前缀和的结构,联系到树上差分所以我们可以对于每个节点维护根到当前节点的前缀和主席树,最后询问就是 \(siz[x]+siz[y]-siz[lca]-siz[falca]\),询问的时候 \(x,y,lca,falca\) 一起跳就好了。

#include<bits/stdc++.h>
#define id(u) lower_bound(b+1,b+len+1,a[u])-b
#define int64 long long
using namespace std;
template <class T>
inline T read(){
    int ans=0,f=0;char ch=getchar();
    while(!isdigit(ch)) f|=ch=='-',ch=getchar();
    while(isdigit(ch)) ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
    return f?-ans:ans;
}
const int N=1e5+5,INF=1e9;
int n,Q,ans,len;
int b[N],a[N];
int hd[N],nx[N<<1],to[N<<1],tote;
int dep[N],fa[N],top[N],sz[N],son[N];
int rt[N],ls[N*20],rs[N*20],siz[N*20],tot;
void adde(int u,int v){
    nx[++tote]=hd[u];to[tote]=v;hd[u]=tote;
    nx[++tote]=hd[v];to[tote]=u;hd[v]=tote;
}
void insert(int &y,int x,int l,int r,int p){
    y=++tot;ls[y]=ls[x],rs[y]=rs[x],siz[y]=siz[x]+1;
    if(l==r) return;
    int mid=l+r>>1;
    if(p<=mid) insert(ls[y],ls[x],l,mid,p);
    else insert(rs[y],rs[x],mid+1,r,p);
}
void dfs1(int u,int father){
    dep[u]=dep[father]+1;fa[u]=father;sz[u]=1;
    insert(rt[u],rt[father],1,len,id(u));
    for(int i=hd[u];i;i=nx[i]){
        int v=to[i];
        if(v==father) continue;
        dfs1(v,u);
        sz[u]+=sz[v];
        if(sz[v]>sz[son[u]]) son[u]=v;
    }
}
void dfs2(int u,int anc){
    top[u]=anc;
    if(son[u]) dfs2(son[u],anc);
    for(int i=hd[u];i;i=nx[i]){
        int v=to[i];
        if(v==fa[u]||v==son[u]) continue;
        dfs2(v,v);
    }
}
int LCA(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return x;
}
int query(int x,int y,int lca,int falca,int l,int r,int p){
    if(l==r) return l;
    int mid=l+r>>1,sum=siz[ls[x]]+siz[ls[y]]-siz[ls[lca]]-siz[ls[falca]];
    if(p<=sum) return query(ls[x],ls[y],ls[lca],ls[falca],l,mid,p);
    else return query(rs[x],rs[y],rs[lca],rs[falca],mid+1,r,p-sum);
}
int query(int x,int y,int k){
    int lca=LCA(x,y);
    return query(rt[x],rt[y],rt[lca],rt[fa[lca]],1,len,k);
}
int main(){
    n=read<int>(),Q=read<int>();
    for(int i=1;i<=n;++i) b[i]=a[i]=read<int>();
    sort(b+1,b+1+n);
    len=unique(b+1,b+1+n)-b-1;
    for(int i=1;i<n;++i){
        int x=read<int>(),y=read<int>();
        adde(x,y);
    }
    dfs1(1,0);
    dfs2(1,1);
    while(Q--){
        int x=read<int>()^ans,y=read<int>(),k=read<int>();
        printf("%d\n",ans=b[query(x,y,k)]);
    }
    return 0;
}

一些应用

Luogu P1383 高级打字机

先从水题开始。

显然我们需要维护一个可持久化的东西来支持撤销和插入,所以我们考虑用模板1的思路来做,记录 \(len\) 数组为当前版本的串的长度,对于撤销操作就是把要撤销到的版本的信息复制到当前的版本既可。

#include<bits/stdc++.h>
#define int64 long long
using namespace std;
template <class T>
inline T read(){
    int ans=0,f=0;char ch=getchar();
    while(!isdigit(ch)) f|=ch=='-',ch=getchar();
    while(isdigit(ch)) ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
    return f?-ans:ans;
}
const int N=1e5+5,INF=1e9,M=1e5;
int now;
int sum[N*20],rt[N],tot,ls[N*20],rs[N*20];
int len[N];
void insert(int &y,int x,int l,int r,int p,int v){
    y=++tot,sum[y]=sum[x],ls[y]=ls[x],rs[y]=rs[x];
    if(l==r){sum[y]=v;return;}
    int mid=l+r>>1;
    if(p<=mid) insert(ls[y],ls[x],l,mid,p,v);
    else insert(rs[y],rs[x],mid+1,r,p,v);
}
int query(int y,int l,int r,int p){
    if(l==r) return sum[y];
    int mid=l+r>>1;
    if(p<=mid) return query(ls[y],l,mid,p);
    else return query(rs[y],mid+1,r,p);
}
int main(){
    for(int Q=read<int>();Q;--Q){
        char opt,s;
        int x;
        cin>>opt;
        if(opt=='T') cin>>s,++now,len[now]=len[now-1]+1,insert(rt[now],rt[now-1],1,M,len[now],(int)s);
        else if(opt=='U'){
            cin>>x;
            int pre=max(now-x,0);
            len[++now]=len[pre],rt[now]=rt[pre];
        }
        else cin>>x,printf("%c\n",query(rt[now],1,M,x));
    }
    return 0;
}

Luogu P2468 [SDOI2010]粟粟的书架

这题可以尝试着写一个二维主席树练习一下卡空间。

正常做法还是分开两个问题来做,对于第一个问题,我们二分最少要用的书的页数,然后搞一下就可以搞出答案。(这个二分好神仙啊)

对于第二个问题,我们考虑用主席树来维护,对于主席树内每个节点的值,我们考虑维护区间和,查询的时候先查右半部分的,如果不满足则答案区间一定是右半部分一整个加上左半部分的一些,根据这个性质查询即可。

#include<bits/stdc++.h>
#define int64 long long
using namespace std;
template <class T>
inline T read(){
    int ans=0,f=0;char ch=getchar();
    while(!isdigit(ch)) f|=ch=='-',ch=getchar();
    while(isdigit(ch)) ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
    return f?-ans:ans;
}
const int N=5e5+5,INF=1e9;
int r,c,m;
namespace solve1{
    int a[205][205],sum[205][205][1005],cnt[205][205][1005],Max=-INF;
    int getsum(int a1,int b1,int a2,int b2,int v){
        return sum[a2][b2][v]-sum[a1-1][b2][v]-sum[a2][b1-1][v]+sum[a1-1][b1-1][v];
    }
    int getcnt(int a1,int b1,int a2,int b2,int v){
        return cnt[a2][b2][v]-cnt[a1-1][b2][v]-cnt[a2][b1-1][v]+cnt[a1-1][b1-1][v];
    }
    void main(){
        for(int i=1;i<=r;++i)
            for(int j=1;j<=c;++j)
                a[i][j]=read<int>(),Max=max(Max,a[i][j]);
        for(int k=0;k<=Max;++k)
            for(int i=1;i<=r;++i)
                for(int j=1;j<=c;++j)
                    sum[i][j][k]=sum[i-1][j][k]+sum[i][j-1][k]-sum[i-1][j-1][k]+(a[i][j]>=k?a[i][j]:0),
                    cnt[i][j][k]=cnt[i-1][j][k]+cnt[i][j-1][k]-cnt[i-1][j-1][k]+(a[i][j]>=k?1:0);
        while(m--){
            int a1=read<int>(),b1=read<int>(),a2=read<int>(),b2=read<int>(),h=read<int>();
            int l=0,r=Max,ans=-1;
            while(l<=r){
                int mid=l+r>>1;
                if(getsum(a1,b1,a2,b2,mid)>=h) l=mid+1,ans=mid;
                else r=mid-1;
            }
            if(!~ans) printf("Poor QLW\n");
            else printf("%d\n",getcnt(a1,b1,a2,b2,ans)-(getsum(a1,b1,a2,b2,ans)-h)/ans);
        }
    }
}
namespace solve2{
    int n;
    int rt[N],sum[N*20],ls[N*20],rs[N*20],sz[N*20],tot;
    void insert(int &y,int x,int l,int r,int p){
        y=++tot;sum[y]=sum[x]+p;ls[y]=ls[x];rs[y]=rs[x];sz[y]=sz[x]+1;
        if(l==r) return;
        int mid=l+r>>1;
        if(p<=mid) insert(ls[y],ls[x],l,mid,p);
        else insert(rs[y],rs[x],mid+1,r,p);
    }
    int query(int y,int x,int l,int r,int k){
        if(l==r) return (k-1)/l+1;//ceil()
        int mid=l+r>>1;
        if(k<=sum[rs[y]]-sum[rs[x]]) return query(rs[y],rs[x],mid+1,r,k);
        else return query(ls[y],ls[x],l,mid,k-(sum[rs[y]]-sum[rs[x]]))+(sz[rs[y]]-sz[rs[x]]);
    }
    void main(){
        for(int i=1,a;i<=c;++i) a=read<int>(),insert(rt[i],rt[i-1],1,1000,a);
        while(m--){
            int a1=read<int>(),b1=read<int>(),a2=read<int>(),b2=read<int>(),h=read<int>();
            if(sum[rt[b2]]-sum[rt[b1-1]]<h){printf("Poor QLW\n");continue;}
            else printf("%d\n",query(rt[b2],rt[b1-1],1,1000,h));
        }
    }
}
int main(){
    r=read<int>(),c=read<int>(),m=read<int>();
    if(r>1) solve1::main();
    else solve2::main();
    return 0;
}

Luogu P2839 [国家集训队]middle

是神题,我们考虑二分答案中位数的位置,并把大于等于二分出的数的数标为 \(1\),把小于的数标为 \(-1\),显然,如果对于给定的区间,区间和大于等于 \(0\) 则说明这个二分出来的数小于等于中位数,反之大于,所以我们可以由此确定中位数。

但是这题的区间不是给定的,而是左右端点都在一个区间里选,而且为了让中位数尽可能的大,要取尽可能多的 \(1\) 才行,所以我们考虑维护左/右端点出发的最大连续子段和,对于左端点的区间我们取右边的最大连续子段和,右端点区间反之,对于题目中的 \([b+1,c-1]\) 的区间,如果存在则必须选,所以我们求一下这个区间的区间和即可。

但是还有一个问题,如果对于每一个二分出来的数都重新赋值 \(1/-1\)​ 时间上过不去,我们发现对于把当前二分出的 \(mid\)​ 的 \(1/-1\)​ 序列变为 \(mid+1\)​ 的 \(1/-1\)​ 序列时,只有 \(mid\)​ 会由 \(1\)​ 变为 \(-1\)​。所以我们用主席树来维护,考虑先建一个全是 \(1\)​ 的初始版本对每一个数都建一个历史版本,并由从小到大的顺序来依次将对应的位置改为 \(-1\) 并记录版本,查询时只用在对应版本查询即可。

#include<bits/stdc++.h>
#define int64 long long
using namespace std;
template <class T>
inline T read(){
    int ans=0,f=0;char ch=getchar();
    while(!isdigit(ch)) f|=ch=='-',ch=getchar();
    while(isdigit(ch)) ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
    return f?-ans:ans;
}
const int N=3e4+5,INF=1e9;
int n,ans;
int Q,q[4];
struct slzs{
    int v,id;
}a[N];
struct zfz{
    int sum,lsum,rsum;
}seg[N*20];
int rt[N],ls[N*20],rs[N*20],tot;
bool cmp(slzs x,slzs y){return x.v<y.v;}
void pushup(int x,int l,int r){
    seg[x].sum=seg[l].sum+seg[r].sum;
    seg[x].lsum=max(seg[l].lsum,seg[l].sum+seg[r].lsum),seg[x].rsum=max(seg[r].rsum,seg[r].sum+seg[l].rsum);
}
void insert(int &y,int x,int l,int r,int p,int v){
    y=++tot;ls[y]=ls[x],rs[y]=rs[x];seg[y]=seg[x];
    if(l==r){
        seg[y].lsum=seg[y].rsum=seg[y].sum=v;
        return;
    }
    int mid=l+r>>1;
    if(p<=mid) insert(ls[y],ls[x],l,mid,p,v);
    else insert(rs[y],rs[x],mid+1,r,p,v);
    pushup(y,ls[y],rs[y]);
}
zfz query(int y,int l,int r,int xl,int xr){
    if(xl<=l&&xr>=r) return seg[y];
    int mid=l+r>>1,flag=0;
    zfz tmp1,tmp2;
    if(xl<=mid) ++flag,tmp1=query(ls[y],l,mid,xl,xr);
    if(xr>mid) flag+=2,tmp2=query(rs[y],mid+1,r,xl,xr);
    if(flag==1) return tmp1;
    if(flag==2) return tmp2;
    return (zfz){tmp1.sum+tmp2.sum,max(tmp1.lsum,tmp1.sum+tmp2.lsum),max(tmp2.rsum,tmp2.sum+tmp1.rsum)};
}
void build(int &y,int l,int r){
    y=++tot;
    if(l==r){seg[y].sum=seg[y].lsum=seg[y].rsum=1;return;}
    int mid=l+r>>1;
    build(ls[y],l,mid);
    build(rs[y],mid+1,r);
    pushup(y,ls[y],rs[y]);
}
bool check(int mid){
    int sum=0;
    if(q[2]-1>=q[1]+1) sum+=query(rt[mid],1,n,q[1]+1,q[2]-1).sum;
    sum+=query(rt[mid],1,n,q[0],q[1]).rsum;
    sum+=query(rt[mid],1,n,q[2],q[3]).lsum;
    return sum>=0;
}
int main(){
    n=read<int>();
    for(int i=1;i<=n;++i) a[i].v=read<int>(),a[i].id=i;
    sort(a+1,a+1+n,cmp);
    build(rt[0],1,n);
    for(int i=1;i<=n;++i) insert(rt[i],rt[i-1],1,n,a[i].id,-1);
    for(Q=read<int>();Q;--Q){
        for(int i=0;i<4;++i) q[i]=(read<int>()+ans)%n+1;
        sort(q,q+4);
        int l=1,r=n;
        while(l<=r){
            int mid=l+r>>1;
            if(check(mid-1)) l=mid+1,ans=mid;
            else r=mid-1;
        }
        printf("%d\n",a[ans].v);
        ans=a[ans].v;
    }
    return 0;
}

未完待续……

posted @ 2021-08-03 22:24  Quick_Kk  阅读(65)  评论(4)    收藏  举报