线段树(等级1)
发现博客园因为太花哨打不开,所以只好写这了
这个东西比树状数组功能更全面,属于全包含,但是由于这个东西常数大,不是很好写,记得也不熟,所以能打树状数组的三个基本操作最好还是打树状数组。
与树状数组的二进制来实现不同,线段树是通过二叉树来实现的
在线段树中,一个节点表示一个区间的和(也可以是乘积,最大值等等),一个节点的权值等于他左儿子权值加上右儿子权值。
由于二叉树的性质,我们可以用
\(2i\)来表示\(i\)节点的左儿子,通常写作\(i << 1\)
\(2i+1\)来表示右儿子,写作$i << 1 \ |\ 1 $
这样,我们可以通过一个结构体,来存他所表示的区间和区间的信息
struct E{ int data ; int l ,r ;}
注意,此处的\(l,r\)是他所表示的区间\([l,r]\)的信息,并不是他的左右儿子编号,左右儿子通过上文的性质来索引。
以区间和为例子
由于 \(t [i].data=t[i<<1].data+t[i<<1|1].data\)
这样,我们就可以写出建树的代码(递归实现)
void build( int x , int l , int r ){ // c代表原数列
t [x].l = l , t [x].r = r ;
if( l == r ) { t [x].data = c [l] ; return ; }
int mid = ( l + r ) >> 1 ;
build( i << 1 , l , mid ) ;
build( i << 1 | 1 , mid + 1 , r ) ;
ud( x ) ;
}
这里面,ud(x)代表更新一个节点,由于通过递归我们已经到达最底端节点,所以我们需要更新一个节点的值
void ud( int x ){ t [x].data = t [x << 1].data + t [x << 1 | 1].data ; }
接下来是一些基本操作
- 单点修改,区间查询
#include <bits/stdc++.h>
#define int long long
#define R register int
using namespace std ;
inline int read(){
int w = 0 ; char ch = getchar() ;
while( ch > '9' || ch < '0' ) ch = getchar() ;
while( ch >= '0' && ch <= '9' ){
w = ( w << 1 ) + ( w << 3 ) + ( ch - '0' ) ;
ch = getchar() ;
} return w ;
}
const int N = 1e5 + 10 ;
struct E{ int l , r ; int da ; } a [N << 2] ;
int n , m ; string s ; int c [N] ;
inline int ud ( int x ) { a [x].da = a [x << 1].da + a [x << 1 | 1].da; }
void build( int x , int l , int r ){
a [x].l = l , a [x].r = r ;
if( l == r ) { a [x].da = c [l] ; return ; }
int mid = ( l + r ) >> 1 ;
build( x << 1 , l , mid ) , build( x << 1 | 1 , mid + 1 , r ) ;
ud( x ) ;
}
inline void add( int p , int x , int k ){
if( a [p].l == a [p].r ) { a [p].da += k ; return ; }
int mid = ( a [p].l + a [p].r ) >> 1 ;
if( x <= mid ) add( p << 1 , x , k ) ;
else add( p << 1 | 1 , x , k ) ;
ud( p ) ;
}
inline int sum( int p , int l , int r ){
if( a [p].l >= l && a [p].r <= r ) return a [p].da ;
int mid = ( a [p].l + a [p].r ) >> 1 ;
int ans = 0 ;
if( l <= mid ) ans += sum( p << 1 , l , r ) ;
if( r > mid ) ans += sum( p << 1 | 1 , l , r ) ;
return ans ;
}
void sc(){
n = read() ;
for( R i = 1 ; i <= n ; i ++ ) c [i] = read() ;
if( n != 0 )
build( 1 , 1 , n ) ;
int ls , rs ;
m = read() ;
for( R i = 1 ; i <= m ; i ++ ){
cin >> s ;
if( s == "SUM" ) {
ls = read() , rs = read() ;
printf( "%lld\n" , sum( 1 , ls , rs ) ) ;
}
else {
ls = read() , rs = read() ;
add( 1 , ls , rs ) ;
/* for( R i = 1 ; i <= 4 * n ; i ++ )
printf( "%lld %lld %lld %lld\n" , i , a [i].l , a [i].r , a [i].da ) ;*/
}
}
}
signed main(){
sc () ;
return 0 ;
}
更新一个节点还是挺简单的,递归到这个节点然后把回来路上的节点都加上就可以了,由于树的深度不大于\(log(n)\)故时间复杂度也是\(\Theta(logn)\)
查询区间的思想:
- 如果被查询区间完全覆盖一个区间s,Ans+=s.data
- 如果不是,继续递归,直到完全覆盖为止
需要注意,这里的
if( l <= mid ) ans += sum( p << 1 , l , r ) ;
if( r > mid ) ans += sum( p << 1 | 1 , l , r ) ;
其实通过两行代码实现了判断是否这个与区间的左右儿子区间有交集,若有,则进行递归,直到查询到完全覆盖。
两行if不要写成else
- 区间修改,区间查询
之所以不写区间修改单点查询是因为我们可以用\([x,x]\)来表示\(x\)这个点的信息,
这个东西需要一个\(lazy\ tag\)来实现
因为如果我们还像单点修改一样把每个点和他们的区间都进行修改,这爆炸的常数甚至还不如前缀和跑得快,这个东西就没用了。
关于\(lazy \ tag\)他就是为了节省时间(因为懒)而出现的,如果一个区间已经被被查询区间完全覆盖,那么我们就更改这个区间的值,并对其打上懒惰标记,这样,时间肯定是快了,但是其实这时他的子区间也需要得到修改,以为我们懒,先不改他,等到需要的时候再改。
对于不需要查询的节点,我们先不传递标记。
对于需要查询的节点,我们加上标记,顺便把标记传递下去。
值得一提的是,由于我们懒,所以传递标记的时候我们也只传递其直系子节点,对于孙子节点,重孙子节点等,需要的时候再更改。
我们首先需要一个可以把\(lazy \ tag\)向下传递的函数\(spread\),为了方便,写成\(sp\)。
void sp( int x ){//将一个节点的标记下传
if( t [x].l == t [x].r ){//这个节点是叶子节点,没有子节点了,无需懒惰标记
t [x].lz = 0 ;
return ;
}
if( t [x].lz ){
int ls = x << 1 ;
int rs = x << 1 | 1 ;
int num = t [x].lz ;
t [x].lz = 0 ;//懒惰标记已经传递,故清空
t [ls].lz += num , t [rs].lz += num ;
t [ls].data += ( t [ls].r - t [ls].l + 1 ) * num ;
t [rs].data += ( t [rs].r - t [rs].l + 1 ) * num ;
}
}
这样,我们就可以将一个点的懒惰标记下传一代并且更新其子节点的值,节省了时间。
注意,对于任何操作,我们需要在对一个节点进行操作之前对他的懒惰标记进行传递。
对于区间修改
void add( int x , int l , int r , int w ){
if( t [x].l >= l && t [x].r <= r ){
t [x].data += ( t [x].r - t [x].l + 1 ) * w ;
t [x].lz += w ;
return ;
} sp (x) ;
int mid = ( t [x].l + t [x].r ) >> 1 ;
if( l <= mid ) add( x << 1 , l , r , w ) ;
if( r > mid ) add( x << 1 | 1 , l , r , w ) ;
ud( x ) ;
}
对于区间查询,和原来没什么区别
int sum( int x , int l , int r ){
if( t [x].l >= l && t [x].r <= r ) return t [x].data ;
sp (x) ; // 对节点操作前传递
int mid = ( t [x].l + t [x].r ) >> 1 ;
int ans = 0 ;
if( l <= mid ) ans += sum( x << 1 , l , mid ) ;
if( r > mid ) ans += sum( x << 1 | 1 , mid + 1 , r ) ;
return ans ;
}
这样就可以写出终极代码
#include <bits/stdc++.h>
#define int long long
#define R register int
using namespace std ;
inline int read(){
int w = 0 , f = 1 ; char ch = getchar() ;
while( ch < '0' || ch > '9' ) {
if( ch == '-' ) f = -1 ;
ch = getchar() ;
}
while( ch >= '0' && ch <= '9' ){
w = ( w << 1 ) + ( w << 3 ) + ( ch - '0' ) ;
ch = getchar() ;
} return f * w ;
}
const int N = 1e5 + 10 ;
int n , m ; string s ; int la , ra , w ; int c [N] ;
struct E{ int l , r ; int da ; int lz ; } a [N << 2] ;
inline void ud( int x ){ a [x].da = a [x << 1].da + a [x << 1 | 1].da ; }
void build( int x , int l , int r ){
a [x].l = l , a [x].r = r ;
if( l == r ) { a [x].da = c [l] ; return ; }
int mid = ( a [x].l + a [x].r ) >> 1 ;
build( x << 1 , l , mid ) ;
build( x << 1 | 1 , mid + 1 , r ) ;
ud( x ) ;
}
inline void sp( int x ){
if( a [x].l == a [x].r ){
a [x].lz = 0 ;
return ;
}
if( a [x].lz ){
int rs = x << 1 | 1 ;
int ls = x << 1 ;
int num = a [x].lz ;
a [x].lz = 0 ;
a [ls].da += num * ( a [ls].r - a [ls].l + 1 ) , a [rs].da += num * ( a [rs].r - a [rs].l + 1 ) ;
a [ls].lz += num , a [rs].lz += num ;
}
}
inline void add( int x , int l , int r , int w ){
if( a [x].l >= l && a [x].r <= r ){
a [x].da += w * ( a [x].r - a [x].l + 1 ) ;
a [x].lz += w ;
return ;
} sp ( x ) ;
int mid = ( a [x].l + a [x].r ) >> 1 ;
if( l <= mid ) add( x << 1 , l , r , w ) ;
if( r > mid ) add( x << 1 | 1 , l , r , w ) ;
ud( x ) ;
}
inline int sum( int x , int l , int r ){
if( a [x].l >= l && a [x].r <= r ) return a [x].da ;
sp ( x ) ;
int mid = ( a [x].l + a [x].r ) >> 1 ;
int ans = 0 ;
if( l <= mid ) ans += sum( x << 1 , l , r ) ;
if( r > mid ) ans += sum( x << 1 | 1 , l , r ) ;
return ans ;
}
void sc(){
n = read() ;// printf( "%lld\n" , n ) ;
for( R i = 1 ; i <= n ; i ++ ) c [i] = read() ;
build( 1 , 1 , n ) ;
// printf( "FIN!!\n\n\n" ) ;
m = read() ; //printf( "%lld\n" , m ) ;
for( R i = 1 ; i <= m ; i ++ ){
cin >> s ;
if( s == "SUM" ) la = read() , ra = read() , printf( "%lld\n" , sum( 1 , la , ra ) ) ;
else la = read() , ra = read() , w = read() , add( 1 , la , ra, w ) ;
}
}
signed main(){
sc() ;
return 0 ;
}
当然,线段树还可以扩展成最大值,最小值等
\(Max\)
#include <bits/stdc++.h>
#define int long long
#define R register int
using namespace std ;
inline int read(){
int w = 0 ; char ch = getchar() ;
while( ch < '0' || ch > '9' ) ch = getchar() ;
while( ch >= '0' && ch <= '9' ){
w = ( w << 1 ) + ( w << 3 ) + ( ch - '0' ) ;
ch = getchar() ;
} return w ;
}
const int N = 1e6 + 10 ;
int c [N] ; int n , m ; int ls , rs ;
struct E{ int l , r ; int da ; } a [N << 2] ;
inline void ud( int x ){ a [x].da = max( a [x << 1].da , a [x << 1 | 1].da ) ; }
void build( int x , int l , int r ){
a [x].l = l , a [x].r = r ;
if( l == r ) { a [x].da = c [l] ; return ; }
int mid = ( l + r ) >> 1 ;
build( x << 1 , l , mid ) ;
build( x << 1 | 1 , mid + 1 , r ) ;
ud( x ) ;
}
inline int sum( int x , int l , int r ){
if( a [x].l >= l && a [x].r <= r ) return a [x].da ;
int mid = ( a [x].l + a [x].r ) >> 1 ;
int ans = -0x3fffffff ;
if( l <= mid ) ans = max( ans , sum( x << 1 , l , r ) ) ;
if( r > mid ) ans = max( ans , sum( x << 1 | 1 , l , r ) ) ;
return ans ;
}
void sc(){
n = read() ; n ++ ;
for( R i = 1 ; i <= n ; i ++ ) c [i] = read() ;
build( 1 , 1 , n ) ;
/* for( R i = 1 ; i <= 4 * n ; i ++ )
printf( "%lld %lld %lld %lld\n" , i , a [i].l , a [i].r , a [i].da ) ;*/
m = read() ;
// printf( "\n%lld\n\n" , sum( 1 , 3 , 5 ) ) ;
for( R i = 1 ; i <= m ; i ++ ){
ls = read() , rs = read() ;
ls ++ , rs ++ ;
printf( "%lld\n" , sum( 1 , ls , rs ) ) ;
}
}
signed main(){
sc() ;
return 0 ;
}
\(Min\)
#include <bits/stdc++.h>
#define int long long
#define R register int
using namespace std ;
inline int read(){
int w = 0 ; char ch = getchar() ;
while( ch > '9' || ch < '0' ) ch = getchar() ;
while( ch >= '0' && ch <= '9' ){
w = ( w << 1 ) + ( w << 3 ) + ( ch - '0' ) ;
ch = getchar() ;
} return w ;
}
const int N = 1e5 + 10 ;
int n , m ; int c [N] ; int ls , rs ;
struct E{ int l , r ; int da ; } a [N << 2] ;
inline void ud( int x ){ a [x].da = min( a [x << 1].da , a [x << 1 | 1].da ) ; }
void build( int x , int l , int r ){
a [x].l = l , a [x].r = r ;
if( l == r ){ a [x].da = c [l] ; return ; }
int mid = ( l + r ) >> 1 ;
build( x << 1 , l , mid ) ;
build( x << 1 | 1 , mid + 1 , r ) ;
ud( x ) ;
}
inline int sum( int x , int l , int r ){
if( a [x].l >= l && a [x].r <= r ) return a [x].da ;
int mid = ( a [x].l + a [x].r ) >> 1 ;
int ans = 0x3f3f3f3f ;
if( l <= mid ) ans = min( ans , sum( x << 1 , l , r ) ) ;
if( r > mid ) ans = min( ans , sum( x << 1 | 1 , l , r ) ) ;
return ans ;
}
void sc(){
n = read() , m = read() ;
for( R i = 1 ; i <= n ; i ++ ) c [i] = read() ;
build( 1 , 1 , n ) ;
for( R i = 1 ; i <= m ; i ++ ){
ls = read() , rs = read() ;
printf( "%lld " , sum( 1 , ls , rs ) ) ;
}
}
signed main(){
sc() ;
return 0 ;
}
山海经
这个题的意思就是要求一个最大连续子段和

