P1637 题解
一道绿写 2.5 h,我是什么效率哥。
Solution
提供一种不使用线段树 / 树状数组的方法。前置知识:分治,二分,前缀和。
考虑分治。我们假设有一个分治函数 solve(l, r) 可以统计区间 \([l, r]\) 中的 thair。
对于一个区间 \([l, r]\) 中的 thair \(= \{a_i, a_j, a_k | i<j<k\) 且 \(a_i <a_j <a_k\}\),有以下几种情况:
- \(a_i, a_j, a_k\) 均在左半区间 \([l, mid]\) 中,可以通过
solve(l, mid)统计; - \(a_i, a_j, a_k\) 均在右半区间 \([mid + 1, r]\) 中,可以通过
solve(mid + 1, r)统计。 - \(a_i, a_j\) 在左半区间,而 \(a_k\) 在右半。
- \(a_i\) 在左半,\(a_j, a_k\) 在右半。
后两种情况的求解需要对两个子区间中的「顺序对」进行统计。(「顺序对」的定义类比 P1908 逆序对 中的「逆序对」)
定义 \(fst_i\) 表示以 \(a_i\) 为开头的「顺序对」个数,\(scd_i\) 表示以 \(a_i\) 为结尾的「顺序对」个数。给两个子区间升序排序后,我们可以这样操作:
- 枚举右半区间,二分找到左半区间内小于当前枚举的数的最右位置 \(pos\),则 \(ans \gets ans + \Sigma_{i = l}^{pos} scd_i\);同时更新当前位置的 \(scd\),使其加上 \(pos - l + 1\)。
- 对左半区间进行类似操作,即:枚举左半区间,二分找到右半区间内大于当前枚举的数的最左位置 \(pos\),则 \(ans \gets ans + \Sigma_{i = pos} ^ {r} fst_i\);同时更新当前位置的 \(fst\),使其加上 \(r - pos + 1\)。
求 \(\Sigma\) 可以用前缀和优化。
处理完上述操作以后,我们已经统计到了上文中情况 3、4 的答案,此时给该区间升序排序并不会影响结果。所以直接排序。
按照上述步骤分治即可。处理每个分治区间的时间复杂度为 \(O(n \log n)\),总复杂度为 \(O(n \log^2 n)\),可以通过本题。
#include <iostream>
#include <cstdio>
#include <algorithm>
using namespace std;
#define ll long long
const int N = 30010;
int n, a[N];
ll fst[N], scd[N];
struct sss {
int a, fst, scd;
}num[N];
ll s[N];
ll ans;
bool cmp (sss x, sss y) {
return x.a < y.a;
}
void solve(int l, int r) {
if (l >= r) {
fst[l] = scd[l] = 0;
return;
}
int mid = (l + r) >> 1;
solve(l, mid);
solve(mid + 1, r);
//分别枚举右半、左半区间统计答案。先统计再更新 fst、scd 的原因是怕更新后干扰答案统计。
for (int i = l; i <= r; i++) s[i] = 0;
for (int i = l; i <= mid; i++) if (i == l) s[i] = scd[i];
else s[i] = s[i - 1] + scd[i];
for (int i = mid + 1; i <= r; i++) {
int pos = lower_bound(a + l, a + mid + 1, a[i]) - a - 1;
if (pos < l) continue; //小细节,防越界出 bug。
ans += s[pos]; }
for (int i = l; i <= r; i++) s[i] = 0;
for (int i = mid + 1; i <= r; i++)
if (i == mid + 1) s[i] = fst[i];
else s[i] = s[i - 1] + fst[i];
for (int i = l; i <= mid; i++) {
int pos = upper_bound(a + mid + 1, a + r + 1, a[i]) - a;
ans += s[r] - s[pos - 1];
}
//更新 fst、scd。
for (int i = mid + 1; i <= r; i++) {
int pos = lower_bound(a + l, a + mid + 1, a[i]) - a - 1;
scd[i] += pos - l + 1; //这里就不用防了,算算就知道越界之后贡献为 0。
}
for (int i = l; i <= mid; i++) {
int pos = upper_bound(a + mid + 1, a + r + 1, a[i]) - a;
fst[i] += r - pos + 1;
}
//排序。记得连着 fst、scd 一起排。
for (int i = l; i <= r; i++) num[i].a = a[i], num[i].fst = fst[i], num[i].scd = scd[i];
sort(num + l, num + r + 1, cmp);
for (int i = l; i <= r; i++) a[i] = num[i].a, fst[i] = num[i].fst, scd[i] = num[i].scd;
}
int main() {
scanf("%d", &n);
for (int i = 1; i <= n; i++) scanf("%d", &a[i]);
solve(1, n);
printf("%lld", ans);
return 0;
}

浙公网安备 33010602011771号