【高级·数据结构】线段树

前言

线段树是非常重要的一个数据结构!可以求解区间 $[l,r]$ 的最值问题等区间查询。

而且大部分ST 表,树状数组能解决的问题线段树都能解决。

普通线段树

P3372 【模板】线段树 1

有一个序列 $a$,现在有两个操作:

  • 1 l r k 表示将 $[l,r]$ 区间的元素全部加上 $k$。
  • 2 l r 表示查询 $[l,r]$ 区间的和。

暴力思想

暴力的话很好处理,每次暴力枚举 $i$ 从 $l$ 到 $r$ 这段区间,给每个元素加上 $k$ 即可。

然后查询的时候暴力枚举 $i$ 从 $l$ 到 $r$ 这段区间,用 $res$ 来统计元素值的和即可。

但是这样的话对于数据规模大的时候显然会超时。

线段树的实现

那么这个时候我们就可以考虑线段树求解了!首先我们要了解线段树是什么。

线段树其实就是一种二叉树,对于一个线段我们用二叉树来表示,每个节点会有自己的区间。

对于根节点来说,管理区间就是 $[1,n]$ 啦,而它的左孩子和右孩子则管理的是 $[1,mid]$ 和 $[mid+1,n]$ ,这里的 $mid$ 指的就是 $\frac{1+n}{2}$ 可以发现,节点管理的区间其实是二分下去的,这就是线段树的建树过程。

void build(ll x,ll l,ll r){
	tr[x].l=l;
	tr[x].r=r;
	if(l==r){
		tr[x].val=arr[l];
		return ;
	}
	ll m=l+r>>1;
	build(lc,l,m);
	build(rc,m+1,r);
	pushUp(x);
	return ;
}

我们定义了一个结构体,储存该节点管理的区间 $l,r$ ,当 $l=r$ 的时候表示到达了叶子节点,那么当前管理的值 $val$ 就应该等于数组中的初始值,当然如果没有初始值直接 $\text{return}$ 就好。

之后二分下去即可,但是你会发现有一个 $\text{pushUp}$ 函数,其实这个函数就是合并一个节点的左右两个节点的,本质上来说,线段树建树的过程就是递归,所以会有合并结点的过程,合并什么呢?具体看题目要求什么了,这里以区间求和为例。

void pushUp(ll x){
	tr[x].val=tr[lc].val+tr[rc].val;
	return ;
}

这个就是 $\text{pushUp}$ 函数了,非常简洁,就是合并左节点和右节点,这里的 $\text{lc}$ 和 $\text{rc}$ 指的是左孩子和右孩子。

那么建树完毕之后就应该有查询了,我们看看主函数:

int main(){
	cin>>n>>q;
	for(int i=1;i<=n;i++) cin>>arr[i];
	build(1,1,n);
	while(q--){
		cin>>op>>x>>y;
		if(op==1){
			cin>>k;
			update(1,x,y,k);
		}
		else cout<<ask(1,x,y)<<'\n';
	}
	return 0;
}

在建完树之后询问,如果当前要更新值,那么就调用 $\text{update}$ 函数,否则输出答案,调用 $\text{ask}$ 函数。

那么 $\text{update}$ 函数怎么写呢?

我们考虑,将一个区间加上 $k$,其实就是把这个区间涂上一个 $len \times k$ 的标记,这里的 $len$ 为该节点管理区间的长度,为什么呢?

该节点是不是管理 $[l,r]$ 的区间,那是不是 $[l,r]$ 的区间中的元素都要加上 $k$?所以管理一个区间的节点就要加上一个区间的 $k$ 值。那么什么时候要加呢?如果这个区间被完全包括在目标区间里面,那么这个节点就要加,这个很容易理解。

但是我要查询的时候是不是还要把这个标记传下去一直累加?所以我们引出了 $\text{lazytag}$ 也就是懒标记来维护,每次加值的时候让值加上 $len \times k$ ,然后让 $\text{lazytag}$ 加上 $k$ 即可,然后我们还要写一个函数让懒标记下传,所以我们新建一个 $\text{pushDown}$ 函数即可。

void update(ll x,ll l,ll r,ll k){
	if(l<=tr[x].l && tr[x].r<=r){
		tr[x].val+=k*tr[x].len();
		tr[x].lazy+=k;
		return ;
	}
	pushDown(x);
	ll m=tr[x].l+tr[x].r>>1;
	if(l<=m) update(lc,l,r,k);
	if(r>m) update(rc,l,r,k);
	pushUp(x);
	return ;
}

