线段树学习笔记

线段树是一种用来维护区间的数据结构,如果所维护的条件满足加法性质那么就可以使用 STG(线段树)进行维护。

例如区间求和、区间最值、区间 \(\gcd\) 等。当然,也可以维护线段一类,比如扫描线。

线段树是一种完全二叉树,对于一个节点 \(u\),如果它有,那么有 \(ls = u * 2\)\(rs = u * 2 + 1\)。其右儿子维护的区间分别是 \([l,mid]\)\([mid+1,r]\)

他的操作格式如下:

  1. pushup() 将左右子节点的区间合并为一个区间,将其赋给当前节点,并同时维护所有需的信息,格式如下:
代码
void pushup(int u){
		t[u]=sum(t[ls],t[rs]);
}
  1. maketag() 将区间按一定方式进行修改并为区间打上标记,标记是为了更好、更快的进行修改,格式如下:
代码
void maketag(int u,int l,int r,int k){
		t[u]=do_sth(k);
		tag[u]=do_sth(k);
}
  1. pushdown 将该节点的标记下传给子节点,并取消标记,格式如下:
代码
void pushdown(int u,int l,int r){
		if(tag[u]!=null){
			maketag(ls,l,mid,tag[u]);
			maketag(rs,mid+1,r,tag[u]);
			tag[u]=null;
		}
}
  1. change() 区间修改,并维护需要的值,同时下传所经结点的标记,格式如下:
代码
void change(int u,int l,int r,int L,int R,int k){
		if(L<=l&&r<=R){
			maketag(u,l,r,k);
			return;
		}
		pushdown(u,l,r);
		if(L<=mid) change(ls,l,mid,L,R,k);
		if(R>mid) change(rs,mid+1,r,L,R,k)l;
		pushup(u);
}
  1. query() 区间查询,查询一个区间所维护的值,同时下传所经结点的标记,格式如下:
代码
int query(int u,int l,int r,int L,int R){
		if(L<=l&&r<=R) return t[u];
		pushdown(u,l,r);
		int s=null;
		if(L<=mid) s+=query(ls,l,mid,L,R,k);
		if(R>mid) s+=query(rs,mid+1,r,L,R,k)l;
		return s;
}

于是我们就可以打掉板子题了:

\(P3373\) 线段树2

AC代码
#include<iostream>
#define ll long long
#define maxn 100001
using namespace std;
ll tag[maxn<<2],mul[maxn<<2];
ll w[maxn<<2];
ll a[maxn];
ll p;
void pu(int u){
		w[u]=(w[u<<1]+w[(u<<1)+1])%p;
}
void maketag(int u,int len,ll x,ll y){
    	w[u]=(w[u]*y+x*len)%p;
    	mul[u]=(mul[u]*y)%p;
    	tag[u]=(tag[u]*y+x)%p;
}
void pd(int u,int l,int r){
		int m=(l+r)>>1;
		maketag(u<<1,m-l+1,tag[u],mul[u]);
		maketag((u<<1)+1,r-m,tag[u],mul[u]);
		tag[u]=0;
		mul[u]=1;
}
ll query(int u,int L,int R,int l,int r){
		if(l<=L&&R<=r)
			return w[u];
		else if(!(L>r||R<l)){
			int m=(L+R)>>1;
			pd(u,L,R);
			return (query(u<<1,L,m,l,r)+query((u<<1)+1,m+1,R,l,r))%p;
		}
		else return 0;
}
void update(int u,int L,int R,int l,int r,ll x,ll y){
		if(l<=L&&R<=r)
			maketag(u,R-L+1,x,y);
		else if(!(L>r||R<l)){
			int m=(L+R)>>1;
			pd(u,L,R);
			update(u<<1,L,m,l,r,x,y);
			update((u<<1)+1,m+1,R,l,r,x,y);
			pu(u);
		}
}
void build(int u,int l,int r){
		mul[u]=1;
		if(l==r){w[u]=a[l];return;}
		int m=(l+r)>>1;
		build((u<<1),l,m);
		build((u<<1)+1,m+1,r);
		pu(u);
}
int main(){
		int n,m;
		scanf("%d%d%lld",&n,&m,&p);
		for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
		build(1,1,n);
		for(int i=1;i<=m;i++){
			int t,x,y;
			ll k;
			scanf("%d",&t);
			if(t==1){
				scanf("%d%d%lld",&x,&y,&k);
				update(1,1,n,x,y,0,k);
			}
			else if(t==2){
				scanf("%d%d%lld",&x,&y,&k);
				update(1,1,n,x,y,k,1);
			}
			else if(t==3){
				scanf("%d%d",&x,&y);
				printf("%lld\n",query(1,1,n,x,y));
			}
		}
		return 0;
}

