洛谷 P3285 - [SCOI2014]方伯伯的OJ(平衡树)

洛谷题面传送门

在酒店写的,刚了一整晚终于调出来了……

首先考虑当 \(n\) 比较小(\(10^5\) 级别)的时候怎么解决,我们考虑将所有用户按排名为关键字建立二叉排序树,我们同时再用一个 map 维护下编号为 \(x\) 的用户在原平衡树上对应的节点编号是什么。那么对于每次操作我们需进行的操作如下:

  • \(1\) 类操作:直接在 map 中找到 \(x\) 对应的节点编号,将该节点对应的用户编号改为 \(y\),同时更新 map 中用户编号为 \(y\) 对应的节点编号。
  • \(2\) 类操作:在 map 中找到 \(x\) 对应的节点编号 \(id\),然后将 \(id\) 从原平衡树中分离出来,然后将 \(id\) 与根节点合并,其中 \(id\) 排名小于根节点的排名。那么怎么实现分离这一操作呢?相较于普通的平衡树不同的一点是,这次我们对于每个点记录其父亲编号,然后假设我们要将 \(p\) 节点分离出来,那么我们就考察其父亲,如果其没有父亲我们就直接将根节点设为 \(p\) 左右儿子合并的结果,如果 \(p\) 是其父亲的左儿子(类似于 splay 里的 identify 函数)就将其父亲的左儿子设为 \(p\) 左右儿子合并的结果,否则将其父亲的右儿子设为 \(p\) 左右儿子合并的结果。同时将 \(p\) 左右儿子即父亲都设为空。
  • \(3\) 类操作:与 \(2\) 类操作几乎一致,只不过这次我们将 \(id\) 合并到根节点后面。
  • \(4\) 类操作:直接在平衡树上二分,然后输出对应节点用户编号。

接下来考虑原问题。注意到不同编号虽然很多,但是如果我们把编号连续排名也连续的这些编号合并起来,那么每次操作最多增加两个合并后的连续段,也就是说这个连续段的个数是 \(\mathcal O(m)\) 的,因此我们考虑平衡树上每个节点维护一个编号的连续段。同样地我们也可以用某种数据结构找到每个点所在连续段对应的节点编号,只不过由于此题用户个数很多,使用 map 逐一存储不可取,因此我们考虑建一个 set 并将所有连续段左端点及其编号看作一个二元组压入一个 set,查询在 set 中二分即可。同时由于每次修改可能会增加新的连续段,因此我们要将一个节点裂成两个,具体来说假设我们要将 \([L,R]\) 从中间 \(p\) 处断开,那么我们就新建两个节点表示 \([L,p-1]\)\([p+1,R]\),然后将 \([L,p-1]\) 放在该节点左边,\([p+1,R]\) 放在该节点右边,原节点编号区间改为 \([p,p]\) 即可。注意这里“将 \([L,p-1]\) 挂在该节点左边”不能直接简简单单地将 \([L,p-1]\) 设为原节点的左儿子,同时将原来该节点的左儿子挂在 \([L,p-1]\) 的左儿子处,而要将 \([L,p-1]\) 与该节点原来的左儿子做一遍 merge 操作,否则复杂度会退化。

时间复杂度 \(m\log m\)