我们先看更新函数,当这个区间被完全包括在目标区间里面,那么这个节点就要加,然后下传懒标记,之后再二分递归即可,记得还要合并!

void pushDown(ll x){
	if(!tr[x].lazy) return ;
	tr[lc].lazy+=tr[x].lazy;
	tr[rc].lazy+=tr[x].lazy;
	tr[lc].val+=tr[x].lazy*tr[lc].len();
	tr[rc].val+=tr[x].lazy*tr[rc].len();
	tr[x].lazy=0;
	return ;
}

我们看下传函数,如果现在没有标记就不下传,然后左右节点的懒标记继承父节点的懒标记,因为要继续下传。然后值就加上 $len \times k$ ,我们之前讲过,然后标记传完了,归 $0$ 即可。

好了我们来看询问函数:

ll ask(ll x,ll l,ll r){
	if(l<=tr[x].l && tr[x].r<=r) return tr[x].val;
	pushDown(x);
	ll m=tr[x].l+tr[x].r>>1;
	ll ans=0;
	if(l<=m) ans+=ask(lc,l,r);
	if(r>m) ans+=ask(rc,l,r);
	return ans;
}

如果这个区间被完全包括在目标区间里面,那么直接返回值即可,然后下传一遍懒标记。

注意:答案需要在两个区间累加,可能一个区间中会被两个节点分别管理!

之后我们将函数组合到一起即可:

#include<bits/stdc++.h>
#define ll long long
#define lc x<<1
#define rc x<<1|1
using namespace std;
const int N=1e5+5;
struct node{
	ll l,r,val,lazy;
	ll len(){
		return r-l+1;
	}
}tr[N<<2];
ll arr[N];
ll n,q,x,y,k,op;
void pushUp(ll x){
	tr[x].val=tr[lc].val+tr[rc].val;
	return ;
}
void pushDown(ll x){
	if(!tr[x].lazy) return ;
	tr[lc].lazy+=tr[x].lazy;
	tr[rc].lazy+=tr[x].lazy;
	tr[lc].val+=tr[x].lazy*tr[lc].len();
	tr[rc].val+=tr[x].lazy*tr[rc].len();
	tr[x].lazy=0;
	return ;
}
void build(ll x,ll l,ll r){
	tr[x].l=l;
	tr[x].r=r;
	if(l==r){
		tr[x].val=arr[l];
		return ;
	}
	ll m=l+r>>1;
	build(lc,l,m);
	build(rc,m+1,r);
	pushUp(x);
	return ;
}
void update(ll x,ll l,ll r,ll k){
	if(l<=tr[x].l && tr[x].r<=r){
		tr[x].val+=k*tr[x].len();
		tr[x].lazy+=k;
		return ;
	}
	pushDown(x);
	ll m=tr[x].l+tr[x].r>>1;
	if(l<=m) update(lc,l,r,k);
	if(r>m) update(rc,l,r,k);
	pushUp(x);
	return ;
}
ll ask(ll x,ll l,ll r){
	if(l<=tr[x].l && tr[x].r<=r) return tr[x].val;
	pushDown(x);
	ll m=tr[x].l+tr[x].r>>1;
	ll ans=0;
	if(l<=m) ans+=ask(lc,l,r);
	if(r>m) ans+=ask(rc,l,r);
	return ans;
}
int main(){
	cin>>n>>q;
	for(int i=1;i<=n;i++) cin>>arr[i];
	build(1,1,n);
	while(q--){
		cin>>op>>x>>y;
		if(op==1){
			cin>>k;
			update(1,x,y,k);
		}
		else cout<<ask(1,x,y)<<'\n';
	}
	return 0;
}

P3870 [TJOI2009] 开关

问题简述

现有 $n$ 盏灯排成一排,从左到右依次编号为:$1$,$2$,……,$n$。然后依次执行 $m$ 项操作。

操作分为两种:

  1. 指定一个区间 $[a,b]$,然后改变编号在这个区间内的灯的状态(把开着的灯关上,关着的灯打开);
  2. 指定一个区间 $[a,b]$,要求你输出这个区间内有多少盏灯是打开的。

灯在初始时都是关着的

思路

我们发现,每次开关灯其实就是对一段区间取反,那么我们用 $val$ 来记录区间内灯开着的数量。

那么每次下传的时候该怎么修改呢?

我们发现,每个灯不是开着就是关着,所以每次去反之后 $val$ 就等于 $len-val$ 了。然后懒标记取反即可。

代码实现

