线段树区间历史、区间最值操作
前置知识:只有朴素线段树板子.
很喜欢传奇硬派学长的一句话:
线段树历史最值是 NOIP 考点.
还真是!而且线段树维护一切东西其实都可以是 NOIP 考点.
所以我们来看看一些线段树的区间最值操作和区间历史查询问题.
P6242 【模板】线段树 3(区间最值操作、区间历史最值)
操作 1,3,4 都是最基础的线段树操作,维护区间长度(或左右子树节点),区间最大值,区间和,以及加法标记即可,这里不再展开讲.
区间历史最值操作
可以先考虑历史最大值怎么做. 类似于线段树板子的,我们希望用某种 \(tag\) 来对其进行维护.
那么考虑历史最大值什么时候会发生变化. 现在我们已经维护了区间加法标记,随着加法操作不断进行,设现在的加法标记 \(t_1=k\),发现:
- \(k>0\),历史最大值会增加.
- \(k\le0\),历史最大值不会改变.
然而当我们查询时,历史最大值实际上是上次查询之后的 \(\max\{k\}\). 也就是说我们需要额外维护一个 \(t_2\) 来存 \(t_1\) 的最大值. 下传时,将 \(t_2\) 更新给历史最值即可. 这就是 \(B\) 序列区间最值的维护.
区间最值操作
注意:这个操作涉及到一些均摊、势能分析技巧. 如果没听过就感性理解.
这个操作最困难的点就在于它作用于很多值不同的数上,每种值的变化量不同,每种值受到作用的数还不止一个,并且限定了作用范围.
于是我们不妨来思考更加简易的操作:
-
如果区间取 \(\min\) 只对最大值生效怎么做?那么区间中的每个最大值都减去了一个确定的数,我们只用维护区间最大值与区间最大值的个数即可.
-
只对最大值生效的条件是什么?我们不妨设区间对 \(x\) 取 \(\min\),那么一个充要的条件则是 \(x\) 小于区间最大值,大于区间严格次大值. 区间次大值的维护也并不困难,但是 \(x\) 如果不直接满足这个条件应该怎么做呢?一个不难证明的性质是随着线段树向下递归,最大值和次大值是单调不增的,所以查询时向下递归同样也是单调不增的. 所以我们就可以在符合条件时向下递归,直到 \(x\) 满足上述条件就可以直接根据最大值的个数来进行修改了.
但是这样看起来十分的暴力,为什么能够保证复杂度正确呢?根据势能分析,单独的区间取最值按照上面的方法是均摊 \(O(n\log n)\) 的,而区间取最值同时维护区间加的均摊复杂度为 \(O(n\log^2 n)\),非常的优秀. 证明用到的一些关键性质是最值个数单调不增以及线段树高 \(O(\log n)\). 具体证明详见吉如一老师的集训队论文.
主要讲一下上传、下传标记的细节. 先看代码:
inline void upd(int u) {
tr[u].sum = tr[ls].sum + tr[rs].sum;
tr[u].mxa = max(tr[ls].mxa, tr[rs].mxa), tr[u].mxb = max(tr[ls].mxb, tr[rs].mxb);
if(tr[ls].mxa == tr[rs].mxa) tr[u].se = max(tr[ls].se, tr[rs].se), tr[u].cnt = tr[ls].cnt + tr[rs].cnt;
if(tr[ls].mxa < tr[rs].mxa) tr[u].se = max(tr[ls].mxa, tr[rs].se), tr[u].cnt = tr[rs].cnt;
if(tr[ls].mxa > tr[rs].mxa) tr[u].se = max(tr[rs].mxa, tr[ls].se), tr[u].cnt = tr[ls].cnt;
return;
}
inline void change(int u, int t1, int t2, int t3, int t4) {
tr[u].sum += 1ll * t1 * tr[u].cnt + 1ll * t2 * (tr[u].len - tr[u].cnt);
tr[u].mxb = max(tr[u].mxa + t3, tr[u].mxb), tr[u].mxa += t1; if(tr[u].se != -inf) tr[u].se += t2;
tr[u].t3 = max(tr[u].t3, tr[u].t1 + t3), tr[u].t4 = max(tr[u].t4, tr[u].t2 + t4);
tr[u].t1 += t1, tr[u].t2 += t2;
return;
}
inline void pushdown(int u) {
int mx = max(tr[ls].mxa, tr[rs].mxa);
if(tr[ls].mxa == mx) change(ls, tr[u].t1, tr[u].t2, tr[u].t3, tr[u].t4);
else change(ls, tr[u].t2, tr[u].t2, tr[u].t4, tr[u].t4);
if(tr[rs].mxa == mx) change(rs, tr[u].t1, tr[u].t2, tr[u].t3, tr[u].t4);
else change(rs, tr[u].t2, tr[u].t2, tr[u].t4, tr[u].t4);
tr[u].t1 = tr[u].t2 = tr[u].t3 = tr[u].t4 = 0;
return;
}
我们维护了区间和 sum,区间最大值 mxa,区间历史最大值 mxb,区间次大值 se,区间最大值个数 cnt,区间长度 len,以及四个标记 t1,t2,t3,t4 分别表示区间最大值的加法标记,区间非最大值的加法标记,区间历史最大值的加法标记,区间历史非最大值的加法标记. 为了同时维护上述两个操作,这些标记是必要的.
先从 upd() 函数讲起. 其它标记是容易理解的,最主要的问题就集中在 se,cnt 的维护上. 由于跟最大值具体的值相关,而最大值又取左右儿子中较大的,所以要分讨左右儿子最大值的大小关系来进行维护.
接着是 pushdown() 和辅助下传函数 change(). 先看 change(),由于标记区分了最大值和非最大值,所以加法贡献也要拆成最大值标记和非最大值标记两部分来算. 其余可以根据上文分析进一步理解. 然后看 pushdown(),有一个基本事实是:历史最大值一定是由左右儿子中的较大值更新得来的. 这也就意味着左右儿子中的较小值只能用非最大值标记来更新,于是分讨即可.
最后有几个细节:
- 涉及
se的更新,都需要判断其是否存在. 所以初值要赋成极小值. - 修改函数也可以借助
change()函数更新.
具体实现涉及大量分讨,但是根据定义还是不难理解的,多写写熟练就好.
点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn = 5e5 + 10; const ll inf = 2e18;
int n, m;
#define ls (u << 1)
#define rs (u << 1 | 1)
struct sgt {
struct node{ll sum, t1, t2, t3, t4, mxa, se, mxb, cnt, len;} tr[maxn << 2];
inline void upd(int u) {
tr[u].sum = tr[ls].sum + tr[rs].sum;
tr[u].mxa = max(tr[ls].mxa, tr[rs].mxa), tr[u].mxb = max(tr[ls].mxb, tr[rs].mxb);
if(tr[ls].mxa == tr[rs].mxa) tr[u].se = max(tr[ls].se, tr[rs].se), tr[u].cnt = tr[ls].cnt + tr[rs].cnt;
if(tr[ls].mxa < tr[rs].mxa) tr[u].se = max(tr[ls].mxa, tr[rs].se), tr[u].cnt = tr[rs].cnt;
if(tr[ls].mxa > tr[rs].mxa) tr[u].se = max(tr[rs].mxa, tr[ls].se), tr[u].cnt = tr[ls].cnt;
return;
}
inline void change(int u, int t1, int t2, int t3, int t4) {
tr[u].sum += 1ll * t1 * tr[u].cnt + 1ll * t2 * (tr[u].len - tr[u].cnt);
tr[u].mxb = max(tr[u].mxa + t3, tr[u].mxb), tr[u].mxa += t1; if(tr[u].se != -inf) tr[u].se += t2;
tr[u].t3 = max(tr[u].t3, tr[u].t1 + t3), tr[u].t4 = max(tr[u].t4, tr[u].t2 + t4);
tr[u].t1 += t1, tr[u].t2 += t2;
return;
}
inline void pushdown(int u) {
int mx = max(tr[ls].mxa, tr[rs].mxa);
if(tr[ls].mxa == mx) change(ls, tr[u].t1, tr[u].t2, tr[u].t3, tr[u].t4);
else change(ls, tr[u].t2, tr[u].t2, tr[u].t4, tr[u].t4);
if(tr[rs].mxa == mx) change(rs, tr[u].t1, tr[u].t2, tr[u].t3, tr[u].t4);
else change(rs, tr[u].t2, tr[u].t2, tr[u].t4, tr[u].t4);
tr[u].t1 = tr[u].t2 = tr[u].t3 = tr[u].t4 = 0;
return;
}
inline void build(int u, int l, int r) {
tr[u].len = r - l + 1;
if(l == r) {int x; cin >> x; tr[u].sum = tr[u].mxa = tr[u].mxb = x, tr[u].se = -inf, tr[u].cnt = 1; return;}
int mid = l + r >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
return upd(u), void(0);
}
inline void qadd(int u, int l, int r, int ql, int qr, int x) {
if(ql <= l && r <= qr) {return change(u, x, x, x, x), void(0);}
pushdown(u);
int mid = l + r >> 1;
if(ql <= mid) qadd(ls, l, mid, ql, qr, x);
if(mid < qr) qadd(rs, mid + 1, r, ql ,qr, x);
return upd(u), void(0);
}
inline void qmin(int u, int l, int r, int ql, int qr, int x) {
if(x >= tr[u].mxa) return;
if(ql <= l && r <= qr && tr[u].se < x) {return change(u, x - tr[u].mxa, 0, 0, 0), void(0);}
pushdown(u);
int mid = l + r >> 1;
if(ql <= mid) qmin(ls, l, mid, ql, qr, x);
if(mid < qr) qmin(rs, mid + 1, r, ql ,qr, x);
return upd(u), void(0);
}
inline ll ask_sum(int u, int l, int r, int ql, int qr) {
if(ql <= l && r <= qr) return tr[u].sum;
pushdown(u);
ll res = 0; int mid = l + r >> 1;
if(ql <= mid) res += ask_sum(ls, l, mid, ql, qr);
if(mid < qr) res += ask_sum(rs, mid + 1, r, ql, qr);
return res;
}
inline ll ask_mxa(int u, int l, int r, int ql, int qr) {
if(ql <= l && r <= qr) return tr[u].mxa;
pushdown(u);
ll res = -inf, mid = l + r >> 1;
if(ql <= mid) res = max(res, ask_mxa(ls, l, mid, ql, qr));
if(mid < qr) res = max(res, ask_mxa(rs, mid + 1, r, ql, qr));
return res;
}
inline ll ask_mxb(int u, int l, int r, int ql, int qr) {
if(ql <= l && r <= qr) return tr[u].mxb;
pushdown(u);
ll res = -inf, mid = l + r >> 1;
if(ql <= mid) res = max(res, ask_mxb(ls, l, mid, ql, qr));
if(mid < qr) res = max(res, ask_mxb(rs, mid + 1, r, ql, qr));
return res;
}
}t;
int main() {
ios :: sync_with_stdio(false); cin.tie(0); cout.tie(0);
cin >> n >> m; t.build(1, 1, n);
for(int i = 1, op, l, r, x; i <= m; i++) {
cin >> op >> l >> r; if(op == 1 || op == 2) cin >> x;
if(op == 1) t.qadd(1, 1, n, l, r, x);
if(op == 2) t.qmin(1, 1, n, l, r, x);
if(op == 3) cout << t.ask_sum(1, 1, n, l, r) << endl;
if(op == 4) cout << t.ask_mxa(1, 1, n, l, r) << endl;
if(op == 5) cout << t.ask_mxb(1, 1, n, l, r) << endl;
}
return 0;
}
再看一道例题:
P4314 CPU 监控
Hint:考虑如何维护区间覆盖与区间加法对最大值的影响.
直接来看维护标记的几个函数:
inline void do_cov(int u, int x, int hx) {
if(tr[u].iscov) tr[u].hcov = max(tr[u].hcov, hx);
else tr[u].iscov = true, tr[u].hcov = hx;
tr[u].hmx = max(tr[u].hmx, hx), tr[u].cov = tr[u].mx = x, tr[u].add = 0;
return;
}
inline void do_add(int u, int x, int hx) {
if(tr[u].iscov) do_cov(u, tr[u].cov + x, tr[u].cov + hx);
else {
tr[u].hmx = max(tr[u].hmx, tr[u].mx + hx), tr[u].hadd = max(tr[u].hadd, tr[u].add + hx);
tr[u].mx += x, tr[u].add += x;
} return;
}
inline void pushdown(int u) {
do_add(ls, tr[u].add, tr[u].hadd), do_add(rs, tr[u].add, tr[u].hadd);
tr[u].add = tr[u].hadd = 0;
if(tr[u].iscov) {
do_cov(ls, tr[u].cov, tr[u].hcov), do_cov(rs, tr[u].cov, tr[u].hcov);
tr[u].cov = tr[u].hcov = tr[u].iscov = 0;
} return;
}
类似于前面提到的维护历史最值的方法,我们需要维护覆盖标记的最大值 hcov 和是否有覆盖标记的标记 iscov. 一个基本事实是,一个区间一旦被覆盖过至少一次,加法操作就等价于覆盖操作了,否则正常下传即可. 而为了囊括无初值的情况,do_cov 也需要针对是否有过覆盖操作进行讨论. pushdown() 下传应该优先传加法标记,再传覆盖标记,正确性显然.
完整代码:
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 10;
const int inff = 2147483647;
int n, q, a[maxn];
#define ls (u << 1)
#define rs (u << 1 | 1)
struct sgt{
struct node{int add, hadd, cov, hcov, mx, hmx; bool iscov;} tr[maxn << 2];
inline void upd(int u) {tr[u].mx = max(tr[ls].mx, tr[rs].mx), tr[u].hmx = max(tr[u].hmx, tr[u].mx); return;}
inline void build(int u, int l, int r) {
tr[u].mx = tr[u].hmx = -inff;
if(l == r) {tr[u].mx = tr[u].hmx = a[l]; return void(0);}
int mid = l + r >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
return upd(u), void(0);
}
inline void do_cov(int u, int x, int hx) {
if(tr[u].iscov) tr[u].hcov = max(tr[u].hcov, hx);
else tr[u].iscov = true, tr[u].hcov = hx;
tr[u].hmx = max(tr[u].hmx, hx), tr[u].cov = tr[u].mx = x, tr[u].add = 0;
return;
}
inline void do_add(int u, int x, int hx) {
if(tr[u].iscov) do_cov(u, tr[u].cov + x, tr[u].cov + hx);
else {
tr[u].hmx = max(tr[u].hmx, tr[u].mx + hx), tr[u].hadd = max(tr[u].hadd, tr[u].add + hx);
tr[u].mx += x, tr[u].add += x;
} return;
}
inline void pushdown(int u) {
do_add(ls, tr[u].add, tr[u].hadd), do_add(rs, tr[u].add, tr[u].hadd);
tr[u].add = tr[u].hadd = 0;
if(tr[u].iscov) {
do_cov(ls, tr[u].cov, tr[u].hcov), do_cov(rs, tr[u].cov, tr[u].hcov);
tr[u].cov = tr[u].hcov = tr[u].iscov = 0;
} return;
}
inline void qadd(int u, int l, int r, int ql, int qr, int x) {
if(ql <= l && r <= qr) return do_add(u, x, x), void(0);
pushdown(u); int mid = l + r >> 1;
if(ql <= mid) qadd(ls, l, mid, ql, qr, x);
if(mid < qr) qadd(rs, mid + 1, r, ql, qr, x);
return upd(u), void(0);
}
inline void qcov(int u, int l, int r, int ql, int qr, int x) {
if(ql <= l && r <= qr) return do_cov(u, x, x), void(0);
pushdown(u); int mid = l + r >> 1;
if(ql <= mid) qcov(ls, l, mid, ql, qr, x);
if(mid < qr) qcov(rs, mid + 1, r, ql, qr, x);
return upd(u), void(0);
}
inline int ask_mx(int u, int l, int r, int ql, int qr) {
if(ql <= l && r <= qr) return tr[u].mx;
pushdown(u); int mid = l + r >> 1, res = -inff;
if(ql <= mid) res = max(res, ask_mx(ls, l, mid, ql, qr));
if(mid < qr) res = max(res, ask_mx(rs, mid + 1, r, ql, qr));
return res;
}
inline int ask_hmx(int u, int l, int r, int ql, int qr) {
if(ql <= l && r <= qr) return tr[u].hmx;
pushdown(u); int mid = l + r >> 1, res = -inff;
if(ql <= mid) res = max(res, ask_hmx(ls, l, mid, ql, qr));
if(mid < qr) res = max(res, ask_hmx(rs, mid + 1, r, ql, qr));
return res;
}
} t;
int main() {
ios :: sync_with_stdio(false); cin.tie(0); cout.tie(0);
cin >> n;
for(int i = 1; i <= n; i++) cin >> a[i]; t.build(1, 1, n);
cin >> q;
while(q--) {
char op; int l, r, x; cin >> op >> l >> r; if(op == 'P' || op == 'C') cin >> x;
if(op == 'Q') cout << t.ask_mx(1, 1, n, l, r) << endl;
if(op == 'A') cout << t.ask_hmx(1, 1, n, l, r) << endl;
if(op == 'P') t.qadd(1, 1, n, l, r, x);
if(op == 'C') t.qcov(1, 1, n, l, r, x);
}
return 0;
}

浙公网安备 33010602011771号