ARC127 Sum of Min of Xor

可以发现 \(a_i \bigoplus b_i \bigoplus a_j \bigoplus b_j\)\(1\) 的位置,是 \(a_i \bigoplus a_j\)\(b_i \bigoplus b_j\) 不同的位置。
\(c_i = a_i \bigoplus b_i\),考虑根据上个性质分治,每次吧在当前位数 \(dep\)\(0\)\(c_i\) 放入一个集合中,\(1\) 的放在另一集合,那么这两个集合之间 \(a_i \bigoplus a_j\)\(b_i \bigoplus b_j\) 不同的最高位就是 \(dep\),判断他们的大小关系只需要考虑在 \(dep\) 位上的大小关系即可。
具体地说,先按 \(c_i\) 排序,然后对于位数,求出该位 \(0,1\) 的分界 \(mid\), 即 \([l,mid-1]\)\(0\), \([mid + 1, r]\)\(1\)。然后呢,对于每个 \(i\) 都知道 \(a_i\)\(b_i\)。 就可以将其分为四种情况,并按照没有 \(min\) 的方法 \((cnt_0 \times cnt_1)\) 计算贡献。
大致就是考虑 \(a_i\)\(b_i\) 的值(因为其这位相同所以只能都为 \(0\) 或都为 \(1\)),再考虑 \(a_j\)\(b_j\) 这位的情况(哪个这位为 \(1\)) 分类讨论一下即可。
那么我们分治的时候只需要考虑两个不同的集合之间的答案,单个集合的答案最终统计一下就行啦。
Tips:
对大小关系的判断可以考虑异或!
涉及 \(a_i, a_j, b_i, b_j\) 的题目可以考虑讲 \(a_i\)\(b_i\) 合并。
这种 \((i,j)\) 问题往往可以分治计算。
\(min\) 不好处理可以尝试拆掉。
代码

#include<bits/stdc++.h>
#define RG register
#define LL long long
#define U(x, y, z) for(RG int x = y; x <= z; ++x)
#define D(x, y, z) for(RG int x = y; x >= z; --x)
#define update(x, y) (x = x + y >= mod ? x + y - mod : x + y)
using namespace std;
const int mod = 998244353;
namespace FastIO {
#define il inline
const int iL = 1 << 25;
char ibuf[iL], *iS = ibuf + iL, *iT = ibuf + iL;
#define GC() (iS == iT) ? \
  (iT = (iS = ibuf) + fread(ibuf, 1, iL, stdin), (iS == iT) ? EOF : *iS++) : *iS++
void read(){}
template<typename _Tp, typename... _Tps>
void read(_Tp &x, _Tps &...Ar) {
    x = 0; char ch = GC(); bool flg = 0;
    for (; !isdigit(ch); ch = GC()) flg |= (ch == '-');
    for (; isdigit(ch); ch = GC()) x = (x << 1) + (x << 3) + (ch ^ 48);
    if (flg) x = -x;
    read(Ar...);    
}
char Out[iL], *iter = Out;
#define Flush() fwrite(Out, 1, iter - Out, stdout); iter = Out
template <class T>il void write(T x, char LastChar = '\n') {
    int c[35], len = 0;
    if (x < 0) {*iter++ = '-'; x = -x;}
    do {c[++len] = x % 10; x /= 10;} while (x);
    while (len) *iter++ = c[len--] + '0';
    *iter++ = LastChar; Flush();
}
template <typename T> inline void writeln(T n){write(n, '\n');}
template <typename T> inline void writesp(T n){write(n, ' ');}
inline char Getchar(){ char ch; for (ch = GC(); !isalpha(ch); ch = GC()); return ch;}
inline void readstr(string &s) { s = ""; static char c = GC(); while (isspace(c)) c = GC(); while (!isspace(c)) s = s + c, c = GC();}
}
using namespace FastIO;
struct modint{
    int x;
    modint(int o=0){x=o;}
    modint &operator = (int o){return x=o,*this;}
    modint &operator +=(modint o){return x=x+o.x>=mod?x+o.x-mod:x+o.x,*this;}
    modint &operator -=(modint o){return x=x-o.x<0?x-o.x+mod:x-o.x,*this;}
    modint &operator *=(modint o){return x=1ll*x*o.x%mod,*this;}
    modint &operator ^=(int b){
        if(b<0)return x=0,*this;
        b%=mod-1;
        modint a=*this,c=1;
        for(;b;b>>=1,a*=a)if(b&1)c*=a;
        return x=c.x,*this;
    }
    modint &operator /=(modint o){return *this *=o^=mod-2;}
    modint &operator +=(int o){return x=x+o>=mod?x+o-mod:x+o,*this;}
    modint &operator -=(int o){return x=x-o<0?x-o+mod:x-o,*this;}
    modint &operator *=(int o){return x=1ll*x*o%mod,*this;}
    modint &operator /=(int o){return *this *= ((modint(o))^=mod-2);}
    template<class I>friend modint operator +(modint a,I b){return a+=b;}
    template<class I>friend modint operator -(modint a,I b){return a-=b;}
    template<class I>friend modint operator *(modint a,I b){return a*=b;}
    template<class I>friend modint operator /(modint a,I b){return a/=b;}
    friend modint operator ^(modint a,int b){return a^=b;}
    friend bool operator ==(modint a,int b){return a.x==b;}
    friend bool operator !=(modint a,int b){return a.x!=b;}
    bool operator ! () {return !x;}
    modint operator - () {return x?mod-x:0;}
};
template <typename T> inline void chkmin(T &x, T y){x = x < y ? x : y;}
template <typename T> inline void chkmax(T &x, T y){x = x > y ? x : y;}
template <typename T> inline T Min(T x, T y){return x < y ? x : y;}
template <typename T> inline T Max(T x, T y){return x > y ? x : y;}
inline void FO(string s){freopen((s + ".in").c_str(), "r", stdin); freopen((s + ".out").c_str(), "w", stdout);}

