线段树

线段树是一种维护区间性质的用数组模拟的数据结构,每个节点存储的是一个区间的结果,以其\(nlogn\)级的复杂度有着相当大的意义。
这里以区间加以及区间乘操作和区间和查询为例。
P3372 【模板】线段树 1
P3373 【模板】线段树 2

建树

无论怎么样,首先要进行的是建树。

#define ll long long
#define ls (now<<1)//左儿子
#define rs ((now<<1)|1)//右儿子
ll ans[N<<2],a[N],laz[N<<2],lax[N<<2];
ll mod;
void build(int now,int l,int r) {
    if(l==r) {
        ans[now]=a[l]%mod;
        laz[now]=0;
        lax[now]=1;
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    ans[now]=ans[ls]+ans[rs];
    ans[now]%=mod;
    laz[now]=0;
    lax[now]=1;
    return;
}

注意这里树开的是4倍空间,这主要是线段树的结构决定的。
对于满二叉树来说,树的高度大致为\(\lceil log_2 n\rceil\)
当我们有\(2^x\)个节点时,所需要的节点数是

\[\sum_{i-1}^{x}{2^i}=2^{x+1}-1 \]

只有二倍空间,但对于其他情况来说,从树高来计算,我们的底层节点数量理论为\(2^{\lfloor log_2 {n}\rfloor +1}\)个,对应的整棵树的节点也约为\(2^{\lfloor log_2 n \rfloor +1}*2\)个,也就是4倍的节点数目,所以应该开四倍空间。

区间修改

树建好之后,我们就进入了区间修改的情况了。

push_down的懒标记优化

对于每个次修改的区间,如果都要一直修改到具体的单点,每次修改的时间复杂度就是\(O(nlogn)\)了,这与我们初始设想的线段树的优秀修改性质相悖,所以我们需要提出一个\(lazy\text{_}tag\)来帮助性能优化。

单纯考虑加法

例如,我们在每一条线段加一个懒标记,表明这是需要加上但是还没进行操作的对应值,于是在每次进行区间修改时,我们只需要修改该线段对应的值,至于这条线段中具体每个点的情况,等到需要的时候再考虑。

void push_down(int now,ll l,ll r) {
    ll mid=(l+r)>>1;

    ans[ls]+=(mid-l+1ll)*laz[now];
    laz[ls]+=laz[now];

    ans[rs]+=(r-(mid+1)+1)*laz[now];
    laz[rs]+=laz[now];

    laz[now]=0;
}

而对于既有加法又有乘法的情况来说则要复杂得多,所以我们采用两个懒标记来进行实现。由于加法和乘法有优先级的差别,所以为了方便,我们先计算乘法,再计算加法。

乘法

void push_down(int now,ll l,ll r) {
    ll mid=(l+r)>>1;
	if(lax[now]!=1) {
    	ans[ls]=(ans[ls]*lax[now])%mod;
    	laz[ls]=(laz[ls]*lax[now])%mod;
    	lax[ls]=(lax[ls]*lax[now])%mod;

    	ans[rs]=(ans[rs]*lax[now])%mod;
    	laz[rs]=(laz[rs]*lax[now])%mod;
    	lax[rs]=(lax[rs]*lax[now])%mod;//*
	}
	if(laz[now]!=0) {
    	ans[ls]=(ans[ls]+(mid-l+1)*laz[now])%mod;
    	laz[ls]=(laz[ls]+laz[now])%mod;

    	ans[rs]=(ans[rs]+(r-mid)*laz[now])%mod;
    	laz[rs]=(laz[rs]+laz[now])%mod;//+
	}
    laz[now]=0ll;
    lax[now]=1ll;
}

update的操作

单纯考虑加法

void update(int now,ll l,ll r,ll x,ll y,ll k) {
    // printf("[%lld,%lld]\n",l,r);
    if(x<=l&&r<=y) {
        ans[now]+=k*(r-l+1);
        laz[now]+=k;//下一级别需要加
        // printf("[%lld,%lld]=%lld\n",l,r,ans[now]);
        return;
    }
    push_down(now,l,r);
    ll mid=(l+r)>>1;
    if(x<=mid) {
        update(ls,l,mid,x,y,k);
    }
    if(y>mid) {
        update(rs,mid+1,r,x,y,k);
    }
    ans[now]=ans[ls]+ans[rs];
    // printf("[%lld,%lld]==%lld\n",l,r,ans[now]);
    return;
}

乘法的考虑

void update(int now,ll l,ll r,ll x,ll y,ll k,int pd) {
    // printf("[%lld,%lld]\n",l,r);
    if(x<=l&&r<=y) {
        if(pd==2) {//+
            ans[now]=(ans[now]+k*(r-l+1))%mod;
            laz[now]=(laz[now]+k)%mod;//下一级别需要加
        // printf("[%lld,%lld]=%lld\n",l,r,ans[now]);
        }else {
            ans[now]=(ans[now]*k)%mod;
            // laz[now]=(laz[now]*k)%mod;++
            lax[now]=(lax[now]*k)%mod;
            laz[now]=(laz[now]*k)%mod;
        }
        return;
    }
    push_down(now,l,r);
    ll mid=(l+r)>>1;
    if(x<=mid) {
        update(ls,l,mid,x,y,k,pd);
    }
    if(y>mid) {
        update(rs,mid+1,r,x,y,k,pd);
    }
    ans[now]=(ans[ls]+ans[rs])%mod;
    // printf("[%lld,%lld]==%lld\n",l,r,ans[now]);
    return;
}

查询

查询相比之下则要简单地多。

加法

ll query(int now,ll l,ll r,ll x,ll y) {
    ll ret=0;
    if(x<=l&&r<=y) {
        return ans[now];
    }
    ll mid=(l+r)>>1;
    push_down(now,l,r);
    if(x<=mid)
        ret+=query(ls,l,mid,x,y);
    if(y>mid)
        ret+=query(rs,mid+1,r,x,y);
    return ret;
}

加与乘

ll query(int now,ll l,ll r,ll x,ll y) {
    ll ret=0;
    if(x<=l&&r<=y) {
        return ans[now];
    }
    ll mid=(l+r)>>1;
    push_down(now,l,r);
    if(x<=mid)
        ret+=query(ls,l,mid,x,y);
    if(y>mid)
        ret+=query(rs,mid+1,r,x,y);
    return ret%mod;
}

最后附上两道题的全部代码

加法

点击查看代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<string>
#include<map>
#include<cstdlib>
#include<set>
#include<vector>
#include<cmath>
#include<iostream>
using namespace std;
#define endl '\n'
const int N=2e5+10;
const int mod=1e9+7;
#define ll long long 
#define ull unsigned long long
#define ls (now<<1)//左儿子
#define rs ((now<<1)|1)//右儿子
ll ans[N<<2],a[N],laz[N<<2];
void build(int now,int l,int r) {
    if(l==r) {
        ans[now]=a[l];
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    ans[now]=ans[ls]+ans[rs];
    laz[now]=0;
    return;
}
void push_down(int now,ll l,ll r) {
    ll mid=(l+r)>>1;

    ans[ls]+=(mid-l+1ll)*laz[now];
    laz[ls]+=laz[now];

    ans[rs]+=(r-(mid+1)+1)*laz[now];
    laz[rs]+=laz[now];

    laz[now]=0;
}
void update(int now,ll l,ll r,ll x,ll y,ll k) {
    // printf("[%lld,%lld]\n",l,r);
    if(x<=l&&r<=y) {
        ans[now]+=k*(r-l+1);
        laz[now]+=k;//下一级别需要加
        // printf("[%lld,%lld]=%lld\n",l,r,ans[now]);
        return;
    }
    push_down(now,l,r);
    ll mid=(l+r)>>1;
    if(x<=mid) {
        update(ls,l,mid,x,y,k);
    }
    if(y>mid) {
        update(rs,mid+1,r,x,y,k);
    }
    ans[now]=ans[ls]+ans[rs];
    // printf("[%lld,%lld]==%lld\n",l,r,ans[now]);
    return;
}
ll query(int now,ll l,ll r,ll x,ll y) {
    ll ret=0;
    if(x<=l&&r<=y) {
        return ans[now];
    }
    ll mid=(l+r)>>1;
    push_down(now,l,r);
    if(x<=mid)
        ret+=query(ls,l,mid,x,y);
    if(y>mid)
        ret+=query(rs,mid+1,r,x,y);
    return ret;
}
void solve() {
    int n,m;
    cin>>n>>m;
    int i;
    for(i=1;i<=n;++i) {
        cin>>a[i];
    }
    build(1,1,n);
    while(m--) {
        int pd;
        cin>>pd;
        switch(pd) {
            ll x,y,k;
            case 1:
                cin>>x>>y>>k;
                update(1,1,n,x,y,k);
                break;
            case 2:
                cin>>x>>y;
                cout<<query(1,1,n,x,y)<<endl;
                break;
        }
    }
    return;
}
int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int t=1;
    // cin>>t;
    while(t--)
        solve();
    return 0;
}

加与乘

点击查看代码
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<string>
#include<map>
#include<cstdlib>
#include<set>
#include<vector>
#include<cmath>
#include<iostream>
using namespace std;
#define endl '\n'
const int N=2e5+10;
// const int mod=1e9+7;
#define ll long long 
#define ull unsigned long long
#define ls (now<<1)//左儿子
#define rs ((now<<1)|1)//右儿子
ll ans[N<<2],a[N],laz[N<<2],lax[N<<2];
ll mod;
void build(int now,int l,int r) {
    if(l==r) {
        ans[now]=a[l]%mod;
        laz[now]=0;
        lax[now]=1;
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    ans[now]=ans[ls]+ans[rs];
    ans[now]%=mod;
    laz[now]=0;
    lax[now]=1;
    return;
}
void push_down(int now,ll l,ll r) {
    ll mid=(l+r)>>1;

    ans[ls]=(ans[ls]*lax[now])%mod;
    laz[ls]=(laz[ls]*lax[now])%mod;
    lax[ls]=(lax[ls]*lax[now])%mod;

    ans[rs]=(ans[rs]*lax[now])%mod;
    laz[rs]=(laz[rs]*lax[now])%mod;
    lax[rs]=(lax[rs]*lax[now])%mod;//*

    ans[ls]=(ans[ls]+(mid-l+1)*laz[now])%mod;
    laz[ls]=(laz[ls]+laz[now])%mod;

    ans[rs]=(ans[rs]+(r-mid)*laz[now])%mod;
    laz[rs]=(laz[rs]+laz[now])%mod;//+

    laz[now]=0ll;
    lax[now]=1ll;
}
void update(int now,ll l,ll r,ll x,ll y,ll k,int pd) {
    // printf("[%lld,%lld]\n",l,r);
    if(x<=l&&r<=y) {
        if(pd==2) {//+
            ans[now]=(ans[now]+k*(r-l+1))%mod;
            laz[now]=(laz[now]+k)%mod;//下一级别需要加
        // printf("[%lld,%lld]=%lld\n",l,r,ans[now]);
        }else {
            ans[now]=(ans[now]*k)%mod;
            // laz[now]=(laz[now]*k)%mod;++
            lax[now]=(lax[now]*k)%mod;
            laz[now]=(laz[now]*k)%mod;
        }
        return;
    }
    push_down(now,l,r);
    ll mid=(l+r)>>1;
    if(x<=mid) {
        update(ls,l,mid,x,y,k,pd);
    }
    if(y>mid) {
        update(rs,mid+1,r,x,y,k,pd);
    }
    ans[now]=(ans[ls]+ans[rs])%mod;
    // printf("[%lld,%lld]==%lld\n",l,r,ans[now]);
    return;
}
ll query(int now,ll l,ll r,ll x,ll y) {
    ll ret=0;
    if(x<=l&&r<=y) {
        return ans[now];
    }
    ll mid=(l+r)>>1;
    push_down(now,l,r);
    if(x<=mid)
        ret+=query(ls,l,mid,x,y);
    if(y>mid)
        ret+=query(rs,mid+1,r,x,y);
    return ret%mod;
}
void solve() {
    int n,m;
    cin>>n>>m>>mod;
    int i;
    for(i=1;i<=n;++i) {
        cin>>a[i];
    }
    build(1,1,n);
    while(m--) {
        int pd;
        cin>>pd;
        switch(pd) {
            ll x,y,k;
            case 1:
                cin>>x>>y>>k;
                update(1,1,n,x,y,k,1);
                break;
            case 2:
                cin>>x>>y>>k;
                update(1,1,n,x,y,k,2);
                break;
            case 3:
                cin>>x>>y;
                cout<<query(1,1,n,x,y)<<endl;
                break;
        }
        // for(ll j=1;j<=n;++j) {
        //     printf("[%lld] ",query(1,1,n,j,j));
        // }
        // puts("");
    }
    return;
}
int main() {
    ios_base::sync_with_stdio(0);
    cin.tie(0);
    cout.tie(0);
    int t=1;
    // cin>>t;
    while(t--)
        solve();
    return 0;
}
posted @ 2024-05-09 18:12  WE-R  阅读(16)  评论(0)    收藏  举报