算法学习笔记(11):历史版本和线段树
历史版本和线段树
功能
可以累计所有历史版本的答案, 可以解区间子区间问题
算法
先上题目: Good Subsegments
题意: 求解 \((l, r)\) 区间里“好的”子区间个数。 “好的”定义为区间内元素排序后是连续的。
思路: 考虑没有重复的数, 所以当一个区间是“好的”相当于\(max - min + 1 = len\)。
考虑这个式子肯定可以用数据结构维护, 于是考虑扫描线将 \((l, r)\) 消掉一维。
从小到大枚举 \(r_i\) , 对于每一个右端点等于 \(r_i\) 的询问区间 \((l_j, r_i)\) 我们需要统计该区间内所有子区间是否为“好的”。
for (int r = lj; r <= ri; r++) {
for (int l = lj; l <= r; l++) {
}
}
看这个代码应该知道是哪些子区间了吧, 现在设想如果我们有一种数据结构很快求出每次外层for循环的答案, 即取代内层for循环。
并且可以累计贡献, 且支持控制范围 \((l, r)\) 。就可以解决问题了!!!
所以引出历史版本和线段树。
定义线段树第 \(i\) 叶子节点存的值为 \(i, r_i\) 的 \(max - min - len\)
Q: 为什么是 \(max - min - len\) 而不是 \(max - min + 1\)呢?
A: 可以把前面的式子化简为 \(max - min - len = -1\) , 并且考虑每个区间的 \(max - min - len\) 必大于等于 -1, 且长度为一的区间值一定为1, 所以考虑线段树维护 \(max - min - len\) 的最小值个数, 就是“好的”区间个数。
再考虑如何修改, \(r_i\)每增加1,线段树区间减, \(max\) 和 \(min\) 可以用单调栈找到每个值对于哪些区间的 \(max\) 或 \(min\)有贡献, 也是区间加。
现在考虑线段树怎么写:
有大力分讨打标记的写法, 还有吉老师的方法, 还有就是矩阵乘法, 其实是殊途同归。 这里讲一下很好理解的矩阵乘法写法。
考虑我们的标记很多比较复杂, 同时打标记必须要满足结合律, 也就是我一直打某些标记不下传, 下传之后仍要保证正确。 那我们考虑矩阵乘法的结合律, 这里这道题我们维护 \(his, len, sum\) 三个标记表示历史版本和, 区间长度, 区间和。 那么区间减, 区间加, 历史版本和操作都可以推出一个矩阵, 这里麻烦读者自己推出。
Q: 矩阵乘法太慢怎么办?
A: 考虑自己分讨打标记的过程, 一定和矩阵乘法的过程一样, 这只是不同的理解罢了, 所以我们可以观察矩阵的性质, 将矩阵乘法手动化简。 比如只维护上三角等等。
struct SegT{
struct Node{ int l, r, mn, cmn, add, tim, ans; }t[N << 2];
#define ls (p << 1)
#define rs (p << 1 | 1)
#define mid (t[p].l + t[p].r >> 1)
void upd_mn(int p, int z) { t[p].mn += z, t[p].add += z; }
void upd_tim(int p, int z) { t[p].ans += t[p].cmn * z, t[p].tim += z; }
void pushdown(int p) {
if (t[p].add) upd_mn(ls, t[p].add), upd_mn(rs, t[p].add), t[p].add = 0;
if (t[p].tim) {
if (t[ls].mn == t[p].mn) upd_tim(ls, t[p].tim);
if (t[rs].mn == t[p].mn) upd_tim(rs, t[p].tim);
t[p].tim = 0;
}
}
void pushup(int p) {
t[p].mn = min(t[ls].mn, t[rs].mn), t[p].cmn = 0;
if (t[p].mn == t[ls].mn) t[p].cmn += t[ls].cmn;
if (t[p].mn == t[rs].mn) t[p].cmn += t[rs].cmn;
t[p].ans = t[ls].ans + t[rs].ans;
}
void build(int p, int l, int r) {
t[p].l = l, t[p].r = r, t[p].mn = l, t[p].cmn = 1;
if (l == r) return;
build(ls, l, mid); build(rs, mid + 1, r);
}
void modify(int p, int x, int y, int z) {
if (x <= t[p].l && t[p].r <= y) return upd_mn(p, z), void();
pushdown(p);
if (x <= mid) modify(ls, x, y, z);
if (y > mid) modify(rs, x, y, z);
pushup(p);
}
int query(int p, int x, int y) {
if (x <= t[p].l && t[p].r <= y) return t[p].ans;
pushdown(p);
int res = 0;
if (x <= mid) res += query(ls, x, y);
if (y > mid) res += query(rs, x, y);
pushup(p);
return res;
}
}T;
拓展
考虑如果数可以重复, 且“好的”区间定义为区间内元素排序去重后, 是连续的。
式子变为 \(max - min + 1 = cnt\) \(cnt\) 为区间内不相等的个数。
这样照样可以维护, \(r_i\)增加1时, 若上一次该元素出现在 pos$,则只把 \((pos, r)\) 区间减一就行了。
题目
P9990 [Ynoi Easy Round 2023] TEST_90
这道更简单, 还是扫描线, 我们只讲怎么打标记。 不能像之前那样直接维护 \(tim\) 了, 之前直接维护 \(tim\), 是因为从上一次下传标记到现在, 区间加操作都是针对整个区间的, 所以这个区间的最小值个数不会改变, \(ans\) 的增量一直是 \(cmn\)
但是这道题, 会区间反转, 从上一次下传标记到现在, 每一次答案累加的增量都是当前的 \(sum\), \(sum\) 是变化的, 所幸\(sum\) 的变化只有一种。 所以我们维护 \(tim0\), \(tim1\) 表示从上一次下传标记到现在 \(sum\) 的累加次数, 和 \(len - sum\)的累加次数。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e6 + 10;
int n, m, a[N];
struct qry{
int l, r, id;
bool operator < (const qry &x) const {
return r < x.r;
}
}q[N];
struct ST{
struct Node{
int l, r, sum, ans, tim0, tim1, tag;
}t[N << 2];
#define ls (p << 1)
#define rs (p << 1 | 1)
#define mid (t[p].l + t[p].r >> 1)
void upd_ans(int p, int z0, int z1) {
t[p].ans += t[p].sum * z1 + (t[p].r - t[p].l + 1 - t[p].sum) * z0;
t[p].tim0 += z0; t[p].tim1 += z1;
}
void upd_sum(int p) {
t[p].sum = t[p].r - t[p].l + 1 - t[p].sum;
swap(t[p].tim0, t[p].tim1);
t[p].tag ^= 1;
}
void pushdown(int p) {
if (t[p].tag) {
upd_sum(ls); upd_sum(rs);
t[p].tag = 0;
}
if (t[p].tim0 || t[p].tim1) {
upd_ans(ls, t[p].tim0, t[p].tim1);
upd_ans(rs, t[p].tim0, t[p].tim1);
t[p].tim0 = t[p].tim1 = 0;
}
}
void pushup(int p) {
t[p].sum = t[ls].sum + t[rs].sum;
t[p].ans = t[ls].ans + t[rs].ans;
}
void build(int p, int l, int r) {
t[p].l = l, t[p].r = r;
if (l == r) return;
build(ls, l, mid); build(rs, mid + 1, r);
}
void modify(int p, int x, int y) {
if (x <= t[p].l && t[p].r <= y) return upd_sum(p), void();
pushdown(p);
if (x <= mid) modify(ls, x, y);
if (y > mid) modify(rs, x, y);
pushup(p);
}
int query(int p, int x, int y) {
if (x <= t[p].l && t[p].r <= y) return t[p].ans;
pushdown(p);
int res = 0;
if (x <= mid) res += query(ls, x, y);
if (y > mid) res += query(rs, x, y);
return res;
}
}T;
int lst[N], ans[N];
signed main() {
scanf("%lld", &n);
for (int i = 1; i <= n; i++) scanf("%lld", &a[i]);
scanf("%lld", &m);
for (int i = 1, l, r; i <= m; i++)
scanf("%lld%lld", &l, &r), q[i] = {l, r, i};
sort(q + 1, q + 1 + m);
T.build(1, 1, n);
for (int i = 1, k = 1; i <= n; i++) {
T.modify(1, lst[a[i]] + 1, i);
lst[a[i]] = i;
T.upd_ans(1, 0, 1);
while (q[k].r == i && k <= m)
ans[q[k].id] = T.query(1, q[k].l, i), k++;
}
for (int i = 1; i <= m; i++) printf("%lld\n", ans[i]);
return 0;
}

浙公网安备 33010602011771号