P10235 [yLCPC2024] C. 舞萌基本练习 题解
题目传送门
大致题意:
多组测试数据,每组数据给定一个长度为 \(n\) 的序列和一个参数 \(k\),要求将此区间划分成不超过 \(k\) 段,使这些区间中的逆序对数量的最大值最小。
思路:
对于求“最大值最小”这类问题,很容易想到二分。显然,本题的答案是具有单调性的,即当划分的段数减少时,区间中逆序对数量的最大值只会增大不会减小。
所以我们考虑将二分答案转化为二分判定,具体为:我们二分一个 \(limit\) 值,表示当前的解,然后将当前这个 \(limit\) 代入计算。
若当前解合法,则说明区间 \([limit,r]\) 的解肯定都是合法的,因为此时 \(limit\) 也可能作为最后答案,所以令 \(r = mid\),向左扩展答案即可。
若当前解不合法,则说明区间 \([l,limit]\) 的解肯定都不合法,所以此时令 \(l = mid + 1\) 向右扩展答案即可。
二分的问题解决了,那么怎么判断这个解是否合法呢?
一个很简单的想法:扫描整个序列,一开始整个序列只有一段,将序列中的数一个一个地加入这个段中,当此段逆序对数量大于 \(limit\) 时,就重新开辟新的一段并使段数 \(cnt\) 增加 \(1\)。将整个序列划分完后,若 \(cnt \le k\),则此解合法,否则不合法。
对于求逆序对数量,用树状数组可以很好解决。
建立一个权值树状数组,每次在某段末尾加入一个数时,只需计算该段中大于它的数的个数,这就是新增的逆序对数。
然而这里要注意两个点:
- \(-10^9 \le a_i \le 10^9\),所以需要离散化;
- 在重新开辟一段时,之前那段的数要全部从树状数组中抹去,这样才能让它正确地求出后面段的逆序对数。
\(\texttt{Some Tips}\):
由于长度为 \(n\) 的序列最大逆序对数量为 \(\frac{n(n-1)}{2}\),所以我的二分上界取了 \(10^{10}\)。
在每次二分中,序列中的所有数只会进入一次、出一次树状数组,所以 \(\operatorname{check()}\) 的时间复杂度为 \(O(n\log n)\)。
整个程序时间复杂度 \(O(n\log n\log 10^{10})\),空间复杂度 \(O(n)\)。
\(\texttt{Code}:\)
#include <iostream>
#include <vector>
#include <algorithm>
#include <cstdio>
#define lowbit(x) x & -x
using namespace std;
const int N = 100010;
int T, n, k;
int a[N], c[N];
int nums[N];
int tt;
int fnd[N];
int find(int x) {
return lower_bound(nums + 1, nums + tt + 1, x) - nums;
}
int ask(int x) {
int res = 0;
for(; x; x -= lowbit(x)) res += c[x];
return res;
}
void add(int x, int y) {
for(; x <= n; x += lowbit(x)) c[x] += y;
}
bool check(long long limit) {
int cnt = 1; //段数
long long f = 0; //目前处理的段的逆序对数
int L = 1; //目前处理的段的左端点
for(int i = 1; i <= n; i++) {
int tmp = ask(tt) - ask(fnd[i]); //计算新增的逆序对数
if(f + tmp > limit) {
cnt++; //段数 + 1
f = 0; //重置逆序对数
for(int j = L; j <= i - 1; j++)
add(fnd[j], -1); //清除上一区间的贡献
L = i; //更新左端点
}
else f += tmp;
add(fnd[i], 1); //加入树状数组
}
for(int i = L; i <= n; i++) add(fnd[i], -1); //不要忘了最后一段也要抹去
return cnt > k;
}
int main() {
scanf("%d", &T);
while(T--) {
scanf("%d%d", &n, &k);
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
nums[++tt] = a[i];
}
sort(nums + 1, nums + tt + 1);
tt = unique(nums + 1, nums + tt + 1) - nums - 1;
for(int i = 1; i <= n; i++) fnd[i] = find(a[i]);
long long l = 0, r = 1e10;
while(l < r) {
long long mid = l + r >> 1;
if(check(mid)) l = mid + 1;
else r = mid;
}
printf("%lld\n", l);
for(int i = 1; i <= tt; i++) nums[i] = 0;
tt = 0;
}
return 0;
}

浙公网安备 33010602011771号