const int MAXM=2e5;
int n,qu;
struct node{
	int sum,val,key,st,ch[2],f;
	node(int _sum=0,int _val=0,int _st=0){
		sum=_sum;val=_val;st=_st;key=rand();
		ch[0]=ch[1]=f=0;
	}
} s[MAXM+5];
int rt=1,ncnt=1;set<pii> nds;
void pushup(int k){s[k].sum=s[s[k].ch[0]].sum+s[s[k].ch[1]].sum+s[k].val;}
void setson(int k,int c,int v){s[v].f=k;s[k].ch[c]=v;}
int get(int x){pii p=*--nds.upper_bound(mp(x,INF));return p.se;}
int merge(int x,int y){
	if(!x||!y) return x+y;
	if(s[x].key<s[y].key) return setson(x,1,merge(s[x].ch[1],y)),pushup(x),x;
	else return setson(y,0,merge(x,s[y].ch[0])),pushup(y),y;
}
void split_nd(int k,int p){
//	printf("split %d %d\n",k,p);
	int L=s[k].st,R=s[k].st+s[k].val-1;
	if(L==R) return;
	nds.erase(nds.find(mp(L,k)));
	if(p!=L){
		int ls=++ncnt;s[ls]=node(p-L,p-L,L);
		nds.insert(mp(L,ls));setson(k,0,merge(s[k].ch[0],ls));
	} if(p!=R){
		int rs=++ncnt;s[rs]=node(R-p,R-p,p+1);
		nds.insert(mp(p+1,rs));setson(k,1,merge(rs,s[k].ch[1]));
	} s[k].val=1;s[k].st=p;nds.insert(mp(p,k));
}
int query(int sz){
	int k=rt;
	while(1){
//		printf("%d %d %d\n",k,sz,s[k].st);
		if(sz<=s[s[k].ch[0]].sum) k=s[k].ch[0];
		else if(sz>s[s[k].ch[0]].sum+s[k].val) sz-=s[s[k].ch[0]].sum+s[k].val,k=s[k].ch[1];
		else return s[k].st+sz-s[s[k].ch[0]].sum-1;
	}
}
void print(int k){
	if(!k) return;print(s[k].ch[0]);
	printf("node %d [%d,%d] %d %d %d %d\n",k,s[k].st,s[k].st+s[k].val-1,s[k].sum,s[k].f,s[k].ch[0],s[k].ch[1]);
	print(s[k].ch[1]);
}
int walk(int k){
//	printf("walk %d\n",k);print(rt);
	int res=1+s[s[k].ch[0]].sum,pr=k;
	while(k){
		k=s[k].f;//printf("%d\n",k);
		if(pr==s[k].ch[1]) res+=s[s[k].ch[0]].sum,res+=s[k].val;
		pr=k;
	} return res;
}
int main(){
	scanf("%d%d",&n,&qu);srand(20211005203353);
	s[1]=node(n,n,1);nds.insert(mp(1,1));int pre=0;
	while(qu--){
		int opt;scanf("%d",&opt);
		if(opt==1){
			int x,y;scanf("%d%d",&x,&y);
			x-=pre;y-=pre;int pt=get(x);
			split_nd(pt,x);s[pt].st=y;
			nds.erase(nds.find(mp(x,pt)));
			nds.insert(mp(y,pt));
			printf("%d\n",pre=walk(pt));
		} else if(opt==2){
			int x;scanf("%d",&x);x-=pre;
			int pt=get(x);split_nd(pt,x);
			printf("%d\n",pre=walk(pt));
			int nd=merge(s[pt].ch[0],s[pt].ch[1]);
			s[pt].ch[0]=s[pt].ch[1]=0;pushup(pt);
			if(s[pt].f){
				int fa=s[pt].f;
				if(s[fa].ch[0]==pt) setson(fa,0,nd);
				else setson(fa,1,nd);
			} else rt=nd,s[rt].f=0;
			for(int j=s[pt].f;j;j=s[j].f) pushup(j);
			s[pt].f=0;
			rt=merge(pt,rt);
		} else if(opt==3){
			int x;scanf("%d",&x);x-=pre;
			int pt=get(x);split_nd(pt,x);
			printf("%d\n",pre=walk(pt));
			int nd=merge(s[pt].ch[0],s[pt].ch[1]);
			s[pt].ch[0]=s[pt].ch[1]=0;pushup(pt);
			if(s[pt].f){
				int fa=s[pt].f;
				if(s[fa].ch[0]==pt) setson(fa,0,nd);
				else setson(fa,1,nd);
			} else rt=nd,s[rt].f=0;
			for(int j=s[pt].f;j;j=s[j].f) pushup(j);
			s[pt].f=0;
			rt=merge(rt,pt);
		} else {
			int k;scanf("%d",&k);k-=pre;
			printf("%d\n",pre=query(k));
		}
	}
	return 0;
}
posted @ 2021-10-06 20:31  tzc_wk  阅读(31)  评论(0编辑  收藏  举报