可持久化线段树(主席树)

  • 主席树作为最常用的可持久化数据结构,广泛运用与各种区间、树上问题的在线求解已经对DP的优化上。这里主要讨论其单纯作为数据结构的应用。

P1972 [SDOI2009] HH的项链

  • 这是一道极其经典的题——静态区间种类数,其变体非常多,树上的,带修的,强制在线的等等。
    这题做法也很多样,离线后树状数组、线段树以及分块都可以做。不过既然是主席树,那就要讲最正统的在线的主席树写法。

思路

  • 主席树是如何解决“种类数”这一问题的呢?
    首先,主席树维护区间问题与一般的线段树可能有些不同。普通线段树是叶子节点的值对应区间上的相应位置店的值。
    而主席树一般而言不这么干。主席树的叶子节点一般是一个数轴,维护每种数的信息,在本质上是一棵权值线段树。
    当然这也不绝对,还是要具体情况具体分析。

  • 如果考虑直接用主席树维护每种数的出现次数,这是没办法做到的。
    因此考虑改变一下维护的东西。有一个很常见的trick是,对于类似种类数的问题,维护与其数值相等的下一个数出现的位置。
    想一下这个东西有什么性质。如果对于一个询问区间,当且仅当这个区间中的某一种数下一个数出现的位置比询问区间的右端点大,这种数就能对答案产生1的贡献。

  • 因此直接维护对于每个数下一个与其数值相等的数的位置,每次查询的就是以r和l-1为根的线段树区间右端点到n+1(为什么是n+1一会说)的值

小细节

  1. 首先仍然是主席树的套路。对于原序列上的每个点,都建立线段树记录前缀和。对于上面我们维护的信息显然有可差分性,因此对于询问直接正常差分做就可以了。

  2. 为什么右端点是n+1呢?因为对于一些数,其之后就没有出现过了,因此我们就直接将其的值赋为n+1,因此值域就是1~n+1了

  3. 主席树的空间复杂度需要特别注意。每次新插入一个节点时最多新建 \(log_n\) 个节点,因此一般而言开n的20到30倍左右。我为了保险开的25倍

code

(略有压行)

点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int N=2e6+7;
int n,m,a[N],tr[N*25],rt[N],idcnt=0,lst[N],val[N],ls[N*25],rs[N*25];
inline void push_up(int u){tr[u]=tr[ls[u]]+tr[rs[u]];}
inline int insert(int last,int l,int r,int w)
{
	int u=++idcnt,mid=(l+r)>>1;
	ls[u]=ls[last],rs[u]=rs[last],tr[u]=tr[last];
	if(l==r){tr[u]+=1;return u;}
	if(w<=mid) ls[u]=insert(ls[last],l,mid,w);
	else rs[u]=insert(rs[last],mid+1,r,w);
	push_up(u);return u;
}
inline int query(int u1,int u2,int l,int r,int ql,int qr)
{
	if(l>=ql&&r<=qr) return tr[u1]-tr[u2];
	int mid=(l+r)>>1,ans=0;
	if(ql<=mid) ans+=query(ls[u1],ls[u2],l,mid,ql,qr);
	if(qr>mid) ans+=query(rs[u1],rs[u2],mid+1,r,ql,qr);
	return ans;
}
int main()
{
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n;rt[0]=0;
	for(int i=1;i<=n;i++) cin>>a[i],lst[a[i]]=n+1;
	for(int i=n;i>=1;i--) val[i]=lst[a[i]],lst[a[i]]=i;
	for(int i=1;i<=n;i++) rt[i]=insert(rt[i-1],1,n+1,val[i]);
	cin>>m;
	for(int i=1,l,r;i<=m;i++) {cin>>l>>r;cout<<query(rt[r],rt[l-1],1,n+1,r+1,n+1)<<'\n';}
	return 0;
}

