P6242 【模板】线段树 3
P6242 【模板】线段树 3
分析
一看到要求,历史最大值。我们就要来看看操作,看看能不能将所有操作化成一个操作。这样就好操作一些。
总共两个操作,区间加,区间取min。
区间加好说,我们来看看如何解决区间取min
操作。
我们其实无法做到快速的去区间取min
,暴力的去求,我们只能递归到叶节点才能更新。
但是这样我们的时间复杂度,最坏情况下,单次修改将会是\(O(nlogn)\)的。
我们考虑换个思路,将区间内大于v
的数都减去一个数,将这些数都变为v
。
因为不同的数要减去不同的数才能等于v
,显然我们不能设很多的tag
来表示区间内不同的数要减去的数。
那我们退一步来说,如果区间内大于v
的数只有一种,我们可以怎么做呢?
只需要维护一个tag
就好了,我们只用知道这一个数需要减的值就好了。
那怎样才能使得区间内大于v
的数只有一种呢?
只需要多递归几次就行了。
我们在线段树每个节点设以下三个变量,分别为:
mx
:该区间的最大值se
:该区间的严格次大值cnt
:该区间的最大值的个数
因为只有在区间只有一种数大于k
的时候我们才能快速更新,即只有在满足\(se<k<mx\)的节点上更新。
具体的,我们在进行操作2时,会有以下3
种情况。
- \(mx\leq v\):说明这个区间的最大值小于等于v,直接返回即可
- \(se<v<mx\):说明这个区间的最大值会全部被修改为
v
,但其他数和最大值的个数不变。修改后打上标记返回。 - \(v\leq se\):无法更新,接着往下递归。
时间复杂度\(O(mlog^2n)\)
接下来,我们就可以来解决问题啦。
代码实现
因为题目比较复杂,我们把每一个部分都拆解出来。
我们先来看看,我们需要维护的值有哪些。
sum
:区间的和mx
:区间最值hismx
:区间历史最值se
:区间严格次大值cnt
:区间最值个数
然后来看看我们配套懒标记。
mxadd
:区间最大值的区间加标记add
:区间除最大值外其余的值的区间加标记hismxadd
:为了算历史最值,我们需要统计一下mxadd
的最大值。hisadd
:为了算历史最值,我们需要统计一下add
的最大值。
这些懒标记,还算是比较好理解。如果不理解,可以看看接下来关于代码部分的解释。
我们只讲modifyadd
,modifymin
,pushdown
。其余部分只有一些细节需要关注,大部分内容都和普通线段树无异。
pushdown
我们直接来讲重头戏pushdown
。
我们来思考一下,我们的四个标记对子区间的所有维护的元素(包括维护的值与配套的懒标记)的影响。
-
mxadd
它会影响,子区间的
sum
,mx
,mxadd
。 -
add
它会影响,子区间的
sum
,se
,add
。 -
hismxadd
它会影响,子区间的
hismx
,hismxadd
。 -
hisadd
它会影响,子区间的
hisadd
。
具体怎样操作,我会在代码中加上适当的注释帮助理解。
void change(Node &u,int mxadd,int add,int hismxadd,int hisadd)
{
u.sum += 1ll*mxadd*u.cnt + 1ll*add*(u.r - u.l + 1 - u.cnt);
u.hismx = max(u.hismx,u.mx + hismxadd);//历史最值更新是,用当前的最值+(mxadd的最大值)
u.mx += mxadd;//区间最值变化为u.mx+mxadd
if(u.se!=-inf) u.se += add;//若存在次小值,区间次大值变为u.se + add
u.hismxadd = max(u.hismxadd,u.mxadd + hismxadd);//hismxadd用当前区间此时的u.mxadd+hismxadd来更新
u.hisadd = max(u.hisadd,u.add + hisadd);//hisadd用当前区间此时的u.add+hisadd来更新
u.mxadd += mxadd,u.add += add;//接下来,更新当前区间此时的u.mxadd与u.add
}
void pushdown(int u)
{
auto &root = tr[u],&left = tr[u<<1],&right = tr[u<<1|1];
if(root.mxadd||root.add||root.hismxadd||root.hisadd){
int mx = max(left.mx,right.mx);
//如果子区间的最值等于此时区间的最值,则用来更新的子区间的mxadd标记就用root.mxadd,hismxadd就用root.hismxadd
//否则就,用root.add与root.hisadd。
if(left.mx==mx) change(left,root.mxadd,root.add,root.hismxadd,root.hisadd);
else change(left,root.add,root.add,root.hisadd,root.hisadd);
if(right.mx==mx) change(right,root.mxadd,root.add,root.hismxadd,root.hisadd);
else change(right,root.add,root.add,root.hisadd,root.hisadd);
root.mxadd = root.add = root.hismxadd = root.hisadd = 0;
}
}
modify
其中分为两部分,直接看代码中的注释吧。
void modfiyadd(int u,int l,int r,int k)
{
if(tr[u].l>r||tr[u].r<l) return ;
//更新当前区间,更新标记为mxadd = k,add = k,hismxadd = k,hisadd = k;
if(l<=tr[u].l&&tr[u].r<=r) return change(tr[u],k,k,k,k);
pushdown(u);
modfiyadd(u<<1,l,r,k),modfiyadd(u<<1|1,l,r,k);
pushup(u);
}
void modfiymin(int u,int l,int r,int k)
{
if(l>tr[u].r||tr[u].l>r||k>=tr[u].mx) return ;//若此时区间的最大值已经小于等于k,直接返回
//只有当满足k>tr[u].se时才更新
//更新当前区间,更新标记为mxadd = k-tr[u].mx,add = 0,hismxadd = k-tr[u].mx,hisadd = 0;
//这里只对最大值操作,因为,只有当se<k<mx时,我们才对区间内所有的最大值操作,让他减去(tr[u].mx-k)
if(l<=tr[u].l&&tr[u].r<=r&&tr[u].se<k) return change(tr[u],k-tr[u].mx,0,k-tr[u].mx,0);
pushdown(u);
modfiymin(u<<1,l,r,k),modfiymin(u<<1|1,l,r,k);
pushup(u);
}
细节问题
- 注意,求和的时候记得开
long long
。 - 然后就是注意,
modifymin
的修改条件。
最后附上完全体代码
AC_code
#include<bits/stdc++.h>
#define ios ios::sync_with_stdio(false); cin.tie(0), cout.tie(0)
using namespace std;
using ll = long long;
const int N = 5e5 + 10,inf = 2e9;
/**
* @brief
* sum 区间和
* mx 区间最大值,se区间严格次大值,cnt区间最大值个数,hismx区间历史最大值
* mxadd 区间最大值的懒标记 add 区间非最大值的懒标记 hismxadd 区间最大值的懒标记的历史最大值 hisadd 区间非最大值的懒标记的历史最大值
*/
struct Node
{
int l,r;
ll sum;
int mx,se,cnt,hismx;
int mxadd,add,hismxadd,hisadd;
}tr[N<<2];
void pushup(int u)
{
auto &root = tr[u],&left = tr[u<<1],&right = tr[u<<1|1];
root.sum = left.sum + right.sum;
root.mx = max(left.mx,right.mx);
root.hismx = max(left.hismx,right.hismx);
if(right.mx==left.mx)
{
root.se = max(left.se,right.se);
root.cnt = left.cnt + right.cnt;
}
else if(left.mx>right.mx)
{
root.se = max(left.se,right.mx);
root.cnt = left.cnt;
}
else
{
root.se = max(left.mx,right.se);
root.cnt = right.cnt;
}
}
void build(int u,int l,int r)
{
tr[u] = {l,r};
if(l==r)
{
int x;cin>>x;
tr[u].sum = tr[u].mx = tr[u].hismx = x;
tr[u].cnt = 1;tr[u].se = -inf;
return ;
}
int mid = l + r >> 1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);
pushup(u);
}
void change(Node &u,int mxadd,int add,int hismxadd,int hisadd)
{
u.sum += 1ll*mxadd*u.cnt + 1ll*add*(u.r - u.l + 1 - u.cnt);
u.hismx = max(u.hismx,u.mx + hismxadd);//历史最值更新是,用当前的最值+(mxadd的最大值)
u.mx += mxadd;//区间最值变化为u.mx+mxadd
if(u.se!=-inf) u.se += add;//若存在次小值,区间次大值变为u.se + add
u.hismxadd = max(u.hismxadd,u.mxadd + hismxadd);//hismxadd用当前区间此时的u.mxadd+hismxadd来更新
u.hisadd = max(u.hisadd,u.add + hisadd);//hisadd用当前区间此时的u.add+hisadd来更新
u.mxadd += mxadd,u.add += add;//接下来,更新当前区间此时的u.mxadd与u.add
}
void pushdown(int u)
{
auto &root = tr[u],&left = tr[u<<1],&right = tr[u<<1|1];
if(root.mxadd||root.add||root.hismxadd||root.hisadd){
int mx = max(left.mx,right.mx);
//如果子区间的最值等于此时区间的最值,则用来更新的子区间的mxadd标记就用root.mxadd,hismxadd就用root.hismxadd
//否则就,用root.add与root.hisadd。
if(left.mx==mx) change(left,root.mxadd,root.add,root.hismxadd,root.hisadd);
else change(left,root.add,root.add,root.hisadd,root.hisadd);
if(right.mx==mx) change(right,root.mxadd,root.add,root.hismxadd,root.hisadd);
else change(right,root.add,root.add,root.hisadd,root.hisadd);
root.mxadd = root.add = root.hismxadd = root.hisadd = 0;
}
}
void modfiyadd(int u,int l,int r,int k)
{
if(tr[u].l>r||tr[u].r<l) return ;
//更新当前区间,更新标记为mxadd = k,add = k,hismxadd = k,hisadd = k;
if(l<=tr[u].l&&tr[u].r<=r) return change(tr[u],k,k,k,k);
pushdown(u);
modfiyadd(u<<1,l,r,k),modfiyadd(u<<1|1,l,r,k);
pushup(u);
}
void modfiymin(int u,int l,int r,int k)
{
if(l>tr[u].r||tr[u].l>r||k>=tr[u].mx) return ;//若此时区间的最大值已经小于等于k,直接返回
//只有当满足k>tr[u].se时才更新
//更新当前区间,更新标记为mxadd = k-tr[u].mx,add = 0,hismxadd = k-tr[u].mx,hisadd = 0;
//这里只对最大值操作,因为,只有当se<k<mx时,我们才对区间内所有的最大值操作,让他减去(tr[u].mx-k)
if(l<=tr[u].l&&tr[u].r<=r&&tr[u].se<k) return change(tr[u],k-tr[u].mx,0,k-tr[u].mx,0);
pushdown(u);
modfiymin(u<<1,l,r,k),modfiymin(u<<1|1,l,r,k);
pushup(u);
}
ll querysum(int u,int l,int r)
{
if(tr[u].l>r||tr[u].r<l) return 0;
if(l<=tr[u].l&&tr[u].r<=r) return tr[u].sum;
pushdown(u);
return querysum(u<<1,l,r) + querysum(u<<1|1,l,r);
}
int querymx(int u,int l,int r)
{
if(tr[u].l>r||tr[u].r<l) return -inf;
if(l<=tr[u].l&&tr[u].r<=r) return tr[u].mx;
pushdown(u);
return max(querymx(u<<1,l,r), querymx(u<<1|1,l,r));
}
int queryhismx(int u,int l,int r)
{
if(tr[u].l>r||tr[u].r<l) return -inf;
if(l<=tr[u].l&&tr[u].r<=r) return tr[u].hismx;
pushdown(u);
return max(queryhismx(u<<1,l,r), queryhismx(u<<1|1,l,r));
}
int n,m;
int main()
{
ios;
cin>>n>>m;
build(1,1,n);
while(m--)
{
int op,l,r,k;cin>>op>>l>>r;
if(op==1)
{
cin>>k;
modfiyadd(1,l,r,k);
}
else if(op==2)
{
cin>>k;
modfiymin(1,l,r,k);
}
else if(op==3) cout<<querysum(1,l,r)<<'\n';
else if(op==4) cout<<querymx(1,l,r)<<'\n';
else cout<<queryhismx(1,l,r)<<'\n';
}
return 0;
}