2025.7.10 线段树

2025.7.10 线段树

基本思想

线段树是算法竞赛中常用的用来维护区间信息的数据结构。

线段树可以在 \(O(\log{n})\) 的时间复杂度内实现单点修改、区间修改、区间查询(区间求和,求区间最大值,求区间最小值)等操作。eg

如上图所示,线段树的每一个根节点维护左右两个子节点的答案(区间和)。

当我们查询某个区间 \([s,t]\) 时,因为当前节点维护两个节点 \([l,mid]\&[mid+1,r]\) 的答案,所以只考虑 \([s,t]\) 是完全包含该节点还是只包含该节点的哪一部分。查询的时间复杂度为 \(O(\log{n})\)

实现

区间和为例。

构造

像上图一样建立线段树,当区间长度为 \(1\) 时停止。

void up(int p){
    tr[p]=tr[p<<1]+tr[p<<1|1];//区间和
}
void build(int l,int r,int p){
    if(l==r){//区间为1
        tr[p]=a[l];
        return ;
    }
    int mid=(l+r)/2;
    build(l,mid,p<<1),build(mid+1,r,p<<1|1);//递归处理左右子节点
    up(p);//向父节点递归
}

查询

int query(int l,int r,int p,int s,int t){
    if(s<=l&&r<=t){//完全包含
        return tr[p];
    }
    int mid=(l+r)/2,sum=0;
    if(s<=mid){//包含左子树的一部分
        sum+=query(l,mid,p<<1,s,t);
    }
    if(t>mid){//包含右子树的一部分
        sum+=query(mid+1,r,p<<1|1,s,t);
    }
    return sum;
}

例题P2023

题目描述

有一个长为 \(n\) 的数列 \({a_n}\) ,有如下三种操作形式:

  1. 格式 1 l r c ,表示把所有满足 \(l\leq{i}\leq{r}\)\(a_i\) 改为 \(a_i*c\) ;
  2. 格式 2 l r c ,表示把所有满足 \(l\leq{i}\leq{r}\)\(a_i\) 改为 \(a_i+c\) ;
  3. 格式 3 l r ,询问所有满足 \(l\leq{i}\leq{r}\)\(a_i\) 的和,由于答案可能很大,你只需输出这个数模 \(mod\) 的值。

[!NOTE]

对于全部的测试点,保证 \(0\leq p,a_i,c\leq10^9\)\(1\leq l\leq r\leq n\)

思路

因为操作 3 涉及区间求和,考虑使用线段树。

题目涉及区间修改,所以需要使用结构体对当前节点的状态进行标记:

struct node{
    long long sum,add,mul;//值,加法标记,乘法标记
}tr[N<<2];
  • 对于乘法操作,不只单纯地把数值乘 c ,之前可能的存在加法标记也要乘 c
  • 对于加法,数值和加法标记一起加 c 即可。

为了使每次查询到的节点的状态无误,在查询它的父节点时就应该状态转移:

void down(int l,int r,int p){
	long long ad=tr[p].add,ml=tr[p].mul,mid=(l+r)/2;
	tr[p<<1].sum=(tr[p<<1].sum*ml+(mid-l+1)*ad)%mod;
	tr[p<<1].add=(tr[p<<1].add*ml+ad)%mod;
	tr[p<<1].mul=(tr[p<<1].mul*ml)%mod;
	tr[p<<1|1].sum=(tr[p<<1|1].sum*ml+(r-mid)*ad)%mod;
	tr[p<<1|1].add=(tr[p<<1|1].add*ml+ad)%mod;
	tr[p<<1|1].mul=(tr[p<<1|1].mul*ml)%mod;
	tr[p].add=0,tr[p].mul=1;//用完后清空
}

AC代码

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5;
int n,m,mod,a[N];
struct node{
	long long sum,add,mul;
}tr[N<<2];
void up(int p){
	tr[p].sum=(tr[p<<1].sum+tr[p<<1|1].sum)%mod;//向父节点
}
void down(int l,int r,int p){//向子节点
	long long ad=tr[p].add,ml=tr[p].mul,mid=(l+r)/2;
	tr[p<<1].sum=(tr[p<<1].sum*ml+(mid-l+1)*ad)%mod;
	tr[p<<1].add=(tr[p<<1].add*ml+ad)%mod;
	tr[p<<1].mul=(tr[p<<1].mul*ml)%mod;
	tr[p<<1|1].sum=(tr[p<<1|1].sum*ml+(r-mid)*ad)%mod;
	tr[p<<1|1].add=(tr[p<<1|1].add*ml+ad)%mod;
	tr[p<<1|1].mul=(tr[p<<1|1].mul*ml)%mod;
	tr[p].add=0,tr[p].mul=1;//清空
}
void build(int l,int r,int p){
	tr[p].add=0,tr[p].mul=1;//初始化
	if(l==r){
		tr[p].sum=a[l];
		return ;
	}
	int mid=(l+r)/2;
	build(l,mid,p<<1),build(mid+1,r,p<<1|1);
	up(p);
}
void add(int l,int r,int p,int s,int t,int c){
	if(s<=l&&r<=t){
		tr[p].sum=(tr[p].sum+(r-l+1)*c)%mod;//加上区间长度*c
		tr[p].add=(tr[p].add+c)%mod;//add标记累加
		return ;
	}
	down(l,r,p);//向下转移
	int mid=(l+r)/2;
	if(s<=mid){
		add(l,mid,p<<1,s,t,c);
	}
	if(t>mid){
		add(mid+1,r,p<<1|1,s,t,c);
	}
	up(p);
}
void mul(int l,int r,int p,int s,int t,int c){
	if(s<=l&&r<=t){
		tr[p].sum=(tr[p].sum*c)%mod;
		tr[p].add=(tr[p].add*c)%mod;
		tr[p].mul=(tr[p].mul*c)%mod;
		return ;
	}
	down(l,r,p);
	int mid=(l+r)/2;
	if(s<=mid){
		mul(l,mid,p<<1,s,t,c);
	}
	if(t>mid){
		mul(mid+1,r,p<<1|1,s,t,c);
	}
	up(p);
}
long long query(int l,int r,int p,int s,int t){
	long long ans=0;
	if(s<=l&&r<=t){
		return tr[p].sum;
	}
	down(l,r,p);
	int mid=(l+r)/2;
	if(s<=mid){
		ans+=query(l,mid,p<<1,s,t);
	}
	if(t>mid){
		ans+=query(mid+1,r,p<<1|1,s,t);
	}
	return ans%mod;
}
int main(){
	cin>>n>>mod;
	for(int i=1;i<=n;i++){
		cin>>a[i];
		a[i]%=mod;
	}
	build(1,n,1);
	cin>>m;
	while(m--){
		int op,l,r,c;
		cin>>op>>l>>r;
		if(op==1){//乘法
			cin>>c;
			mul(1,n,1,l,r,c);
		}
		else if(op==2){//加法
			cin>>c;
			add(1,n,1,l,r,c);
		}
		else{//求和
			cout<<query(1,n,1,l,r)<<endl;
		}
	}
	return 0;
}

完结撒花!!!

posted @ 2025-07-11 21:58  liyuan2023  阅读(14)  评论(0)    收藏  举报