CF1322B Present(拆位+双指针)
CF1322B Present(拆位+双指针)
题意
给出一个长度为 \(n\) 的数列 \(a\) 。 计算对任意的 $i < j $ ,\((a_i + a_j)\) 的异或和。
思路
涉及到位运算,考虑每一位二进制位。
考虑第 \(k\) 位二进制位,如果这一位对答案有贡献,那就需要对 \(a_i + a_j\) 分析。因为有进位的影响,所以不能像普通拆位那样将 \(a_i\) 和 \(a_j\) 的第 \(k\) 位单独拎出来。我们取出它们 \(mod ~2^{k + 1}\) 后的值。
当 \(a_i + a_j\) 不产生进位时,结果在 \([2^{k},2^{k + 1} - 1]\) 时有贡献。
产生进位时,在 \([3 * 2^k,2^{k + 2} - 2]\) 时有贡献。
显然这两个区间是不相交的,因此不会有算重复的情况,分别计算满足上述条件的数对个数,如果个数为奇数,则对答案有贡献。
那只要对每一位判断一下是否对答案有贡献,这个题就解决了
现在问题来到如何 快速统计两数之和落在某区间内的数对个数
如果我们假设 \(i,j\) 没有大小关系。那么打乱数组不影响答案(因为 \(i,j\) 都任意取了)。如果设 \(b\) 为 \(a\) 取模后的数组。
不妨将 \(b\) 排序,此时可以用 双指针 统计数对个数
auto calc = [&](int L,int R) {
int ans = 0;
for(int i = n,l = 1,r = 1;i > 0;i --) {
while(l <= n && b[i] + b[l] < L) l ++; //xx[l
while(r <= n && b[i] + b[r] <= R) r ++; // xxlxx]r
ans += r - l - (r > i && l <= i); // erase (i,i)
}
return ans >> 1 & 1;
};
需要注意的是要将 \((i,i)\) 这种不合法情况减去。并且统计完后除二恢复顺序。
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<set>
#include<queue>
#include<map>
#include<string>
#include<random>
#include<iomanip>
#define yes puts("yes");
#define inf 0x3f3f3f3f
#define ll long long
#define linf 0x3f3f3f3f3f3f3f3f
#define ull unsigned long long
#define endl '\n'
#define int long long
#define rep(i,a,n) for(int i = a;i <= n;i++)
#define per(i,n,a) for(int i = n;i >= a;i--)
using namespace std;
mt19937 mrand(random_device{}());
int rnd(int x) { return mrand() % x;}
typedef pair<int,int> PII;
const int MAXN =10 + 2e5 ,mod=1e9 + 7;
void solve()
{
int n; cin >> n;
vector<int> a(n + 1);
rep(i,1,n) cin >> a[i];
vector<int> b(n + 1);
auto calc = [&](int L,int R) {
int ans = 0;
for(int i = n,l = 1,r = 1;i > 0;i --) {
while(l <= n && b[i] + b[l] < L) l ++; //xx[l
while(r <= n && b[i] + b[r] <= R) r ++; // xxlxx]r
ans += r - l - (r > i && l <= i); // erase (i,i)
}
return ans >> 1 & 1;
};
int ans = 0;
rep(i,0,25) {
// [2^k,2^(k + 1) - 1] ^ [2^k + 2^(k + 1),2^(k + 1) + 2^(k + 1) - 2]
// [2^k,2^(k + 1)-1] ^ [2^k * 3, 2^(i + 2) - 2]
rep(j,1,n) b[j] = a[j] & ((1 << i + 1) - 1);
sort(b.begin() + 1,b.end());
int t = calc((1 << i), (1 << i + 1) - 1) ^ calc(3 * (1 << i), (1 << i + 2) - 2);
ans |= (t << i);
}
cout << ans;
}
signed main()
{
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
//int T;cin>>T;
//while(T--)
solve();
return 0;
}