题目链接
- 题解好抽象……费了好大劲儿才(自认为)勉强看懂代码,也不打算自己重写了,下附自己写的log平方(我的实现可能是立方)代码和题解代码
- 首先,朴素模拟的复杂度似乎不对,但通过打表找规律可以发现一个优良性质——对于每次合并完的结果,至多有两个长度相同的互异段,于是朴素模拟的复杂度的确是log方的
- 其次,继续找规律可以发现,对于长度为\(2^j\)的一个分段,它要么是[xsum,(xsum|((1ll<<j)-1))],要么是[(xsum|(1ll<<j)),(xsum|((1ll<<j+1)-1))],xsum为当前所有的x的异或和保留j+1位及更高位得到的结果,而去掉末j+1位可以通过先右移再左移实现
- 动态规划。用下标的0/1区分左端点的第s位是否为1,dp0/dp1区分是段长<\(2^s\)的总和还是<=\(2^s\)的总和
- 感受为什么要用前缀和优化动态规划?因为对于某个特定的段,向比它小(大)的段转移,操作都是类似的、有规律可循的
- s记录了当前已经更新的x的所有段长的总和,这些段长都小于\(2^j\)
- 转移时,新加进来的x的贡献显然是s,但,原来的呢?
- 考虑在朴素转移中,对于当前所有lowbit(x)<\(2^j\)的x,我们用(cur[i].first^(x-lowbit(x)))去掉末j+1位更新区间左端点
- 所以根据当前x第j位是否为1来确定
- 最后还要把长度=2^j的段加入dp1中
- 2的逆元可以用(p+1)/2求(p是奇数)
- 用左闭右开区间
点击查看代码
#include <bits/stdc++.h>
using namespace std;
const int mod=998244353;
const long long maxn=(1ll<<60)-1;
long long a[1000005];
vector<pair<long long,long long> >cur,tmp;
vector<long long>cnt,res;
long long lowbit(long long n)
{
return n&(-n);
}
void split(long long x)
{
tmp.clear();
res.clear();
int n=cur.size();
while(x)
{
if(!n)
{
tmp.push_back(make_pair(x-lowbit(x),x));
res.push_back(1);
}
else
{
for(int i=0;i<n;i++)
{
long long len=max(cur[i].second-cur[i].first,lowbit(x));
long long val=min(cur[i].second-cur[i].first,lowbit(x))%mod;
long long L=(cur[i].first^(x-lowbit(x)))&(maxn-(len-1)),R=L+len;
bool f=false;
for(int j=0;j<tmp.size();j++)
{
if(L==tmp[j].first&&R==tmp[j].second)
{
f=true;
res[j]+=(val*cnt[i]%mod);
res[j]%=mod;
break;
}
}
if(f==false)
{
tmp.push_back(make_pair(L,R));
res.push_back(val*cnt[i]%mod);
}
}
}
x-=lowbit(x);
}
cur=tmp;
cnt=res;
}
long long getlen(long long l,long long r,long long L,long long R)
{
if(l>=L&&r<=R)
{
return (r-l+1)%mod;
}
if(l<=L&&r>=R)
{
return (R-L+1)%mod;
}
if(r<L||l>R)
{
return 0;
}
if(r>=L&&r<=R)
{
return (r-L+1)%mod;
}
return (R-l+1)%mod;
}
int n,m;
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int T;
cin>>T;
while(T--)
{
cin>>n>>m;
cur.clear();
cnt.clear();
for(int i=1;i<=n;i++)
{
cin>>a[i];
a[i]++;
split(a[i]);
}
for(int i=1;i<=m;i++)
{
long long l,r;
cin>>l>>r;
long long ans=0;
for(int j=0;j<cur.size();j++)
{
long long L=cur[j].first,R=cur[j].second;
R--;
ans=ans+getlen(l,r,L,R)*cnt[j]%mod;
ans%=mod;
}
cout<<ans<<endl;
}
}
return 0;
}
点击查看代码
#include<bits/stdc++.h>
using namespace std;
const int p=998244353;
const int inv=p+1>>1;
long long dp0[70][2],dp1[70][2];
int main()
{
ios::sync_with_stdio(false);
cin.tie(0);
int T;
cin>>T;
while(T--)
{
int n,m;
cin>>n>>m;
memset(dp0,0,sizeof(dp0));
memset(dp1,0,sizeof(dp1));
for(int j=0;j<60;j++)
{
dp0[j][0]=dp1[j][0]=1;
}
long long xsum=0;
for(int i=1;i<=n;i++)
{
long long x;
cin>>x;
x++;
xsum^=x;
long long s=0;
for(int j=0;j<60;j++)
{
long long d0=dp0[j][0],d1=dp0[j][1];
dp0[j][x>>j&1]=d0*s%p;
dp0[j][~x>>j&1]=d1*s%p;
d0=dp1[j][0],d1=dp1[j][1];
dp1[j][x>>j&1]=d0*s%p;
dp1[j][~x>>j&1]=d1*s%p;
if(x>>j&1)
{
s+=1ll<<j;
s%=p;
dp1[j][0]=(dp1[j][0]+(1ll<<j)%p*d0)%p;
dp1[j][1]=(dp1[j][1]+(1ll<<j)%p*d1)%p;
}
}
}
while(m--)
{
long long L,R;
cin>>L>>R;
long long res=0;
for(long long s=0,pw=1;s<60;s++,pw=pw*inv%p)
{
long long t=xsum;
xsum>>=s+1;xsum<<=s+1;
long long len1=min(R,xsum|((1ll<<s)-1))-max(L,xsum)+1;
long long len2=min(R,xsum|((1ll<<s+1)-1))-max(L,xsum|(1ll<<s))+1;
len1%=p;
len2%=p;
res+=(len1>0)*(dp1[s][0]-dp0[s][0])*len1%p*pw%p;
res+=(len2>0)*(dp1[s][1]-dp0[s][1])*len2%p*pw%p;
xsum=t;
}
cout<<(res%p+p)%p<<endl;
}
}
return 0;
}