cf241b friends

给定 \(n\)个整数 \(a_1,a_2...a_n\) ,求两两异或值前 \(k\) 大的和。

\(1\leq n\leq 50000,0\leq k\leq \frac{n(n-1)}{2},0\leq a_i\leq 10^9\)

建立 trie 树 . 考虑二分第 \(k\) 大的值 \(V\) . 每次 check 需要 \(O(n\log a)\) 的时间 .

然后对于每个数,在 trie 树异或求出 \(\leq V\) 的值的个数,因此每个节点要保留 \(cnt[x][i]\) 表示,以 \(x\) 为根的子树中, \(2^i\) 位置上为 \(1\) 的数的个数 .

时间复杂度 : \(O(n\log^2a)\)

空间复杂度 : \(O(n\log^2a)\)

思路不是很难,但是细节稍微有点多 .

还有一个点, 我的二分结构是

int low=0,high=2e9,val=1;
while(low<high){
	int mid=1ll*(low+high)/2;
	if(check(mid)>=k){
		val=max(val,mid);
		low=mid+1;
	}else high=mid;
}

但是,先计算的是 low+high ,所以此时已经溢出了.

应该这样写 (0ll+low+high) ,就不会溢出了 .

code

#include<bits/stdc++.h>
using namespace std;
char in[100005];
int iiter=0,llen=0;
inline char get(){
	if(iiter==llen)llen=fread(in,1,100000,stdin),iiter=0;
	if(llen==0)return EOF;
	return in[iiter++];
}
inline int rd(){
	char ch=get();while(ch<'0'||ch>'9')ch=get();
	int res=0;while(ch>='0'&&ch<='9')res=(res<<3)+(res<<1)+ch-'0',ch=get();
	return res;
}
inline void pr(long long res){
	if(res==0){putchar('0');return;}
	static int out[20];int len=0;
	while(res)out[len++]=res%10,res/=10;
	for(int i=len-1;i>=0;i--)putchar(out[i]+'0');
}
const int N=5e4+10;
const int mod=1e9+7,inv2=(mod+1)/2;
int n,k;
int a[N];
class node{public:int ch[2],cnt[32],num;}ts[N*32];
int cnt=1;
void ins(int val){
	int x=1;
	for(int i=31;i>=0;i--){
		int id=(val&(1<<i))?1:0;
		if(!ts[x].ch[id])ts[x].ch[id]=++cnt;
		x=ts[x].ch[id];
	}
	ts[x].num++;
}
void dfs(int x,int tp,int dep){
	for(int id=0;id<2;id++){
		int to=ts[x].ch[id];
		if(!to)continue;
		dfs(to,id,dep-1);
		ts[x].num+=ts[to].num;
		for(int i=0;i<32;i++)ts[x].cnt[i]+=ts[to].cnt[i];
	}
	if(tp==1)ts[x].cnt[dep]+=ts[x].num;
}
int qry(int val,int tar){
	int x=1,res=0,tmp=0;
	for(int k=31;k>=0;k--){
		int id=(val^tar)>>k&1;
		if(!(tar>>k&1))res+=ts[x].ch[id^1]?ts[ts[x].ch[id^1]].num:0;
		x=ts[x].ch[id];
	}
	res+=ts[x].num;
	return res;
}
long long get_ans(int val,int tar){
	int x=1,tmp=0;int res=0;
	for(int k=31;k>=0;k--){
		int id=(val^tar)>>k&1;
		if(!(tar>>k&1)){
			int to=ts[x].ch[id^1];
			if(to){
				res+=1ll*tmp*ts[to].num%mod;res%=mod;
				for(int i=0;i<=k;i++){
					res+=1ll*(1<<i)*((val&(1<<i))?(ts[to].num-ts[to].cnt[i]):ts[to].cnt[i])%mod;
					res%=mod;
				}
			}
		}
		tmp+=(tar>>k&1)<<k;
		x=ts[x].ch[id];
	}
	res+=1ll*tmp*ts[x].num%mod;res%=mod;
	return res;
}
int check(int val){
	long long res=0;
	for(int i=0;i<n;i++){
		res+=qry(a[i],val);
		if(val==0)res--;
	}
	return res/2;
}
int main(){
	n=rd();k=rd();
	for(int i=0;i<n;i++)a[i]=rd();
	if(!k){
		putchar('0');
		return 0;
	}
	for(int i=0;i<n;i++)ins(a[i]);
	dfs(1,-1,32);
	int low=0,high=2e9,val=1;
	while(low<high){
		int mid=1ll*(0ll+low+high)/2ll;
		if(check(mid)>=k){
			val=max(val,mid);
			low=mid+1;
		}else high=mid;
	}
	int ans=0,res=check(val);
	for(int i=0;i<n;i++)ans=(ans+get_ans(a[i],val))%mod;
	ans=1ll*ans*inv2%mod;
	if(res>k)ans-=1ll*val*(res-k)%mod,ans=(ans+mod)%mod;
	pr(ans);
	return 0;
}
/*inline? ll or int? size? min max?*/
posted @ 2022-02-02 16:30  xyangh  阅读(12)  评论(0)    收藏  举报