这个题思路不是很简单,代码也不是很好实现
我的思路主要就是建3棵线段树,一棵和左边合并用,一棵和右边合并用,一个存自己的最值
然后还开了很多来记录路径
这个题由于全程是自己做的,就不写那么详细了(好像第一次一点题解都没看做出来这么麻烦的题呢)
code
view code
#include<bits/stdc++.h>
#define int long long
#define R register int
using namespace std ;
inline int read(){
int w = 0 ; bool fg = 0 ; char ch = getchar() ;
while( ch < '0' || ch > '9' ) fg |= ( ch == '-' ) , ch = getchar() ;
while( ch >= '0' && ch <= '9' ) w = ( w << 1 ) + ( w << 3 ) + ( ch - '0' ) , ch = getchar() ;
return fg ? -w : w ;
}
const int N = 1e5 + 10 ;
const int INF = 0x3f3f3f3f ;
int n , m ;
int c [N] ;
struct E{
int l , r ;
int left , right , maxs , sum ;
int leftc , rightc ;
int maxscl , maxscr ;
} a [N << 2] , b [N << 2] ;
int cnt ;
bool cmp( E as , E bs ){
return as.l < bs.l ;
}
void debug(){
// for( R i = 1 ; i <= n * 4 ; i ++ )
// printf( "TEST %lld %lld %lld %lld %lld %lld %lld %lld %lld %lld\n" , a [i].l , a [i].r , a [i].left , a [i].leftc , a [i].right , a [i].rightc , a [i].maxs , a [i].maxscl , a [i].maxscr , a [i].sum ) ;
// for( R i = 1 ; i <= cnt ; i ++ )
// printf( "TEST %lld %lld %lld %lld %lld %lld %lld %lld %lld %lld\n" , b [i].l , b [i].r , b [i].left , b [i].leftc , b [i].right , b [i].rightc , b [i].maxs , b [i].maxscl , b [i].maxscr , b [i].sum ) ;
}
//8 1 5 -6 3 -1 8 -7 9 5 4 7
void build( int x , int l , int r ){
a [x].l = l , a [x].r = r ;
if( a [x].l == a [x].r ){
a [x].left = c [l] ;
a [x].right = c [l] ;
a [x].maxs = c [l] ;
a [x].sum = c [l] ;
a [x].leftc = l ;
a [x].rightc = l ;
a [x].maxscl = l ;
a [x].maxscr = l ;
return ;
}
int mid = ( a [x].l + a [x].r ) >> 1 ;
build( x << 1 , l , mid ) ;
build( x << 1 | 1 , mid + 1 , r ) ;
int ls = x << 1 , rs = x << 1 | 1 ;
a [x].sum = a [ls].sum + a [rs].sum ;
if( a [ls].left >= a [ls].sum + a [rs].left ) a [x].left = a [ls].left , a [x].leftc = a [ls].leftc ;
else a [x].left = a [ls].sum + a [rs].left , a [x].leftc = a [rs].leftc ;
//因为要尽量靠左边决策,所以相等时只取左边
if( a [rs].sum + a [ls].right >= a [rs].right ) a [x].right = a [rs].sum + a [ls].right , a [x].rightc = a [ls].rightc ;
else a [x].right = a [rs].right , a [x].rightc = a [rs].rightc ;
//因为要尽量靠左边决策,所以相等时要尽量左偏,选择左右混合
bool fg = 0 ; int num = a [rs].maxs ;
if( a [ls].maxs >= a [rs].maxs ) fg = 1 , num = a [ls].maxs ; //相等时靠左决策 ,1表示选择左边
if( num == a [ls].right + a [rs].left ){
a [x].maxs = num ;
if( fg ){
if( a [ls].maxscl == a [ls].rightc ){
a [x].maxscl = a [ls].maxscl ;
a [x].maxscr = min( a [ls].maxscr , a [rs].leftc ) ;
}
else if( a [ls].maxscl < a [ls].rightc ) a [x].maxscl = a [ls].maxscl , a [x].maxscr = a [ls].maxscr ;
else if( a [ls].maxscl > a [ls].rightc ) a [x].maxscl = a [ls].rightc , a [x].maxscr = a [rs].leftc ;
}
else a [x].maxscl = a [ls].rightc , a [x].maxscr = a [rs].leftc ;
}
else if( num < a [ls].right + a [rs].left ){
a [x].maxs = a [ls].right + a [rs].left ;
a [x].maxscl = a [ls].rightc , a [x].maxscr = a [rs].leftc ;
}
else if( num > a [ls].right + a [rs].left ){
a [x].maxs = num ;
if( fg ) a [x].maxscl = a [ls].maxscl , a [x].maxscr = a [ls].maxscr ;
else a [x].maxscl = a [rs].maxscl , a [x].maxscr = a [rs].maxscr ;
}
}
inline void finds( int x , int l , int r ){
if( a [x].l >= l && a [x].r <= r ) {
b [++ cnt] = a [x] ;
// printf( "%lld %lld\n" , a [x].l , a [x].r ) ;
E T ;
int ls = cnt - 1 , rs = cnt ;
T.sum = b [ls].sum + b [rs].sum ;
if( b [ls].left >= b [ls].sum + b [rs].left ) T.left = b [ls].left , T.leftc = b [ls].leftc ;
else T.left = b [ls].sum + b [rs].left , T.leftc = b [rs].leftc ;
//因为要尽量靠左边决策,所以相等时只取左边
if( b [rs].sum + b [ls].right >= b [rs].right ) T.right = b [rs].sum + b [ls].right , T.rightc = b [ls].rightc ;
else T.right = b [rs].right , T.rightc = b [rs].rightc ;
//因为要尽量靠左边决策,所以相等时要尽量左偏,选择左右混合
bool fg = 0 ; int num = b [rs].maxs ;
if( b [ls].maxs >= b [rs].maxs ) fg = 1 , num = b [ls].maxs ; //相等时靠左决策 ,1表示选择左边
if( num == b [ls].right + b [rs].left ){
T.maxs = num ;
if( fg ){
if( b [ls].maxscl == b [ls].rightc ){
T.maxscl = b [ls].maxscl ;
T.maxscr = min( b [ls].maxscr , b [rs].leftc ) ;
}
else if( b [ls].maxscl < b [ls].rightc ) T.maxscl = b [ls].maxscl , T.maxscr = b [ls].maxscr ;
else if( b [ls].maxscl > b [ls].rightc ) T.maxscl = b [ls].rightc , T.maxscr = b [rs].leftc ;
}
else T.maxscl = b [ls].rightc , T.maxscr = b [rs].leftc ;
}
else if( num < b [ls].right + b [rs].left ){
T.maxs = b [ls].right + b [rs].left ;
T.maxscl = b [ls].rightc , T.maxscr = b [rs].leftc ;
}
else if( num > b [ls].right + b [rs].left ){
T.maxs = num ;
if( fg ) T.maxscl = b [ls].maxscl , T.maxscr = b [ls].maxscr ;
else T.maxscl = b [rs].maxscl , T.maxscr = b [rs].maxscr ;
}
b [rs].l = b [ls].l ;
b [rs].left = T.left ;
b [rs].right = T.right ;
b [rs].leftc = T.leftc ;
b [rs].rightc = T.rightc ;
b [rs].maxs = T.maxs ;
b [rs].maxscl = T.maxscl ;
b [rs].maxscr = T.maxscr ;
b [rs].sum = T.sum ;
return ;
}
int mid = ( a [x].l + a [x].r ) >> 1 ;
if( l <= mid ) finds( x << 1 , l , r ) ;
if( r > mid ) finds( x << 1 | 1 , l , r ) ;
}
inline void Ans( int l , int r ){
cnt = 0 ;
finds( 1 , l , r ) ;
// sort( b + 1 , b + 1 + cnt , cmp ) ;
// debug() ;
printf( "%lld %lld %lld\n" , b [cnt].maxscl , b [cnt].maxscr , b [cnt].maxs ) ;
// printf( "%lld\n" , 0 ) ;
}
void sc(){
n = read() , m = read() ;
for( R i = 1 ; i <= n ; i ++ ) c [i] = read() ;
build( 1 , 1 , n ) ; //debug() ;
b [0].left = b [0].right = b [0].maxs = -INF ;
int lam , ram ;
for( R i = 1 ; i <= m ; i ++ ) lam = read() , ram = read() , Ans( lam , ram ) ;
}
signed main(){
sc() ;
return 0 ;
}
貌似还有很多操作,所以说待进阶,之后学了再更新吧(挖坑带师上线)

浙公网安备 33010602011771号