P4137 Rmq Problem / mex

  • 此题也是非常经典的——静态区间求mex。mex指是指定区间最小的没出现过的自然数(包括0)

思路

  • 还是套路地对于每个点建立线段树作为前缀,考虑如何求mex
    由于我们求的是前缀,因此我们肯定是基于查询区间右端点为基础去查询的。
    对于这种“某数没出现过”的题,我们将维护的值设为其最后一次出现的位置,这样主席树仍然本质上维护的是前缀,即从头开始的、所有数出现的最后的位置。没出现过的数设为0

  • 而从叶子节点向上合并统计的时候,我们让非叶子节点记录其管辖的区域内所有点出现的最前面的位置。
    这样统计的原因是我们可以以查询区间左端点直接查找。如果区间出现的最左边的数比查询区间左端点小那就向左搜,反之向右搜。

小细节

  1. 这道题的数值范围很大,比n大得多,因此有很多人使用了离散化。然而其实并不用。显然最终的答案不会大于n+1(给出的是n的排列)
    因此对于比n大的数直接赋为n+1就好了。

  2. 空间复杂度不多赘述,但还是要注意。

  3. 注意,由于主席树是动态开点的,因此要注意查询时的边界条件。如果这个点从来没有被创建或访问过那就要直接返回其所代表区间的左端点
    即这个区间里没有一个点,那这个区间的左端点就是最小的没出现过的数(请注意这是权值线段树)

code

(还是小压行)

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N=2e5+7;int M;
int n,m,a[N],loc[N<<5],ls[N<<5],rs[N<<5],cnt=0,rt[N];
void push_up(int u){loc[u]=min(loc[ls[u]],loc[rs[u]]);}
int modify(int last,int l,int r,int x,int w)
{
	int u=++cnt;ls[u]=ls[last],rs[u]=rs[last];
	if(l==r) {loc[u]=w;return u;}
	int mid=(l+r)>>1;
	if(x<=mid) ls[u]=modify(ls[last],l,mid,x,w);
	else rs[u]=modify(rs[last],mid+1,r,x,w);
	push_up(u);return u;
}
int build(int l,int r)
{
	int u=++cnt;int mid=(l+r)>>1;
	if(l==r) {loc[u]=0;return u;}
	ls[u]=build(l,mid),rs[u]=build(mid+1,r);
	push_up(u);return u;
}
int query(int u,int l,int r,int x)
{
	if(!u||l==r) return l;
	int mid=(l+r)>>1;
	if(loc[ls[u]]<x) return query(ls[u],l,mid,x);
	else return query(rs[u],mid+1,r,x);
}
int main()
{
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n>>m;M=n+1;rt[0]=build(0,M);
	for(int i=1;i<=n;i++){cin>>a[i];if(a[i]>n) a[i]=M;rt[i]=modify(rt[i-1],0,M,a[i],i);}
	for(int i=1,l,r;i<=m;i++){cin>>l>>r;cout<<query(rt[r],0,M,l)<<'\n';}
	return 0;
}

P2633 Count on a tree

  • 众所周知,静态区间第k小是主席树的模板题,而这道题将其搬到了树上,同时强制在线。

思路

  • 主席树对于强制在线是不惧怕的。而对于树上区间问题,我们一般都是将其转化为点到根节点的问题。比如说对于这道题而言我们就同样将主席树搬到树上,其处理流程是一样的。
    考虑其查询。由于我们主席树记录的是从根节点到某节点的前缀,因此我们直接用类似于树上差分的形式求出两点间的值的数量。
    对于 \(u,v\) 两点,设 \(x\) 点代表的线段树上记录的值是 \(tr[x]\),差分后的值设为 \(val\) 则有

\[{\large val=tr[u]+tr[v]-tr[lca_{u,v}]-tr[fa_{lca}]} \]

  • 树上差分需要求lca
    然后就直接做

