习题:选数(tri树)
题目
小 s 要在 \([0,2^n)\) 中选一个整数 \(x\),接着把 \(x\) 依次异或 \(m\) 个整数 \(a_1 \sim a_m\),他想要最大化 \(x\) 的最终取值。
然而问题并没有这么简单,小 r 想要干预小 s 的选择。
在小 s 选出 \(x\) 后,小 r 会选择恰好一个时刻(刚选完数时、异或一些数后或是最后),将 \(x\) 变为 \((\lfloor \frac{2x}{2^n} \rfloor +2x) \bmod 2^n\)。
小 s 想使 \(x\) 最后尽量大,而小 r 会使 \(x\) 最后尽量小。
小 s 想要求出 \(x\) 最后的最大值,以及得到最大值的初值数量。
然而小 s 太笨了不会算,请你帮帮他。
对于 \(20\%\) 的数据,\(n \leq 10,m \leq 100\);
对于 \(40\%\) 的数据,\(n \leq 10,m \leq 1000\);
对于另外 \(20\%\) 的数据,\(n \leq 30,m \leq 10\);
对于 \(100\%\) 的数据,\(n \leq 30,m \leq 100000,0 \leq a_i<2^n\)。
思路
设\(f(x)=(\lfloor\frac{2x}{2^n}+2x\rfloor)\% 2^n\)
可以发现,这就是对x进行一次向左的循环
考虑一个数\(x\),他会进行的变化为\(f(x\oplus pre_i)\oplus(pre_n\oplus pre_i)\)
其中\(pre_i\)表示前缀异或
有了对\(f(x)\)的深入理解,发现\(f\)可以拆开
\(f(x)\oplus f(pre_i)\oplus(pre_n\oplus pre_i)\)
可以发现\(f(x)\)对先手而言没有用
可以转换为\(x\oplus f(pre_i)\oplus(pre_n\oplus pre_i)\)
之后考虑\(f(pre_i)\oplus (pre_n\oplus pre_i)\)进行建tri树
对于每一个节点先手的选择就是\((0,1)\),后手的选择就是左儿子和右儿子
如果有两个儿子,那么不管怎么样都为0,所以需要两边都去搜索
只有一个儿子,就可以构造成为1
代码
#include<iostream>
#include<cstdio>
using namespace std;
#define pii pair<int,int>
#define x first
#define y second
namespace tri
{
int n;
int cnt=1;
struct node
{
int ch[2];
}tre[30*100000+5];
void insert(int s)
{
int now=1;
for(int i=n-1;i>=0;i--)
{
int t=(s>>i)&1;
if(tre[now].ch[t]==0)
tre[now].ch[t]=++cnt;
now=tre[now].ch[t];
}
}
pii ask(int now,int ans)
{
//cout<<now<<' '<<tre[now].ch[0]<<'\n';
// cout<<now<<' '<<tre[now].ch[1]<<'\n';
if(tre[now].ch[0]==0&&tre[now].ch[1]==0)
return make_pair(ans,1);
if(tre[now].ch[0]!=0&&tre[now].ch[1]!=0)
{
pii t1=ask(tre[now].ch[0],ans<<1);
pii t2=ask(tre[now].ch[1],ans<<1);
if(t1.x==t2.x)
return make_pair(t1.x,t1.y+t2.y);
else
return max(t1,t2);
}
else
{
if(tre[now].ch[0])
return ask(tre[now].ch[0],ans<<1|1);
else
return ask(tre[now].ch[1],ans<<1|1);
}
}
}
using namespace tri;
int m;
int a[100005];
int s[100005];
int calc(int s)
{
return (2*s/(1<<n)+2*s)%(1<<n);
}
int main()
{
cin>>n>>m;
for(int i=1;i<=m;i++)
{
cin>>a[i];
s[i]=s[i-1]^a[i];
}
/*for(int i=1;i<=m;i++)
cout<<s[i]<<' ';
cout<<'\n';*/
for(int i=0;i<=m;i++)
{
insert(calc(s[i])^(s[m]^s[i]));
//cout<<(calc(s[i])<<'\n';
}
pii t=ask(1,0);
cout<<t.x<<'\n'<<t.y;
return 0;
}

浙公网安备 33010602011771号