#include<bits/stdc++.h>
#define ll long long
#define lc x<<1
#define rc x<<1|1
using namespace std;
const int N=5e5;
struct node{
	ll val,l,r,lazy;
	ll len(){
		return r-l+1;
	}
}tr[N<<2];
ll n,q,op;

void pushUp(ll x){
	tr[x].val=tr[lc].val+tr[rc].val;
}

void pushDown(ll x){
	if(!tr[x].lazy) return ;
	tr[lc].val=tr[lc].len()-tr[lc].val;
	tr[rc].val=tr[rc].len()-tr[rc].val;
	tr[lc].lazy^=1;
	tr[rc].lazy^=1;
	tr[x].lazy=0;
}

void buildTree(ll x,ll l,ll r){
	tr[x].l=l;
	tr[x].r=r;
	if(l==r){
		tr[x].val=0;
		return ;
	}
	ll m=(l+r)>>1;
	buildTree(lc,l,m);
	buildTree(rc,m+1,r);
	pushUp(x);
}

void update(ll x,ll l,ll r){
	if(l<=tr[x].l && tr[x].r<=r){
		tr[x].val=tr[x].len()-tr[x].val;
		tr[x].lazy^=1;
		return ;
	}
	pushDown(x);
	ll m=(tr[x].l+tr[x].r)>>1;
	if(l<=m) update(lc,l,r);
	if(r>m) update(rc,l,r);
	pushUp(x);
}

ll query(ll x,ll l,ll r){
	if(l<=tr[x].l && tr[x].r<=r) return tr[x].val;
	pushDown(x);
	ll m=(tr[x].l+tr[x].r)>>1,sum=0;
	if(l<=m) sum+=query(lc,l,r);
	if(r>m) sum+=query(rc,l,r);
	return sum;
}
int main(){
	cin>>n>>q;
	buildTree(1,1,n);
	while(q--){
		ll l,r;
		cin>>op>>l>>r;
		if(op==0){
			update(1,l,r);
		}
		else{
			cout<<query(1,l,r)<<endl;
		}
	}
	return 0;
}

多懒标记下放线段树

P1253 扶苏的问题

问题简述

给定一个长度为 $n$ 的序列 $a$,要求支持如下三个操作:

  1. 给定区间 $[l, r]$,将区间内每个数都修改为 $x$。
  2. 给定区间 $[l, r]$,将区间内每个数都加上 $x$。
  3. 给定区间 $[l, r]$,求区间内的最大值。

思路

其实多懒标记下放,就是有很多个懒标记,在 $\text{pushdown}$ 函数中下方多个懒标记,就比如这道题。

首先我们看第二个操作和第三个操作,都很简单,所以我们直接维护区间最大值即可。

void pushUp(ll x){
	tr[x].maxn=max(tr[lc].maxn,tr[rc].maxn);
}

void update_add(ll x,ll l,ll r,ll L,ll R,ll k){
	if(l<=L && R<=r){
		pushDown_set(x);
		tr[x].maxn+=k;
		tr[x].lazy_max+=k;
		return ;
	}
	pushDown(x);
	ll m=L+R>>1;
	if(l<=m) update_add(lc,l,r,L,m,k);
	if(r>m) update_add(rc,l,r,m+1,R,k);
	pushUp(x);
	return ;
}

而第一个操作怎么做呢?我们产尝试维护一个 $\text{lazy_set}$ 的懒标记,然后在 $\text{pushdown}$ 函数的时候需要覆盖 $\text{lazy_max}$ 懒标记,直接设为 $0$ 即可,那么最大值咋办?直接设为 $\text{lazy_set}$ 就好了呀,然后就正常操作即可。

void pushDown_set(ll x){
	if(tr[x].lazy_set!=-1e18){
//		cout<<"HHH"<<endl;
		tr[lc].lazy_max=0;
		tr[rc].lazy_max=0;
		tr[lc].maxn=tr[x].lazy_set;
		tr[rc].maxn=tr[x].lazy_set;
		tr[lc].lazy_set=tr[x].lazy_set;
		tr[rc].lazy_set=tr[x].lazy_set;
		tr[x].lazy_set=-1e18;
	}
}
void pushDown(ll x){
	pushDown_set(x);
	if(tr[x].lazy_max){
		pushDown_set(x);
		tr[lc].lazy_max+=tr[x].lazy_max;
		tr[rc].lazy_max+=tr[x].lazy_max;
		tr[lc].maxn+=tr[x].lazy_max;
		tr[rc].maxn+=tr[x].lazy_max;
		tr[x].lazy_max=0;
		return ;
	}
	return ;
}