const int N = 2.5e5 + 10;
int n, m;
int a[N], b[N], c[N], id[N];
LL cnt[2], ans;

inline void solve(int l, int r, int dep) {
	// cerr << l << " " << r << ' ' << dep << "\n";
	if (l >= r) return ;
	if (dep < 0) {
		// cerr << l << " " << r << " " << c[id[l]] << " " << c[id[r]] << "\n";
		U(i, 0, 17) {
			cnt[0] = cnt[1] = 0;
			U(j, l, r)
				cnt[(a[id[j]] >> i) & 1]++;
			ans += cnt[0] * cnt[1] * (1 << i);
		}
		cerr << ans << "\n";
		return ;
	}
	int mid = l;
	while (mid <= r && !((c[id[mid]] >> dep) & 1)) mid++;
	if (mid <= r && mid > l)  U(i, 0, 17) {
		cnt[0] = cnt[1] = 0;
		U(j, l, mid - 1)
			if (!(a[id[j]] >> dep & 1))
				cnt[b[id[j]] >> i & 1]++;
		U(j, mid, r) 
			if (a[id[j]] >> dep & 1) 
				ans += cnt[(b[id[j]] >> i & 1) ^ 1] << i;

		cnt[0] = cnt[1] = 0;
		U(j, l, mid - 1)
			if (!(a[id[j]] >> dep & 1))
				cnt[a[id[j]] >> i & 1]++;
		U(j, mid, r)
			if (b[id[j]] >> dep & 1)
				ans += cnt[(a[id[j]] >> i & 1) ^ 1] << i;

		cnt[0] = cnt[1] = 0;
		U(j, l, mid - 1)
			if (a[id[j]] >> dep & 1)
				cnt[b[id[j]] >> i & 1]++;
		U(j, mid, r)
			if (b[id[j]] >> dep & 1)
				ans += cnt[(b[id[j]] >> i & 1) ^ 1] << i;
		cnt[0] = cnt[1] = 0;
		U(j, l, mid - 1)
			if (a[id[j]] >> dep & 1)
				cnt[a[id[j]] >> i & 1]++;
		U(j, mid, r)
			if (a[id[j]] >> dep & 1)
				ans += cnt[(a[id[j]] >> i & 1) ^ 1] << i;
		// if (dep == 1) cerr << ans << "\n";
	}
	cerr << "!!! " << dep << " " << l << " " << mid << " " << r << " " << ans << "\n";
	solve(l, mid - 1, dep - 1), solve(mid, r, dep - 1);
}

int main(){
	//FO("");
	read(n);
	U(i, 1, n) read(a[i]);
	U(i, 1, n) read(b[i]);
	U(i, 1, n) {
		c[i] = a[i] ^ b[i];
		id[i] = i;
	}

	sort(id + 1, id + n + 1, [&](int x, int y){return c[x] < c[y];});
	solve(1, n, 17);
	writeln(ans);
	return 0;
}
posted @ 2022-11-24 22:46  Southern_Way  阅读(22)  评论(0)    收藏  举报