CDQ分治 学习笔记
我们先回顾一下逆序对问题:
在本题中,我们需要求所有 \(i < j\) 且 \(a_i > a_j\) 的对数。我们一般有两种方法:一种是树状数组求解,一种是归并求解。其实,令 \(a_i = i\),\(b_i =\) 输入的数字,本题实际上就是求满足 \(a_i < a_j\) 且 \(b_i > b_j\) 的对数,这是一道典型的二维偏序问题,只不过默认了 \(a_i = i\)。
想一下我们是怎么求解的:我们一个一个地把 \(b_i\) 插入到树状数组里,设当前插入到了第 \(i\) 个位置,实际上就是满足了 \(a_j < a_i\)。此时我们再利用树状数组求所有已插入的 \(b_j\) 中大于 \(b_i\) 的数量,由于前提已经满足了 \(a_j < a_i\),所以此时所有 \(b_j > b_i\) 一定都满足条件。反过来,若是要求所有 \(a_i < a_j, b_i < b_j\) 的对数(即顺序对),只需在统计答案时稍作改变即可。
但如果我们需要求关于 \(a_i < a_j, b_i < b_j, c_i < c_j\) 的问题呢?这便是 \(\text{CDQ}\) 分治能解决的经典问题,即三维偏序问题。
题意:设 \(f_i\) 表示 \(\displaystyle\sum_{j = 1, j \neq i}^n [a_j \leq a_i, b_j \leq b_i, c_j \leq c_i]\),对于每个 \(d \in [0, n)\),我们需要求出 \(f_i = d\) 的数量。
我们运用刚才的思想,我们首先以 \(a_i\) 为第一关键字,\(b_i\)为第二关键字,\(c_i\)为第三关键字升序排序。注意到若对于两个不同的 \(i\),三个数有可能完全相同,我们需要额外处理一下,稍后再谈。下设 \(cnt_i\) 表示 \((a_i, b_i, c_i)\) 出现的次数,\(ans_i\) 表示第 \(i\) 个数对的答案。显然地,此时对于每个 \(i\),对 \(f_i\) 的贡献只可能来自于 \(j < i\)。
我们考虑归并:对于一个区间 \([l, r]\),我们先分别计算 \([l, mid]\) 及 \([mid + 1, r]\) 中可能产生的贡献。此时两个区间都是按照 \(b_i\) 升序排序的,不一定按照 \(a_i\) 了,因为无论左边和右边的 \(a_i\) 多大,由于前面已经排序过了,所以只要左边的 \(b_i\) 和 \(c_i\) 满足条件,就一定可以对右边产生贡献,而不用考虑 \(a_i\)。而原本的 \(a_i\) 只对原本的小区间有影响,但原本的小区间里产生的贡献已经算过了,所以此时的 \(a_i\) 便不重要了。然后,由于只有 \([l, mid]\) 可能对 \([mid + 1, r]\) 产生贡献,我们按照 \(b_i\) 像归并排序一样升序排序,但要额外统计:
- 当 \(b_{lptr} \leq b_{rptr}\) 时,我们把 \(c_{lptr}\) 在树状数组中加上 \(cnt_{lptr}\)(显然地,一个对数合法了,相同的也一定合法),并增加 \(lptr\);
- 当 \(b_{lptr} > b_{rptr}\) 时,我们对 \(ans_{rptr}\) 加上树状数组中小于 \(c_{rptr}\) 的前缀和。由于对 \(b_i\) 进行排序,所以此时加上的一定是合法的数量。
排序完后,我们需要再将原本的树状数组每个 \(c_{l \leq i \leq mid}\) 中减去对应的 \(cnt_i\),这是为了消除这个区间的影响,因为归并到大区间及求之后的小区间还需要用到这个树状数组,若不减去则会重复计算。
对整个区间归并后,每个 \(ans_i\) 就是对第 \(i\) 个数对不算入这个数对的重复而有算入其他数对的重复的答案了。于是,对于每个 \(ans_i\),应该对 \(f_{ans_i + cnt_i - 1}\) 加上 \(cnt_i\)。
不清楚的可以看代码理解。
AC Code:
#include <bits/stdc++.h>
#define lowbit(x) ((x) & -(x))
#define mid (l + r >> 1)
using namespace std;
struct P {
int a, b, c, id, cnt;
const bool operator<(const P &rhs) const {
if (a != rhs.a)
return a < rhs.a;
else if (b != rhs.b)
return b < rhs.b;
else
return c < rhs.c;
}
const bool operator==(const P &rhs) const {
return a == rhs.a && b == rhs.b && c == rhs.c;
}
};
int n, k, tot, ans[100005], f[100005], sum[200005];
P p[100005];
map<P, int> ma;
void add(int id, int d) {
while (id <= k) {
sum[id] += d;
id += lowbit(id);
}
}
int query(int id) {
int res = 0;
while (id) {
res += sum[id];
id -= lowbit(id);
}
return res;
}
void mergeSort(int l, int r) {
if (l == r)
return;
mergeSort(l, mid), mergeSort(mid + 1, r);
int lptr = l, rptr = mid + 1, t = 0;
P *tmp = new P[r - l + 2];
while (lptr <= mid && rptr <= r)
if (p[lptr].b > p[rptr].b) {
tmp[++t] = p[rptr];
ans[p[rptr].id] += query(p[rptr].c);
++rptr;
} else {
tmp[++t] = p[lptr];
add(p[lptr].c, p[lptr].cnt);
++lptr;
}
while (lptr <= mid) {
tmp[++t] = p[lptr];
add(p[lptr].c, p[lptr].cnt);
++lptr;
}
while (rptr <= r) {
tmp[++t] = p[rptr];
ans[p[rptr].id] += query(p[rptr].c);
++rptr;
}
for (int i = l; i <= mid; ++i)
add(p[i].c, -p[i].cnt);
for (int i = l; i <= r; ++i)
p[i] = tmp[i - l + 1];
delete tmp;
}
int main() {
scanf("%d %d", &n, &k);
for (int i = 1; i <= n; ++i) {
++tot;
scanf("%d %d %d", &p[tot].a, &p[tot].b, &p[tot].c);
if (ma[p[tot]]) {
++p[ma[p[tot]]].cnt;
--tot;
} else {
ma[p[tot]] = tot;
p[tot].cnt = 1;
p[tot].id = tot;
}
}
sort(p + 1, p + tot + 1);
mergeSort(1, tot);
for (int i = 1; i <= tot; ++i)
f[ans[p[i].id] + p[i].cnt - 1] += p[i].cnt;
for (int i = 0; i < n; ++i)
printf("%d\n", f[i]);
return 0;
}