最后你可能发现,修改函数也分为 $\text{add}$ 和 $\text{set}$,区别就在懒标记的修改:

	if(l<=L && R<=r){
		pushDown_set(x);
		tr[x].maxn+=k;
		tr[x].lazy_max+=k;
		return ;
	}

	if(l<=L && R<=r){
		tr[x].maxn=k;
		tr[x].lazy_max=0;
		tr[x].lazy_set=k;
		return ;
	}

上面是加函数,下面是修改函数,发现修改的不同,因为修改的时候会覆盖之前的操作,所以会不一样,然后我们就可以写出代码了:

#include<bits/stdc++.h>
#define ll long long
#define lc x<<1
#define rc x<<1|1
using namespace std;
const int N=1e7+5;
struct node{
	ll maxn,lazy_max,lazy_set;
}tr[N<<2];
ll arr[N];
ll n,q,op,l,r,k,res;
inline ll read(){
	ll x=0,f=1;
	char ch=getchar();
	while (!isdigit(ch)){if (ch=='-') f=-1;ch=getchar();}
	while (isdigit(ch)){x=x*10+ch-48;ch=getchar();}
	return x*f;
}
void pushUp(ll x){
	tr[x].maxn=max(tr[lc].maxn,tr[rc].maxn);
}
void pushDown_set(ll x){
	if(tr[x].lazy_set!=-1e18){
//		cout<<"HHH"<<endl;
		tr[lc].lazy_max=0;
		tr[rc].lazy_max=0;
		tr[lc].maxn=tr[x].lazy_set;
		tr[rc].maxn=tr[x].lazy_set;
		tr[lc].lazy_set=tr[x].lazy_set;
		tr[rc].lazy_set=tr[x].lazy_set;
		tr[x].lazy_set=-1e18;
	}
}
void pushDown(ll x){
	pushDown_set(x);
	if(tr[x].lazy_max){
		pushDown_set(x);
		tr[lc].lazy_max+=tr[x].lazy_max;
		tr[rc].lazy_max+=tr[x].lazy_max;
		tr[lc].maxn+=tr[x].lazy_max;
		tr[rc].maxn+=tr[x].lazy_max;
		tr[x].lazy_max=0;
		return ;
	}
	return ;
}
void build(ll x,ll l,ll r){
	if(l==r){
		tr[x].maxn=arr[l];
		tr[x].lazy_set=-1e18;
		tr[x].lazy_max=0;
		return ;
	}
	ll m=l+r>>1;
	build(lc,l,m);
	build(rc,m+1,r);
	pushUp(x);
	return ;
}
void update_add(ll x,ll l,ll r,ll L,ll R,ll k){
	if(l<=L && R<=r){
		pushDown_set(x);
		tr[x].maxn+=k;
		tr[x].lazy_max+=k;
		return ;
	}
	pushDown(x);
	ll m=L+R>>1;
	if(l<=m) update_add(lc,l,r,L,m,k);
	if(r>m) update_add(rc,l,r,m+1,R,k);
	pushUp(x);
	return ;
}
void update_set(ll x,ll l,ll r,ll L,ll R,ll k){
	if(l<=L && R<=r){
		tr[x].maxn=k;
		tr[x].lazy_max=0;
		tr[x].lazy_set=k;
		return ;
	}
	pushDown(x);
	ll m=L+R>>1;
	if(l<=m) update_set(lc,l,r,L,m,k);
	if(r>m) update_set(rc,l,r,m+1,R,k);
	pushUp(x);
	return ;
}
ll ask(ll x,ll l,ll r,ll L,ll R){
	if(l<=L && R<=r) return tr[x].maxn;
	pushDown(x);
	ll m=L+R>>1;
	ll ans=-1e18;
	if(l<=m) ans=max(ans,ask(lc,l,r,L,m));
	if(r>m) ans=max(ans,ask(rc,l,r,m+1,R));
	return ans;
}
int main(){
	n=read();
	q=read();
	for(int i=1;i<=n;i++) arr[i]=read();
	build(1,1,n);
	for(int i=1;i<=n*4;i++) tr[i].lazy_set=-1e18;
	while(q--){
		op=read();
		l=read();
		r=read();
		if(op==1){
			k=read();
			update_set(1,l,r,1,n,k);
		}
		else if(op==2){
			k=read();
			update_add(1,l,r,1,n,k);
		}
		else{
			cout<<ask(1,l,r,1,n)<<endl;
		}
	}
	return 0;
}

多区间合并线段树

P4513 小白逛公园

题目简述

给定一个序列,每次有两次操作:

1 l r 求区间 $[l,r]$ 的最大子段和

