CF1416C XOR Inverse
个人记录第十九篇。
题目描述
给定长度为 \(n\) \((1\le n\le3\times 10^5)\) 的数列 \(\{a_n\}\) \((0\le a_n\le 10^9)\),请求出最小的整数 \(x\) 使 \(\{a_n\oplus x\}\) 的逆序对数最少,其中 \(\oplus\) 是按位异或。
输出最少逆序对数和最小的 \(x\)。
Solution
看到数列还有异或就下意识想到 0/1 trie。
首先把每一个数都插进 0/1 trie 里。由于要求逆序对,所以我们也还要记录当前这个数的下标。由于逆序对这个东西是对所有数都有效的(就是某个数不会只针对某个数产生逆序对),所以如果 trie 上 \(p\) 节点被当前的数经过,那么我们就在 \(p\) 上插入这个数的下标。
现在来看怎么算最少的逆序对数。
由于你是在一棵 trie 树上统计,所以可以考虑在 trie 上树形 dp。显然,对于一个节点 \(p\),其左子树的值小于右子树,所以我们只需要找到左子树里有多少个数的下标大于右子树里的。令这个个数为 \(c_p\)。设 \(f_{i,0/1}\) 表示是否在第 \(i\) 个二进制位进行异或(也就是 \(x\) 的第 \(i\) 个二进制位是否填 \(1\)),那么方程容易得出:
\[f_{i,0}=\sum c_p
\]
\[f_{i,1}=\sum siz_{l}\cdot siz_{r}-c_p
\]
其中 \(siz_i\) 表示 \(i\) 的子树内下标的个数。因为你在异或之后 trie 上所有逆序对都变成了顺序对,所有顺序对都变成了逆序对,所以贡献是 \(siz_l\cdot siz_r-c_p\)。
最后求答案的时候就是遍历每一个二进制位,贪心的去选择更优的的方案。
Code
//蒟蒻一枚 RP++
#include <bits/stdc++.h>
#define Mem(a,b) memset((a),(b),sizeof((a)))
#define eb emplace_back
#define pb push_back
using namespace std;
using i64 = long long;
using uint = unsigned int;
using ui64 = unsigned long long;
using i128 = __int128;
constexpr int N = 3e5 + 15, mod = 998244353;
constexpr int inf = 1e9;
constexpr double eps = 1e-6, PI = acos (-1);
namespace FAST_IO {
#define IOSIZE 1048576
char ibuf[IOSIZE], obuf[IOSIZE], *XXH1 = ibuf, *XXH2 = ibuf, *XXH3 = obuf;
#define getchar() ((XXH1==XXH2)&&(XXH2=(XXH1=ibuf)+fread(ibuf,1,IOSIZE,stdin),XXH1==XXH2)?(EOF):(*XXH1++))
#define putchar(x) ((XXH3==obuf+IOSIZE)&&(fwrite(obuf,XXH3-obuf,1,stdout),XXH3=obuf),*XXH3++=x)
#define isdigit(ch) (ch>47&&ch<58)
#define isspace(ch) (ch<33)
template <typename T> inline void read (T &x) { x = 0; T f = 1;char ch = getchar ();while (!isdigit (ch)) {if (ch == '-') f = -1; ch = getchar ();}while (isdigit (ch)) {x = x * 10 + (ch ^ '0'); ch = getchar ();} x *= f;}
template <> inline void read (double &x) { x = 0; int f = 1;char ch = getchar ();while (!isdigit (ch)) { if (ch == '-') f = -1; ch = getchar ();} while (isdigit (ch)) x = x * 10 + (ch - '0'), ch = getchar ();if (ch == '.') {ch = getchar (); for (double t = 0.1; isdigit (ch); t *= 0.1) x += t * (ch - '0'), ch = getchar ();}x *= f;}
inline void read(char &x) { char ch;while ((ch = getchar()) != EOF && isspace(ch));if (ch != EOF) x = ch; }
inline bool read(char *s) { char ch; while (ch = getchar(), isspace(ch)); if (ch == EOF) return false; while (!isspace(ch)) *s++ = ch, ch = getchar(); *s = '\0'; return true; }
inline bool read(string& s) { s = ""; char ch; while (ch = getchar(), isspace(ch)); if (ch == EOF) return false; while (!isspace(ch)) s += ch, ch = getchar(); return true; }
template <typename T, typename ...Args> inline void read (T &x, Args &...args) {read(x); read(args...);}
template <typename T> inline void write (T x) { if (x < 0) putchar ('-'), x = -x;static short stk[50], top(0);do stk[++ top] = x % 10, x /= 10;while (x);while (top)putchar (stk[top --] | 48);}
inline void write(string x) { for (int i = 0, n = x.size(); i < n; i++) putchar(x[i]); }
inline void write(char *x) { while (*x) putchar(*x++); }
inline void write(const char *x) { while (*x) putchar(*x++); }
inline void write(char x) { putchar(x); }
struct fio { ~fio () {if (XXH3 != obuf) { fwrite(obuf, 1, XXH3 - obuf, stdout);XXH3 = obuf;}}} io;
template <typename T> fio& operator >> (fio &io, T &x) {return read (x), io;}
template <typename T> fio& operator << (fio &io, const T &x) {return write (x), io;}
#define cin io
#define cout io
#define endl '\n'
}using namespace FAST_IO;
bool MS;
int n;
int a[N];
int trie[N * 30][2], idx;
vector <int> ed[N * 30];
i64 dp[50][2];
inline void insert (int x, int id) {
int p = 0;
for (int i = 30; i >= 0; -- i) {
int c = (x >> i) & 1;
if (!trie[p][c]) trie[p][c] = ++ idx;
p = trie[p][c];
ed[p].eb (id);
}
}
void DP (int p, int id) {
int ls = trie[p][0], rs = trie[p][1];
if (!ls && !rs) return ;
if (id < 0) return ;
if (ls) DP (ls, id - 1);
if (rs) DP (rs, id - 1);
int now = 0;
i64 sum = 0;
for (int pos : ed[ls]) {
while (now < ed[rs].size () && pos > ed[rs][now]) ++ now;
sum += now;
}
dp[id][0] += sum;
dp[id][1] += 1ll * ed[ls].size () * ed[rs].size () - sum;
}
bool MT;
int main () {
cin >> n;
for (int i = 1; i <= n; ++ i) {
cin >> a[i];
insert (a[i], i);
}
DP (0, 30);
i64 ans1 = 0, ans2 = 0;
for (int p = 30; p >= 0; -- p) {
if (dp[p][0] < dp[p][1]) {
ans1 += dp[p][0];
} else if (dp[p][0] > dp[p][1]) {
ans1 += dp[p][1];
ans2 |= 1ll << p;
} else ans1 += dp[p][0];
}
cout << ans1 << " " << ans2 << endl;
cerr << "Memory:" << (&MS - &MT) / 1048576.0 << "MB Time:" << clock() << "ms" << endl;
return 0;
}
/*
1. 该开 long long 的地方你开了吗?
2. 你的数组开够了吗?
3. 该取模的地方你取模了吗?
4. 你有没有理解错题意?
5. 想到了再写
*/

浙公网安备 33010602011771号