小细节

  1. 注意本题的数据范围。可以不离散化但是离散化要好一点,空间没有那么极限,同时如果不离散化很容易在求mid的时候挂掉
    我的写法是离散化了的

  2. 注意一下求lca和离散化时的细节

  3. 如果有问题看看讨论区,此题错误方法五花八门树上问题是这样的

code

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+7;
const int M=3e6+7;
int n,m,a[N],tot[N],rt[M],idcnt=0,tr[M],ls[M],rs[M],f[N][20],dep[N],b[N],len=0;
vector <int> q[M];
void add(int u,int v)
{
	q[u].push_back(v),tot[u]++;
	q[v].push_back(u),tot[v]++;
}
void push_up(int u){tr[u]=tr[ls[u]]+tr[rs[u]];}
int init(int l,int r)
{
	int u=++idcnt,mid=(l+r)>>1;
	if(l==r) return u;
	ls[u]=init(l,mid),rs[u]=init(mid+1,r);
	return u;
}
int insert(int lst,int l,int r,int w)
{
	int u=++idcnt,mid=1ll*(l+r)>>1;
	ls[u]=ls[lst],rs[u]=rs[lst],tr[u]=tr[lst];
	if(l==r){tr[u]++;return u;}
	if(w<=mid) ls[u]=insert(ls[lst],l,mid,w);
	else rs[u]=insert(rs[lst],mid+1,r,w);
	push_up(u);return u;
}
void build(int u,int fa)
{
	rt[u]=insert(rt[fa],1,len+1,a[u]);
	f[u][0]=fa;dep[u]=dep[fa]+1;
	for(int i=1;i<=18;i++) f[u][i]=f[f[u][i-1]][i-1];
	for(int i=0;i<tot[u];i++)
	{
		int v=q[u][i];if(v==fa) continue;
		build(v,u);
	}
}
int lca(int x,int y)
{
	if(dep[x]<dep[y]) swap(x,y);
	for(int i=18;i>=0;i--) if(dep[f[x][i]]>=dep[y]) x=f[x][i];
	if(x==y) return x;
	for(int i=18;i>=0;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
	return f[x][0];
}
int query(int u1,int u2,int fa1,int fa2,int l,int r,int k)
{
	if(l==r) return b[l];
	int mid=1ll*(l+r)>>1,val=tr[ls[u1]]+tr[ls[u2]]-tr[ls[fa1]]-tr[ls[fa2]];
	if(val>=k) return query(ls[u1],ls[u2],ls[fa1],ls[fa2],l,mid,k);
	else return query(rs[u1],rs[u2],rs[fa1],rs[fa2],mid+1,r,k-val);
}
signed main()
{
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n>>m;
	for(int i=1;i<=n;i++) cin>>a[i],b[i]=a[i];
	for(int i=1,u,v;i<=n-1;i++) cin>>u>>v,add(u,v);
	sort(b+1,b+n+1);len=unique(b+1,b+n+1)-(b+1);
	for(int i=1;i<=n;i++) a[i]=lower_bound(b+1,b+len+1,a[i])-b;
	rt[0]=init(1,len);
	build(1,0);
	int lst=0;
	for(int i=1,u,v,k;i<=m;i++)
	{
		cin>>u>>v>>k;u=u^lst;int lc=lca(u,v);
		lst=query(rt[u],rt[v],rt[lc],rt[f[lc][0]],1,len+1,k);
		cout<<lst<<'\n';
	}
	return 0;
}

P2839 [国家集训队] middle

  • 求中位数是相当套路的。当然是二分答案然后将序列中比其小的视为-1,比其大的视为1,如果所求区间的和大于等于零证明中位数大于等于当前二分到的数。

  • 那如何去维护呢?如果暴力去修改肯定T飞。
    一个比较显然的观察是如果将数值离散化,排序后从小到大去枚举中位数,那我们所生成的-1/1的序列每次一定只有一个数被修改(相同的数不影响答案)

  • 但是二分并不是顺着枚举的。但是这提示我们通过可持久化数据结构来预处理-1/1数列(由小到大每次只有一个数被修改)
    同时我们发现题目数据范围不大、不带修、强制在线,几乎是把“主席树”三个字写在脸上了。

  • 因此我们用主席树来预处理-1/1的序列。
    而对于求子区间的中位数,假设我们我们已经二分出中位数,那只需要将-1/1数列拆开成一段前缀、一段必须选的值、一段后缀。
    而对于前缀后缀我们都取最大值(一定不劣),然后看加起来是否大于零就可以了。

code

实现上没什么细节,只是套路的维护前缀最大值与后缀最大值。唯一关注一下结构体的应用。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=5e6+7;
const int inf=1e9+7;
int n,tmp[N],rt[N],ls[N],rs[N],idcnt=0;
struct node{
	int w,lw,rw;
}tr[N];
struct edge{
	int w,id;
}a[N];
bool cmp1(edge x,edge y){return x.w<y.w;}
void push_up1(int u){tr[u].w=tr[ls[u]].w+tr[rs[u]].w;tr[u].lw=max(tr[ls[u]].lw,tr[ls[u]].w+tr[rs[u]].lw);tr[u].rw=max(tr[rs[u]].rw,tr[rs[u]].w+tr[ls[u]].rw);}
int build(int u,int l,int r)
{
	u=++idcnt;
	if(l==r) {tr[u]={1,1,1};return u;}
	int mid=(l+r)>>1;
	ls[u]=build(ls[u],l,mid),rs[u]=build(rs[u],mid+1,r);
	push_up1(u);return u;
}
int modify(int lu,int l,int r,int loc)
{
	int u=++idcnt;
	tr[u]=tr[lu],ls[u]=ls[lu],rs[u]=rs[lu];
	if(l==r) {tr[u]={-1,-1,-1};return u;}
	int mid=(l+r)>>1;
	if(loc<=mid) ls[u]=modify(ls[lu],l,mid,loc);else rs[u]=modify(rs[lu],mid+1,r,loc);
	push_up1(u);return u;
}
node query(int u,int l,int r,int ql,int qr)
{
	if(ql>qr) return {0,-inf,-inf};
	if(l>=ql&&r<=qr) return tr[u];
	int mid=(l+r)>>1;node res={0,-inf,-inf},tmp1={0,-inf,-inf},tmp2={0,-inf,-inf};
	if(ql<=mid) tmp1=query(ls[u],l,mid,ql,qr);
	if(qr>mid)  tmp2=query(rs[u],mid+1,r,ql,qr);
	res.w=tmp1.w+tmp2.w;res.lw=max({-inf,tmp1.lw,tmp1.w+tmp2.lw}),res.rw=max({-inf,tmp2.rw,tmp2.w+tmp1.rw});
	return res;
}
bool check(int a,int b,int c,int d,int x){
	return (query(rt[x],1,n,a,b).rw+query(rt[x],1,n,b+1,c-1).w+query(rt[x],1,n,c,d).lw)>=0;
}
signed main()
{
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n;for(int i=1;i<=n;i++) cin>>a[i].w,a[i].id=i;
	sort(a+1,a+n+1,cmp1);rt[1]=build(rt[1],1,n);
	for(int i=2;i<=n;i++) rt[i]=modify(rt[i-1],1,n,a[i-1].id);
	int q[10],x=0,aa,b,c,d,res;cin>>q[0];
	while(q[0]--){
		cin>>aa>>b>>c>>d;q[1]=(aa+x)%n+1,q[2]=(b+x)%n+1,q[3]=(c+x)%n+1,q[4]=(d+x)%n+1;
		sort(q+1,q+5);aa=q[1],b=q[2],c=q[3],d=q[4];
		int l=1,r=n;res=1ll;
		while(l<=r){
			int mid=(l+r)>>1;
			if(check(aa,b,c,d,mid)) l=mid+1,res=mid;
			else r=mid-1;
		}
		x=a[res].w;cout<<x<<'\n';
	}
	return 0;
}

P4587 [FJOI2016] 神秘数

  • 先考虑暴力怎么做。一个显然的观察是如果我们可以表示出 \([1,x]\),那对于一个新加进来的数 \(a_i\le x+1\),就一定可以表示出区间 \([1,x+a_i]\)。注意对于 \(a_i=x+1\)\(a_i\) 本身就可以单独成为一个子集使 \(x+1\) 这个数可以被表示。

  • 那暴力就是将原区间提取出来单独排序后设我们对于前 \(i-1\) 个数已经可以表示出区间 \([1,ans_{i-1}]\),当前的数为 \(a_i\)
    如果 \(a_i>ans_{i-1}+1\),那“神秘数”就是 \(ans_{i-1}+1\)
    否则答案就可以更新,即 \(ans_i=ans_{i-1}+a_i\)
    然后我们就发现答案本质上是一个排完序后序列的前缀和序列,同时我们也只需要这个序列,同时整个序列并不带修改。

  • 因此考虑用主席树来维护区间,我们要求的是小于等于 \(ans_i\) 的所有数的和,直接套板子。

  • 那如何保证复杂度呢?
    复杂度正确是因为对于上一个ans,当前的ans是小于等于上一个ans的数的和得来的。那下一个ans比上一个ans大的部分就是上一个ans与当前ans之间的部分。
    而这些数都比上一个ans大,也就是说下一个ans至少比上一个ans大一倍。这样两次计算翻一倍那一次查询的时间复杂度就是 \(\log \sum_ia_i\) 的。
    这样一直翻倍直到小于等于ans的数的和小于ans就是答案。
    总复杂度 \(O(m\log n\log \sum_ia_i)\)

code

代码一如既往的短。注意主席树是动态开点的,但是一开始的值域过于巨大(\(1e9\)),因此不能建树,直接向空树中加点即可。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int N=5e6+7;
const int R=1e9;
int n,m,tr[N],rt[N],idcnt=1,ls[N],rs[N];
void push_up(int u){tr[u]=tr[ls[u]]+tr[rs[u]];}
int insert(int lu,int l,int r,int x,int w)
{
	int u=++idcnt,mid=(l+r)>>1;tr[u]=tr[lu],ls[u]=ls[lu],rs[u]=rs[lu];
	if(l==r) {tr[u]+=w;return u;}
	if(x<=mid) ls[u]=insert(ls[lu],l,mid,x,w);
	else rs[u]=insert(rs[lu],mid+1,r,x,w);
	push_up(u);return u;
}
int query(int lu,int ru,int l,int r,int ql,int qr)
{
	int val=tr[ru]-tr[lu],mid=(l+r)>>1,res=0;
	if((l>=ql&&r<=qr)) return val;
	if(ql<=mid) res+=query(ls[lu],ls[ru],l,mid,ql,qr);if(qr>mid) res+=query(rs[lu],rs[ru],mid+1,r,ql,qr);
	return res;
}
signed main()
{
	ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
	cin>>n;rt[0]=1;
	for(int i=1,x;i<=n;i++){cin>>x;rt[i]=insert(rt[i-1],1,R,x,x);}
	cin>>m;for(int i=1,l,r,ans=0,lst=1;i<=m;i++){
		cin>>l>>r;ans=query(rt[l-1],rt[r],1,R,1,1)+1;
		if(ans-1==0) {cout<<"1\n";continue;}
		while("666yan都不yan了") {
			int tmp=query(rt[l-1],rt[r],1,R,1,ans);
			if(tmp>=ans) ans=tmp+1;else {cout<<ans<<'\n';break;}
		} 
	}
	return 0;
}
posted @ 2024-11-11 15:16  all_for_god  阅读(62)  评论(0)    收藏  举报