接下来看几道例题:

\(P4145\) 上帝造题的七分钟 2 / 花神游历各国

这是一道很明显的线段树题,题目概述如下:

需要编写一个数据结构并支持以下操作:

  1. \([l,r]\) 区间进行开方运算并向下取整。
  2. \(\sum\limits_{i=l}\limits^{r}a[i]\) 并输出。

我们可以用线段树来维护区间求和,然后暴力单点开方。但这样会T。

区间求和已经没法再优化了(至少不在我的能力范围之内),考虑优化开方操作。

由于区间开方无法通过打标记的方式优化,所以我们只能从开方的性质入手。

我们已知 \(\sqrt{1}=1\),那么我们可以利用这个性质来优化线段树。

当一个区间的最大值都小于等于 \(1\) 时,这个区间内就只有 \(1\)(开方不可能开出负数),这是我们就可以直接放弃对这一个区间的修改。

由于一个小于等于 \({10}^{12}\) 的数最多被开方 \(6\) 次后向下取整就变成了 \(1\),而每一次查询的复杂度都为 \(O(n\log_2{n})\)

至于修改操作,在最坏情况下复杂度为 \(O(n\log_2{n})\),总复杂度为 \(O(6n\log_2{n})\)

那么这种做法的总复杂度即为 \(O(m\log_2{n}+6n\log_2{n})\)。对于本题是能够通过的。

AC代码
#include<iostream>
#include<cmath>
#define ll long long
using namespace std;
const int N(1e5+3);
int n,m;
struct node{
	ll mx,sum;
} t[N<<2];
#define ls (u<<1)
#define rs (u<<1|1)
#define mid ((l+r)>>1)
void pushup(int u){
	t[u].sum=t[ls].sum+t[rs].sum;
	t[u].mx=max(t[ls].mx,t[rs].mx);
}
void build(int u,int l,int r){
	if(l==r){
		scanf("%lld",&t[u].mx);t[u].sum=t[u].mx;
		return ;
	}
	build(ls,l,mid);build(rs,mid+1,r);
	pushup(u);
}
void change(int u,int l,int r,int L,int R){
	if(l==r){
		t[u].mx=sqrt(t[u].mx);t[u].sum=sqrt(t[u].sum);
		return ;
	}
	if(L<=mid&&t[ls].mx>1) change(ls,l,mid,L,R);
	if(R>mid&&t[rs].mx>1) change(rs,mid+1,r,L,R);
	pushup(u);
}
ll qry(int u,int l,int r,int L,int R){
	if(L<=l&&r<=R) return t[u].sum;
	ll s=0;
	if(L<=mid) s+=qry(ls,l,mid,L,R);
	if(R>mid) s+=qry(rs,mid+1,r,L,R);
	return s;
}
int main(){
	#ifdef ytxy
	freopen("in.txt","r",stdin);
	#endif
	scanf("%d",&n);
	build(1,1,n);
	scanf("%d",&m);
	while(m--){
		int k,l,r;
		scanf("%d%d%d",&k,&l,&r);
		if(l>r) swap(l,r);
		if(k==0){
			change(1,1,n,l,r);
		}
		else{
			printf("%lld\n",qry(1,1,n,l,r));
		}
	}
}

\(P2572\) [SCOI2010] 序列操作:

题面很清楚,不多赘述。

很有挑战性的一道题,看了小粉兔的题解后恍然大悟,原来线段树还能这么写……

这一题就难在 maketag()pushup() 不过我们可以用另一种方式完成 pushup() 的工作。

我们可以考虑 \(t[u]=merge(t[ls],t[rs])\),即合并 \(u\) 的儿子区间赋给 \(u\) 节点(\(t[u]\) 表示当前节点结构体)。

merge() 的代码实现并不难:

inline node merge(node x,node y){
	return node(//构造函数
		x.w+y.w,x.b+y.b,
		(x.b?x.lw:x.w+y.lw),(x.w?x.lb:x.b+y.lb),
		(y.b?y.rw:y.w+x.rw),(y.w?y.rb:y.b+x.rb),
		max(max(x.mw,y.mw),x.rw+y.lw),
		max(max(x.mb,y.mb),x.rb+y.lb)
	);
}

区间赋值并不难,区间 \(xor\) 才是最难的。我们可以考虑同时维护一个区间内的 \(0\) 的总数、最长连续、前缀长度、后缀长度,对于区间内的 \(1\) 也维护同样的值。这样一来在进行区间 \(xor\) 时就只需要对维护的 \(0\)\(1\) 对应的值进行交换即可。

