P1903 【模板】带修莫队 / [国家集训队] 数颜色 / 维护队列 题解 树套树(线段树 套 splay tree)
特别鸣谢
感谢 CuteMurasame 大佬 帮我指出了 splay tree 的 get_rnk 函数忘了 splay 的问题。
我的思路大致是这样的(这一部分显示暴力维护一个 \(pre_i\)):
我们用 \(c_i\) 表示第 \(i\) 支画笔的颜色。
\(pre_i\) 表示第 \(i\) 支画笔前面的画笔中和它最近的画笔的位置,也就是说:
- \(pre_i\) 是所有满足 \(1 \le j \lt i\) 且 \(c_j = c_i\) 的最大的下标 \(j\)。
当然,有可能第 \(i\) 支画笔前面没有和它同一种颜色的画笔,此时我们令 \(pre_i = 0\)。
我们可以给每一个颜色都开一个 set,继而维护 \(pre_i\) 的信息。
对于每次查询 Q L R,因为每种颜色都只要算一种,所以我们只考虑 \([L, R]\) 内每种颜色第一次出现:
- 如果 \(pre_i \lt L\),则颜色 \(c_i\) 是 \([L, R]\) 内第一次出现 \(c_i\) 这个颜色;
- 否则,因为 \(L \le j = pre_i \le R\),所以 \(c_j\) 比 \(c_i\) 早出现,\(c_i\) 不是在 \([L, R]\) 内第一次出现这种颜色,就可以不算入答案。
所以此时问题就变成了求解:
- 存在多少个 \(i\) 满足 \(L \le i \le R\) 且 \(c_i \lt L\)。
暴力写法时间复杂度 \(O(mn)\),可以拿 12/13(这题数据没那么严,只有最后一组卡数据 TLE 了)。
暴力程序如下(后续的逻辑仍然会使用到这部分逻辑):
#include <bits/stdc++.h>
using namespace std;
const int maxn = 133333 + 5, maxm = 1e6 + 5;
int n, m, c[maxn], pre[maxn];
set<int> st[maxm];
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", c+i);
auto it = st[ c[i] ].insert(i).first;
if (it != st[ c[i] ].begin()) {
it--;
pre[i] = *it;
}
}
while (m--) {
char op[2];
scanf("%s", op);
if (op[0] == 'Q') {
int l, r, cnt = 0;
scanf("%d%d", &l, &r);
for (int i = l; i <= r; i++)
if (pre[i] < l)
cnt++;
printf("%d\n", cnt);
}
else { // op[0] == 'R'
int p, x;
scanf("%d%d", &p, &x);
if (x == c[p]) continue;
auto it = st[ c[p] ].find(p), it2 = it;
it2++;
if (it2 != st[ c[p] ].end()) {
if (it == st[ c[p] ].begin())
pre[*it2] = 0;
else {
it--;
pre[*it2] = *it;
}
}
st[ c[p] ].erase(p);
c[p] = x;
it = it2 = st[x].insert(p).first;
if (it != st[x].begin()) {
it--;
int q = *it;
pre[p] = q;
}
else
pre[p] = 0;
it2++;
if (it2 != st[x].end()) {
int q = *it2;
pre[q] = p;
}
}
}
return 0;
}
如果使用 树套树(线段树 套 平衡树) 的话,对于线段树上每一个包含在查询的 \([L, R]\),它都有一棵对应的平衡树。
这些平衡树都记录的是 \(pre_i\) 的信息。
我们可以 \(O(\log n)\) 查找平衡上有多少个点 \(\lt L\)。
总时间复杂度 \(O(n \log ^ 2 n)\)。
树套树 代码(线段树 套 splay tree):
#include <bits/stdc++.h>
using namespace std;
const int maxn = 133333 + 5, maxm = 1e6 + 5, inf = 1e9;
int n, m, c[maxn], pre[maxn];
set<int> st[maxm];
struct SplayTree {
struct Node {
int s[2], v, p, sz;
Node() {}
Node(int _v, int _p) {
s[0] = s[1] = 0;
v = _v;
p = _p;
sz = 1;
}
} tr[maxn*500];
int idx;
void push_up(int u) {
int l = tr[u].s[0], r = tr[u].s[1];
tr[u].sz = tr[l].sz + tr[r].sz + 1;
}
void f_s(int p, int u, int k) {
tr[p].s[k] = u;
tr[u].p = p;
}
void rot(int x) {
int y = tr[x].p, z = tr[y].p;
int k = tr[y].s[1] == x;
f_s(z, x, tr[z].s[1] == y);
f_s(y, tr[x].s[k^1], k);
f_s(x, y, k^1);
push_up(y), push_up(x);
}
void splay(int &root, int x, int k) {
while (tr[x].p != k) {
int y = tr[x].p, z = tr[y].p;
if (z != k)
(tr[y].s[1]==x)^(tr[z].s[1]==y) ? rot(x) : rot(y);
rot(x);
}
if (!k) root = x;
}
int get_rnk(int &root, int x) {
int cnt = 0, u = root, p = 0;
while (u) {
p = u;
if (tr[u].v < x) {
cnt += 1 + tr[ tr[u].s[0] ].sz;
u = tr[u].s[1];
}
else
u = tr[u].s[0];
}
if (p) splay(root, p, 0);
return cnt - 1; // 减去一个哨兵节点
}
void ins(int &root, int v) {
int u = root, p = 0, k = 0;
while (u) {
tr[u].sz++;
k = tr[u].v < v;
p = u;
u = tr[u].s[k];
}
u = ++idx;
tr[u] = Node(v, p);
if (p) tr[p].s[k] = u;
splay(root, u, 0);
}
void del(int &root, int v) {
int u = root;
while (u) {
if (tr[u].v == v) break;
else u = tr[u].s[tr[u].v < v];
}
splay(root, u, 0);
int l = tr[u].s[0], r = tr[u].s[1];
if (!l || !r) {
root = l + r;
tr[root].p = 0;
}
else {
while (tr[l].s[1]) l = tr[l].s[1];
splay(root, l, 0);
tr[l].s[1] = r;
tr[r].p = l;
push_up(l);
}
}
} splay_t;
struct SegmentTree {
int tr[maxn<<2];
#define lson l, mid, u<<1
#define rson mid+1, r, u<<1|1
void build(int l, int r, int u) {
splay_t.ins(tr[u], -inf);
splay_t.ins(tr[u], inf);
for (int i = l; i <= r; i++)
splay_t.ins(tr[u], pre[i]);
if (l == r) return;
int mid = l + r >> 1;
build(lson);
build(rson);
}
// pre[p] 删去 x,插入 y
void update(int p, int x, int y, int l, int r, int u) {
splay_t.del(tr[u], x);
splay_t.ins(tr[u], y);
if (l == r) return;
int mid = l + r >> 1;
(p <= mid) ? update(p, x, y, lson) : update(p, x, y, rson);
}
int query(int L, int R, int l, int r, int u) {
if (L <= l && r <= R)
return splay_t.get_rnk(tr[u], L);
int res = 0, mid = l + r >> 1;
if (L <= mid)
res += query(L, R, lson);
if (R > mid)
res += query(L, R, rson);
return res;
}
} seg_t;
int main() {
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%d", c+i);
auto it = st[ c[i] ].insert(i).first;
if (it != st[ c[i] ].begin()) {
it--;
pre[i] = *it;
}
}
seg_t.build(1, n, 1);
while (m--) {
char op[2];
scanf("%s", op);
if (op[0] == 'Q') {
int l, r;
scanf("%d%d", &l, &r);
printf("%d\n", seg_t.query(l, r, 1, n, 1));
}
else { // op[0] == 'R'
int p, x;
scanf("%d%d", &p, &x);
if (x == c[p]) continue;
auto it = st[ c[p] ].find(p), it2 = it;
it2++;
if (it2 != st[ c[p] ].end()) {
if (it == st[ c[p] ].begin()) {
seg_t.update(*it2, pre[*it2], 0, 1, n, 1);
pre[*it2] = 0;
}
else {
it--;
seg_t.update(*it2, pre[*it2], *it, 1, n, 1);
pre[*it2] = *it;
}
}
st[ c[p] ].erase(p);
c[p] = x;
it = it2 = st[x].insert(p).first;
if (it != st[x].begin()) {
it--;
int q = *it;
seg_t.update(p, pre[p], q, 1, n, 1);
pre[p] = q;
}
else {
seg_t.update(p, pre[p], 0, 1, n, 1);
pre[p] = 0;
}
it2++;
if (it2 != st[x].end()) {
int q = *it2;
seg_t.update(q, pre[q], p, 1, n, 1);
pre[q] = p;
}
}
}
return 0;
}
浙公网安备 33010602011771号