ARC194C - Cost to Flip 题解
首先把所有的物品分类,可以得到四种不同的种类:
- \(0 \to 0\),不用管
- \(0 \to 1\),需要变为 \(1\)
- \(1 \to 0\),需要变为 \(0\)
- \(1 \to 1\),有两种情况:
- 不改变
- \(1 \to 0 \to 1\)
因此,考虑一个显然的贪心,从大到小先翻转需要 \(1 \to 0\) 的比特,再从小到大翻转需要 \(0 \to 1\) 的比特,但是 \(1 \to 1\) 怎么办呢?
我们发现我们可以贪心的从大到小逐渐插入 \(1 \to 0 \to 1\) 的部分,也就是说如果有一个 \(1 \to 1\) 需要被翻转两次,只需要把这个比特同时加入两个组 \(1 \to 0\) 和 \(0 \to 1\) 内即可。
通过前缀和操作不难想到线性做法。
参考代码(可参照注释)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<ll, ll> PII;
const int N = 2e5 + 10;
int n, n1, n2, n3;
struct Node {
ll a, b, c;
}bit[N];
ll zero_to_one[N], one_to_zero[N], is_one[N];
ll pre1[N], pre2[N];
bool cmp(Node x, Node y) {
if (x.c == y.c) return x.a < y.a;
return x.c < y.c;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
cin >> n;
for (int i = 1; i <= n; i ++ ) cin >> bit[i].a;
for (int i = 1; i <= n; i ++ ) cin >> bit[i].b;
for (int i = 1; i <= n; i ++ ) cin >> bit[i].c;
for (int i = 1; i <= n; i ++ ) {
// 1->0
if (bit[i].b < bit[i].a) one_to_zero[ ++ n1] = bit[i].c;
// 0->1
if (bit[i].a < bit[i].b) zero_to_one[ ++ n2] = bit[i].c;
// 1->1
if (bit[i].a && bit[i].b) is_one[ ++ n3] = bit[i].c;
}
// 1->0
sort(one_to_zero + 1, one_to_zero + n1 + 1);
// 0->1
sort(zero_to_one + 1, zero_to_one + n2 + 1);
// 1->0->1, we should first choose bigger one
sort(is_one + 1, is_one + n3 + 1);
for (int i = 1; i <= n1; i ++ ) pre1[i] = pre1[i - 1] + one_to_zero[i];
for (int i = 1; i <= n2; i ++ ) pre2[i] = pre2[i - 1] + zero_to_one[i];
// res means probable-answer, ans means real-answer, sum_one_one means the sum of is_one
ll res = 0, ans = 0, sum_one_one = 0;
// if we don't make 1->0->1, any 1 will be add for each time
for (int i = 1; i <= n3; i ++ ) sum_one_one += is_one[i];
res += (n1 + n2) * sum_one_one;
// we first make 1->0, bigger one should be flipped first
// caution! don't add the biggest one
for (int i = 1; i < n1; i ++ ) res += pre1[i];
// then make 0->1, smaller one should be flipped first
for (int i = 1; i <= n2; i ++ ) res += pre2[i];
ans = res;
// now we consider trying to make some 1->0->1 from 1 to 0 and check if it has lower answer
// let l be a point of zero_to_one, r be a point of one_to_zero
int l = n1, r = n2;
// if we choose c[i] which is 1->0->1, we should do as what we do before to make 1->0 and 0->1
for (int i = n3; i; i -- ) {
// the smaller elements in one_to_zero will be add again
while (l && one_to_zero[l] > is_one[i]) l -- ;
// the smaller elements in zero_to_one will be add again
while (r && zero_to_one[r] > is_one[i]) r -- ;
// remove is_one[i]
sum_one_one -= is_one[i];
// is_one[i] should be added n1+n2+2*(n3-i) times, but now n1-l+n2-r+1+2*(n3-i) times
res -= is_one[i] * (l + r - 1);
// rest is_one will be added 2 more times
res += sum_one_one * 2;
// smaller 1->0 will be added 1 more time, smaller 0->1 will be added 1 more time
res += pre1[l] + pre2[r];
// remember update ans
ans = min(ans, res);
}
cout << ans << "\n";
return 0;
}