那么我们可以开两个数组来保存赋值标记与异或标记。打标记就变得很简单了:

inline void mt(int u,int ty,int len){
	if(ty==0) t1[u]=0,t2[u]=0,
		t[u]=node(0,len,0,len,0,len,0,len);
	else if(ty==1) t1[u]=1,t2[u]=0,
		t[u]=node(len,0,len,0,len,0,len,0);
	else if(ty==2) t2[u]^=1,
		swap(t[u].w,t[u].b),swap(t[u].lw,t[u].lb),
		swap(t[u].rw,t[u].rb),swap(t[u].mw,t[u].mb);
}

在查询时我们可以将 \([l,r]\) 区间用先前的 merge() 合并出来,这样就可以将 \(3,4\) 操作合在一起,在输出时进行判断。

总的复杂的即为 \(O(m\log{n})\)

AC代码
#include<iostream>
using namespace std;
const int N(1e5+5);
int n,m;
struct node{
	int w,b,lw,lb,rw,rb,mw,mb;
	node(int w=0,int b=0,int lw=0,int lb=0,//构造函数方便合并区间
	int rw=0,int rb=0,int mw=0,int mb=0):
		w(w),b(b),lw(lw),lb(lb),
		rw(rw),rb(rb),mw(mw),mb(mb){}
} t[N<<2];
int t1[N<<2],t2[N<<2];
inline node merge(node x,node y){//合并区间
	return node(
		x.w+y.w,x.b+y.b,
		(x.b?x.lw:x.w+y.lw),(x.w?x.lb:x.b+y.lb),//考虑整个子区间都是1或0的情况
		(y.b?y.rw:y.w+x.rw),(y.w?y.rb:y.b+x.rb),
		max(max(x.mw,y.mw),x.rw+y.lw),//考虑左、右两节点与横跨两个子区间的
		max(max(x.mb,y.mb),x.rb+y.lb)//同上
	);
}
inline void mt(int u,int ty,int len){
	if(ty==0) t1[u]=0,t2[u]=0,
		t[u]=node(0,len,0,len,0,len,0,len);
	else if(ty==1) t1[u]=1,t2[u]=0,
		t[u]=node(len,0,len,0,len,0,len,0);
	else if(ty==2) t2[u]^=1,
		swap(t[u].w,t[u].b),swap(t[u].lw,t[u].lb),//交换0与1的对应值
		swap(t[u].rw,t[u].rb),swap(t[u].mw,t[u].mb);
}
#define ls (u<<1)
#define rs (u<<1|1)
#define mid (l+r>>1)
void pushdown(int u,int l,int r){
	if(t1[u]!=-1)mt(ls,t1[u],mid-l+1),mt(rs,t1[u],r-mid);
	if(t2[u]) mt(ls,2,mid-l+1),mt(rs,2,r-mid);
	t1[u]=-1,t2[u]=0;
}
void build(int u,int l,int r){
	t1[u]=-1;
	if(l==r){
		int a;scanf("%d",&a);
		t[u]=node(a,a^1,a,a^1,a,a^1,a,a^1);return ;
	}
	build(ls,l,mid);build(rs,mid+1,r);
	t[u]=merge(t[ls],t[rs]);
}
void change(int u,int l,int r,int L,int R,int ty){
	if(r<L||R<l) return ;
	if(L<=l&&r<=R){mt(u,ty,r-l+1);return ;}
	pushdown(u,l,r);
	change(ls,l,mid,L,R,ty);change(rs,mid+1,r,L,R,ty);
	t[u]=merge(t[ls],t[rs]);
}
node qry(int u,int l,int r,int L,int R){
	if(r<L||R<l) return node();
	if(L<=l&&r<=R) return t[u];
	pushdown(u,l,r);
	return merge(qry(ls,l,mid,L,R),qry(rs,mid+1,r,L,R));
}
int main(){
	#ifdef ytxy
	freopen("in.txt","r",stdin);
	#endif
	scanf("%d%d",&n,&m);
	build(1,1,n);
	while(m--){
		int opt,l,r;
		scanf("%d%d%d",&opt,&l,&r);
		//题目有点坑,测试点的区间是[0,n-1],所以需要加1来方便线段树维护
		if(opt<3) change(1,1,n,l+1,r+1,opt);
		else{
			node x=qry(1,1,n,l+1,r+1);
			printf("%d\n",(opt==3?x.w:x.mw));
		}
	}
}
posted @ 2022-07-09 12:57  JR_ytxy  阅读(56)  评论(0)    收藏  举报