cf241b friends
给定 \(n\)个整数 \(a_1,a_2...a_n\) ,求两两异或值前 \(k\) 大的和。
\(1\leq n\leq 50000,0\leq k\leq \frac{n(n-1)}{2},0\leq a_i\leq 10^9\)
建立 trie 树 . 考虑二分第 \(k\) 大的值 \(V\) . 每次 check 需要 \(O(n\log a)\) 的时间 .
然后对于每个数,在 trie 树异或求出 \(\leq V\) 的值的个数,因此每个节点要保留 \(cnt[x][i]\) 表示,以 \(x\) 为根的子树中, \(2^i\) 位置上为 \(1\) 的数的个数 .
时间复杂度 : \(O(n\log^2a)\)
空间复杂度 : \(O(n\log^2a)\)
思路不是很难,但是细节稍微有点多 .
还有一个点, 我的二分结构是
int low=0,high=2e9,val=1;
while(low<high){
int mid=1ll*(low+high)/2;
if(check(mid)>=k){
val=max(val,mid);
low=mid+1;
}else high=mid;
}
但是,先计算的是 low+high ,所以此时已经溢出了.
应该这样写 (0ll+low+high) ,就不会溢出了 .
code
#include<bits/stdc++.h>
using namespace std;
char in[100005];
int iiter=0,llen=0;
inline char get(){
if(iiter==llen)llen=fread(in,1,100000,stdin),iiter=0;
if(llen==0)return EOF;
return in[iiter++];
}
inline int rd(){
char ch=get();while(ch<'0'||ch>'9')ch=get();
int res=0;while(ch>='0'&&ch<='9')res=(res<<3)+(res<<1)+ch-'0',ch=get();
return res;
}
inline void pr(long long res){
if(res==0){putchar('0');return;}
static int out[20];int len=0;
while(res)out[len++]=res%10,res/=10;
for(int i=len-1;i>=0;i--)putchar(out[i]+'0');
}
const int N=5e4+10;
const int mod=1e9+7,inv2=(mod+1)/2;
int n,k;
int a[N];
class node{public:int ch[2],cnt[32],num;}ts[N*32];
int cnt=1;
void ins(int val){
int x=1;
for(int i=31;i>=0;i--){
int id=(val&(1<<i))?1:0;
if(!ts[x].ch[id])ts[x].ch[id]=++cnt;
x=ts[x].ch[id];
}
ts[x].num++;
}
void dfs(int x,int tp,int dep){
for(int id=0;id<2;id++){
int to=ts[x].ch[id];
if(!to)continue;
dfs(to,id,dep-1);
ts[x].num+=ts[to].num;
for(int i=0;i<32;i++)ts[x].cnt[i]+=ts[to].cnt[i];
}
if(tp==1)ts[x].cnt[dep]+=ts[x].num;
}
int qry(int val,int tar){
int x=1,res=0,tmp=0;
for(int k=31;k>=0;k--){
int id=(val^tar)>>k&1;
if(!(tar>>k&1))res+=ts[x].ch[id^1]?ts[ts[x].ch[id^1]].num:0;
x=ts[x].ch[id];
}
res+=ts[x].num;
return res;
}
long long get_ans(int val,int tar){
int x=1,tmp=0;int res=0;
for(int k=31;k>=0;k--){
int id=(val^tar)>>k&1;
if(!(tar>>k&1)){
int to=ts[x].ch[id^1];
if(to){
res+=1ll*tmp*ts[to].num%mod;res%=mod;
for(int i=0;i<=k;i++){
res+=1ll*(1<<i)*((val&(1<<i))?(ts[to].num-ts[to].cnt[i]):ts[to].cnt[i])%mod;
res%=mod;
}
}
}
tmp+=(tar>>k&1)<<k;
x=ts[x].ch[id];
}
res+=1ll*tmp*ts[x].num%mod;res%=mod;
return res;
}
int check(int val){
long long res=0;
for(int i=0;i<n;i++){
res+=qry(a[i],val);
if(val==0)res--;
}
return res/2;
}
int main(){
n=rd();k=rd();
for(int i=0;i<n;i++)a[i]=rd();
if(!k){
putchar('0');
return 0;
}
for(int i=0;i<n;i++)ins(a[i]);
dfs(1,-1,32);
int low=0,high=2e9,val=1;
while(low<high){
int mid=1ll*(0ll+low+high)/2ll;
if(check(mid)>=k){
val=max(val,mid);
low=mid+1;
}else high=mid;
}
int ans=0,res=check(val);
for(int i=0;i<n;i++)ans=(ans+get_ans(a[i],val))%mod;
ans=1ll*ans*inv2%mod;
if(res>k)ans-=1ll*val*(res-k)%mod,ans=(ans+mod)%mod;
pr(ans);
return 0;
}
/*inline? ll or int? size? min max?*/

浙公网安备 33010602011771号