2 p s 把第 $p$ 个元素修改为 $s$

思路

嗯,单点修改就不说了,直接秒了,但是 $\text{update}$ 不能直接套模板,要看看最大子段和的求法。

那么最大子段和怎么求呢?我们先看一下基本的结构体:

struct node{
	ll l,r,val,lmax,rmax,sum;
}tr[N<<2];

其中 $sum$ 就是区间和,$lmax$ 是左儿子的最大子段和,$rmax$ 相反,最后 $val$ 为区间 $[l,r]$ 的最大子段和。

那么你可能会疑惑:直接求最大子段和不就行了吗,为什么还要求左右两边的呢?

因为我们可以发现,区间 $[l,r]$ 的最大子段和不等于 $[l,mid]$ 的最大子段和加上 $[mid+1,r]$ 的最大子段和,因为还有跨越两个区间的最大子段和,所以我们要创这两个变量。

那么我们要怎么更新呢?我们看下面的函数:

node operator + (node x,node y){
	node ans;
	ans.l=x.l;
	ans.r=y.r;
	ans.val=max(x.rmax+y.lmax,max(x.val,y.val));
	ans.lmax=max(x.lmax,x.sum+y.lmax);
	ans.rmax=max(y.rmax,y.sum+x.rmax);
	ans.sum=x.sum+y.sum;
	return ans;
}
void pushUp(ll x){
	tr[x]=tr[lc]+tr[rc];
}

这里我直接重定义了加号,免得再写到函数里面。

我们定义 $v_{[l,r]}$ 表示区间 $[l,r]$ 的最大子段和,那么这里区间最大子段和的更新就是 $v_{[l,mid]} + v_{[mid+1,r]}$ 了,这里我还取了个最大值,以防万一,但是也可以不要。

然后来到重点,对于 $v_{[l,mid]}$ 的更新,我们发现存在两种情况:第一种就是最大子段和都在左边区间,也就是 $ans$ 的 $v_{[l,mid]}$;第二种则是左边都有,右边有一些,所以是上面那一坨(不是)。

然后对于 $v_{[mid+1,r]}$ 的更新,和上面都差不多,最后更新一下区间和即可。

#include<bits/stdc++.h>
#define ll long long
#define lc x<<1
#define rc x<<1|1
using namespace std;
const int N=5e5+5;
struct node{
	ll l,r,val,lmax,rmax,sum;
}tr[N<<2];
ll n,q,op,x,y;
ll arr[N];
node operator + (node x,node y){
	node ans;
	ans.l=x.l;
	ans.r=y.r;
	ans.val=max(x.rmax+y.lmax,max(x.val,y.val));
	ans.lmax=max(x.lmax,x.sum+y.lmax);
	ans.rmax=max(y.rmax,y.sum+x.rmax);
	ans.sum=x.sum+y.sum;
	return ans;
}
void pushUp(ll x){
	tr[x]=tr[lc]+tr[rc];
}
void build(ll x,ll l,ll r){
	tr[x].l=l;
	tr[x].r=r;
	if(l==r){
		tr[x].val=tr[x].sum=arr[l];
		tr[x].lmax=tr[x].rmax=arr[l];
		return ;
	}
	ll m=l+r>>1;
	build(lc,l,m);
	build(rc,m+1,r);
	pushUp(x);
	return ;
}
void modify(ll x,ll p,ll k){
	if(tr[x].l==tr[x].r){
		tr[x].lmax=tr[x].rmax=k;
		tr[x].sum=tr[x].val=k;
		return ;
	}
	ll m=tr[x].l+tr[x].r>>1;
	if(p<=m) modify(lc,p,k);
	else modify(rc,p,k);
	pushUp(x);
	return ;
}
node ask(ll x,ll l,ll r){
	if(l<=tr[x].l && tr[x].r<=r) return tr[x];
	ll m=tr[x].l+tr[x].r>>1;
	if(r<=m) return ask(lc,l,r);
	if(l>m) return ask(rc,l,r);
	return ask(lc,l,r)+ask(rc,l,r);
}
int main(){
	cin>>n>>q;
	for(int i=1;i<=n;i++) cin>>arr[i];
	build(1,1,n);
	while(q--){
		cin>>op>>x>>y;
		if(op==1){
			if(x>y) swap(x,y);
			cout<<ask(1,x,y).val<<endl;
		} 
		else modify(1,x,y);
	}
	return 0;
}

推荐题目

P10463 Interval GCD

P2894 [USACO08FEB] Hotel G

posted @ 2025-10-07 13:44  一只小何  阅读(6)  评论(0)    收藏  举报