线段树
线段树是一种维护区间性质的用数组模拟的数据结构,每个节点存储的是一个区间的结果,以其\(nlogn\)级的复杂度有着相当大的意义。
这里以区间加以及区间乘操作和区间和查询为例。
P3372 【模板】线段树 1
P3373 【模板】线段树 2
建树
无论怎么样,首先要进行的是建树。
#define ll long long
#define ls (now<<1)//左儿子
#define rs ((now<<1)|1)//右儿子
ll ans[N<<2],a[N],laz[N<<2],lax[N<<2];
ll mod;
void build(int now,int l,int r) {
if(l==r) {
ans[now]=a[l]%mod;
laz[now]=0;
lax[now]=1;
return;
}
int mid=(l+r)>>1;
build(ls,l,mid);
build(rs,mid+1,r);
ans[now]=ans[ls]+ans[rs];
ans[now]%=mod;
laz[now]=0;
lax[now]=1;
return;
}
注意这里树开的是4倍空间,这主要是线段树的结构决定的。
对于满二叉树来说,树的高度大致为\(\lceil log_2 n\rceil\)
当我们有\(2^x\)个节点时,所需要的节点数是
\[\sum_{i-1}^{x}{2^i}=2^{x+1}-1
\]
只有二倍空间,但对于其他情况来说,从树高来计算,我们的底层节点数量理论为\(2^{\lfloor log_2 {n}\rfloor +1}\)个,对应的整棵树的节点也约为\(2^{\lfloor log_2 n \rfloor +1}*2\)个,也就是4倍的节点数目,所以应该开四倍空间。
区间修改
树建好之后,我们就进入了区间修改的情况了。
push_down的懒标记优化
对于每个次修改的区间,如果都要一直修改到具体的单点,每次修改的时间复杂度就是\(O(nlogn)\)了,这与我们初始设想的线段树的优秀修改性质相悖,所以我们需要提出一个\(lazy\text{_}tag\)来帮助性能优化。
单纯考虑加法
例如,我们在每一条线段加一个懒标记,表明这是需要加上但是还没进行操作的对应值,于是在每次进行区间修改时,我们只需要修改该线段对应的值,至于这条线段中具体每个点的情况,等到需要的时候再考虑。
void push_down(int now,ll l,ll r) {
ll mid=(l+r)>>1;
ans[ls]+=(mid-l+1ll)*laz[now];
laz[ls]+=laz[now];
ans[rs]+=(r-(mid+1)+1)*laz[now];
laz[rs]+=laz[now];
laz[now]=0;
}
而对于既有加法又有乘法的情况来说则要复杂得多,所以我们采用两个懒标记来进行实现。由于加法和乘法有优先级的差别,所以为了方便,我们先计算乘法,再计算加法。
乘法
void push_down(int now,ll l,ll r) {
ll mid=(l+r)>>1;
if(lax[now]!=1) {
ans[ls]=(ans[ls]*lax[now])%mod;
laz[ls]=(laz[ls]*lax[now])%mod;
lax[ls]=(lax[ls]*lax[now])%mod;
ans[rs]=(ans[rs]*lax[now])%mod;
laz[rs]=(laz[rs]*lax[now])%mod;
lax[rs]=(lax[rs]*lax[now])%mod;//*
}
if(laz[now]!=0) {
ans[ls]=(ans[ls]+(mid-l+1)*laz[now])%mod;
laz[ls]=(laz[ls]+laz[now])%mod;
ans[rs]=(ans[rs]+(r-mid)*laz[now])%mod;
laz[rs]=(laz[rs]+laz[now])%mod;//+
}
laz[now]=0ll;
lax[now]=1ll;
}
update的操作
单纯考虑加法
void update(int now,ll l,ll r,ll x,ll y,ll k) {
// printf("[%lld,%lld]\n",l,r);
if(x<=l&&r<=y) {
ans[now]+=k*(r-l+1);
laz[now]+=k;//下一级别需要加
// printf("[%lld,%lld]=%lld\n",l,r,ans[now]);
return;
}
push_down(now,l,r);
ll mid=(l+r)>>1;
if(x<=mid) {
update(ls,l,mid,x,y,k);
}
if(y>mid) {
update(rs,mid+1,r,x,y,k);
}
ans[now]=ans[ls]+ans[rs];
// printf("[%lld,%lld]==%lld\n",l,r,ans[now]);
return;
}
乘法的考虑
void update(int now,ll l,ll r,ll x,ll y,ll k,int pd) {
// printf("[%lld,%lld]\n",l,r);
if(x<=l&&r<=y) {
if(pd==2) {//+
ans[now]=(ans[now]+k*(r-l+1))%mod;
laz[now]=(laz[now]+k)%mod;//下一级别需要加
// printf("[%lld,%lld]=%lld\n",l,r,ans[now]);
}else {
ans[now]=(ans[now]*k)%mod;
// laz[now]=(laz[now]*k)%mod;++
lax[now]=(lax[now]*k)%mod;
laz[now]=(laz[now]*k)%mod;
}
return;
}
push_down(now,l,r);
ll mid=(l+r)>>1;
if(x<=mid) {
update(ls,l,mid,x,y,k,pd);
}
if(y>mid) {
update(rs,mid+1,r,x,y,k,pd);
}
ans[now]=(ans[ls]+ans[rs])%mod;
// printf("[%lld,%lld]==%lld\n",l,r,ans[now]);
return;
}
查询
查询相比之下则要简单地多。
加法
ll query(int now,ll l,ll r,ll x,ll y) {
ll ret=0;
if(x<=l&&r<=y) {
return ans[now];
}
ll mid=(l+r)>>1;
push_down(now,l,r);
if(x<=mid)
ret+=query(ls,l,mid,x,y);
if(y>mid)
ret+=query(rs,mid+1,r,x,y);
return ret;
}
加与乘
ll query(int now,ll l,ll r,ll x,ll y) {
ll ret=0;
if(x<=l&&r<=y) {
return ans[now];
}
ll mid=(l+r)>>1;
push_down(now,l,r);
if(x<=mid)
ret+=query(ls,l,mid,x,y);
if(y>mid)
ret+=query(rs,mid+1,r,x,y);
return ret%mod;
}
最后附上两道题的全部代码
加法
点击查看代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<string>
#include<map>
#include<cstdlib>
#include<set>
#include<vector>
#include<cmath>
#include<iostream>
using namespace std;
#define endl '\n'
const int N=2e5+10;
const int mod=1e9+7;
#define ll long long
#define ull unsigned long long
#define ls (now<<1)//左儿子
#define rs ((now<<1)|1)//右儿子
ll ans[N<<2],a[N],laz[N<<2];
void build(int now,int l,int r) {
if(l==r) {
ans[now]=a[l];
return;
}
int mid=(l+r)>>1;
build(ls,l,mid);
build(rs,mid+1,r);
ans[now]=ans[ls]+ans[rs];
laz[now]=0;
return;
}
void push_down(int now,ll l,ll r) {
ll mid=(l+r)>>1;
ans[ls]+=(mid-l+1ll)*laz[now];
laz[ls]+=laz[now];
ans[rs]+=(r-(mid+1)+1)*laz[now];
laz[rs]+=laz[now];
laz[now]=0;
}
void update(int now,ll l,ll r,ll x,ll y,ll k) {
// printf("[%lld,%lld]\n",l,r);
if(x<=l&&r<=y) {
ans[now]+=k*(r-l+1);
laz[now]+=k;//下一级别需要加
// printf("[%lld,%lld]=%lld\n",l,r,ans[now]);
return;
}
push_down(now,l,r);
ll mid=(l+r)>>1;
if(x<=mid) {
update(ls,l,mid,x,y,k);
}
if(y>mid) {
update(rs,mid+1,r,x,y,k);
}
ans[now]=ans[ls]+ans[rs];
// printf("[%lld,%lld]==%lld\n",l,r,ans[now]);
return;
}
ll query(int now,ll l,ll r,ll x,ll y) {
ll ret=0;
if(x<=l&&r<=y) {
return ans[now];
}
ll mid=(l+r)>>1;
push_down(now,l,r);
if(x<=mid)
ret+=query(ls,l,mid,x,y);
if(y>mid)
ret+=query(rs,mid+1,r,x,y);
return ret;
}
void solve() {
int n,m;
cin>>n>>m;
int i;
for(i=1;i<=n;++i) {
cin>>a[i];
}
build(1,1,n);
while(m--) {
int pd;
cin>>pd;
switch(pd) {
ll x,y,k;
case 1:
cin>>x>>y>>k;
update(1,1,n,x,y,k);
break;
case 2:
cin>>x>>y;
cout<<query(1,1,n,x,y)<<endl;
break;
}
}
return;
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int t=1;
// cin>>t;
while(t--)
solve();
return 0;
}
加与乘
点击查看代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<string>
#include<map>
#include<cstdlib>
#include<set>
#include<vector>
#include<cmath>
#include<iostream>
using namespace std;
#define endl '\n'
const int N=2e5+10;
// const int mod=1e9+7;
#define ll long long
#define ull unsigned long long
#define ls (now<<1)//左儿子
#define rs ((now<<1)|1)//右儿子
ll ans[N<<2],a[N],laz[N<<2],lax[N<<2];
ll mod;
void build(int now,int l,int r) {
if(l==r) {
ans[now]=a[l]%mod;
laz[now]=0;
lax[now]=1;
return;
}
int mid=(l+r)>>1;
build(ls,l,mid);
build(rs,mid+1,r);
ans[now]=ans[ls]+ans[rs];
ans[now]%=mod;
laz[now]=0;
lax[now]=1;
return;
}
void push_down(int now,ll l,ll r) {
ll mid=(l+r)>>1;
ans[ls]=(ans[ls]*lax[now])%mod;
laz[ls]=(laz[ls]*lax[now])%mod;
lax[ls]=(lax[ls]*lax[now])%mod;
ans[rs]=(ans[rs]*lax[now])%mod;
laz[rs]=(laz[rs]*lax[now])%mod;
lax[rs]=(lax[rs]*lax[now])%mod;//*
ans[ls]=(ans[ls]+(mid-l+1)*laz[now])%mod;
laz[ls]=(laz[ls]+laz[now])%mod;
ans[rs]=(ans[rs]+(r-mid)*laz[now])%mod;
laz[rs]=(laz[rs]+laz[now])%mod;//+
laz[now]=0ll;
lax[now]=1ll;
}
void update(int now,ll l,ll r,ll x,ll y,ll k,int pd) {
// printf("[%lld,%lld]\n",l,r);
if(x<=l&&r<=y) {
if(pd==2) {//+
ans[now]=(ans[now]+k*(r-l+1))%mod;
laz[now]=(laz[now]+k)%mod;//下一级别需要加
// printf("[%lld,%lld]=%lld\n",l,r,ans[now]);
}else {
ans[now]=(ans[now]*k)%mod;
// laz[now]=(laz[now]*k)%mod;++
lax[now]=(lax[now]*k)%mod;
laz[now]=(laz[now]*k)%mod;
}
return;
}
push_down(now,l,r);
ll mid=(l+r)>>1;
if(x<=mid) {
update(ls,l,mid,x,y,k,pd);
}
if(y>mid) {
update(rs,mid+1,r,x,y,k,pd);
}
ans[now]=(ans[ls]+ans[rs])%mod;
// printf("[%lld,%lld]==%lld\n",l,r,ans[now]);
return;
}
ll query(int now,ll l,ll r,ll x,ll y) {
ll ret=0;
if(x<=l&&r<=y) {
return ans[now];
}
ll mid=(l+r)>>1;
push_down(now,l,r);
if(x<=mid)
ret+=query(ls,l,mid,x,y);
if(y>mid)
ret+=query(rs,mid+1,r,x,y);
return ret%mod;
}
void solve() {
int n,m;
cin>>n>>m>>mod;
int i;
for(i=1;i<=n;++i) {
cin>>a[i];
}
build(1,1,n);
while(m--) {
int pd;
cin>>pd;
switch(pd) {
ll x,y,k;
case 1:
cin>>x>>y>>k;
update(1,1,n,x,y,k,1);
break;
case 2:
cin>>x>>y>>k;
update(1,1,n,x,y,k,2);
break;
case 3:
cin>>x>>y;
cout<<query(1,1,n,x,y)<<endl;
break;
}
// for(ll j=1;j<=n;++j) {
// printf("[%lld] ",query(1,1,n,j,j));
// }
// puts("");
}
return;
}
int main() {
ios_base::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
int t=1;
// cin>>t;
while(t--)
solve();
return 0;
}

浙公网安备 33010602011771号