Furik and Rubik and Sub Array题解
Description
给定一个长度为 \(N\) 只含正数的数组 \(a_i\) 求所有的连续的区间内数的总和的种类数,设数组的总和为 \(S\) 。
\(N\cdot S\leq 4\cdot10^{10}\)
Solution
因为数组里面都是整数,所以 \(N\leq 2\cdot 10^5\) 。
所以直接暴力 \(N^2\) 找,要么爆时间(指 \(2\cdot 10^4 \leq N \leq 2 \cdot 10^5\) ),要么爆空间(指 \(1 \leq N \leq 2 \cdot 10^3\) )
所以,我们就分三个类型:
1.(\(1 \leq N \leq 2 \cdot 10^3\) )
直接开什么 \(map\) 或者 \(set\) 乱搞就行,空间什么的 \(STL\) 根本没在怕。。
时间复杂度 \(O(n^2logn)\) 。
2.(\(2\cdot 10^3 \leq N \leq 2 \cdot 10^4\) )
此时的 \(S\) 最大也就 \(2\cdot10^6\) ,所以开一个 \(vis\) 硬上。。
时间复杂度 \(O(N^2)\) 。
3.(\(2\cdot 10^4 \leq N \leq 2 \cdot 10^5\) )
这时候时间这个硬伤不行了,可以用 FTT/NTT 进一步优化计算。
但我们要先想怎么把 \(N^2\) 的遍历改成 \(N^2\) 的多项式乘法。
我们求的是一段连续的区间,所以用一个数当作多项式里面的指数肯定不能达到要求。
所以我们可以考虑把所有前缀和存进去。
大概就是一正一负:
\(f_n=\sum_{i=1}^{n}x^{sum_i}\)
\(g_n=\sum_{i=1}^{n}x^{-sum_i}\)
这两个卷起来就行。
注意:多项式里指数不能是负数,所以可以集体再加一个 \(sum_n\) ,然后找的时候整体右移就行了。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=5e6+10;
typedef ll arr[N];
const ll mod=998244353;
const ll inv3=332748118;
ll n,m,sum[N],nm,inv,lim=1,fre,id[N],Ans;
arr a,b,ans;
map<ll,bool > mp;
int vis[N];
inline ll inc(ll x,ll y){return x+y>=mod?x+y-mod:x+y;}
inline ll dec(ll x,ll y){return x-y<0?x-y+mod:x-y;}
inline ll read(){
ll s=0,w=1;
char ch=getchar();
while(ch<'0'||ch>'9'){if(ch=='-')w=-1;ch=getchar();}
while(ch>='0'&&ch<='9') s=s*10+ch-'0',ch=getchar();
return s*w;
}
inline ll ksm(ll a,ll b){
ll tmp=1;
while(b){
if(b&1)tmp=(tmp*a)%mod;
b>>=1,a=(a*a)%mod;
}
return tmp;
}
inline void Never_Tell_TLE(ll* NTT,ll sign){
for(ll i=0;i<=lim;++i)if(i<id[i]){
ll NTt=NTT[i];
NTT[i]=NTT[id[i]];
NTT[id[i]]=NTt;
}
for(ll mid=1;mid<lim;mid<<=1){
ll Unit_root;
if(sign==1)Unit_root=ksm(3,(mod-1)/(mid<<1));
else Unit_root=ksm(inv3,(mod-1)/(mid<<1));
for(ll R=mid<<1,r=0;r<lim;r+=R){
ll pw=1;
for(ll l=0;l<mid;++l,pw=(pw*Unit_root)%mod){
ll butt=NTT[l+r],rfly=(pw*NTT[l+r+mid])%mod;
NTT[l+r]=inc(butt,rfly);
NTT[l+r+mid]=dec(butt,rfly);
}
}
}
if(sign==-1)for(ll i=0;i<=lim;++i)NTT[i]=(NTT[i]*inv)%mod;
}
int main(){
n=read();
for(int i=1;i<=n;++i){
sum[i]=read()+sum[i-1];
}
if(n<=4000){
for(int i=1;i<=n;++i){
for(int j=i;j<=n;++j){
if(!mp[sum[j]-sum[i-1]]){
++Ans;
mp[sum[j]-sum[i-1]]=1;
}
}
}
printf("%lld\n",Ans-1);
return 0;
}
if(n<=20000){
for(int i=1;i<=n;++i){
for(int j=i;j<=n;++j){
++vis[sum[j]-sum[i-1]];
}
}
for(int i=1;i<=2e6;++i){
if(vis[i])++Ans;
}
printf("%lld\n",Ans-1);
return 0;
}
nm=sum[n]<<1;
for(int i=0;i<=n;++i){
a[sum[n]+sum[i]]=1;
b[sum[n]-sum[i]]=1;
}
for(;lim<=(nm<<1);lim<<=1)++fre;
inv=ksm(lim,mod-2);
for(int i=0;i<lim;++i)id[i]=(id[i>>1]>>1)|((i&1)<<(fre-1));
Never_Tell_TLE(a,1);
Never_Tell_TLE(b,1);
for(int i=0;i<=lim;++i)ans[i]=(a[i]*b[i])%mod;
Never_Tell_TLE(ans,-1);
for(int i=nm+1;i<=(nm<<1);++i)if(ans[i])++Ans;
printf("%lld\n",Ans-1);
return 0;
}
(难得考场过题呀)

浙公网安备 33010602011771号