BZOJ 3509: [CodeChef] COUNTARI
Description
给定一个长度为\(N\)的数组\(A[]\),求有多少对\(i, j, k(1\leqslant i<j<k \leqslant N)\)满足\(A[k]-A[j]=A[j]-A[i]\)。
Solution
分块FFT。
每个暴力求需要\(n\)次FFT。
分块的话,FFT求块与块之间的,块内的暴力求。
复杂度\(O(n\sqrt {n}logn)\)
BZOJ上险些T了qwq...
Code
/**************************************************************
Problem: 3509
User: BeiYu
Language: C++
Result: Accepted
Time:37220 ms
Memory:7936 kb
****************************************************************/
#include <bits/stdc++.h>
using namespace std;
#define debug(a) cout<<#a<<"="<<a<<" "
#define mpr make_pair
#define r first
#define i second
typedef pair< double,double > pr;
typedef long long LL;
const int N = 1e5+50;
const int B = 2500;
const double Pi = M_PI;
pr operator + (const pr &a,const pr &b) { return mpr(a.r+b.r,a.i+b.i); }
pr operator - (const pr &a,const pr &b) { return mpr(a.r-b.r,a.i-b.i); }
pr operator * (const pr &a,const pr &b) { return mpr(a.r*b.r-a.i*b.i,a.r*b.i+a.i*b.r); }
int NN=65536;
void Rev(pr a[]) {
for(int i=0,j=0;i<NN;i++) {
if(i<j) swap(a[i],a[j]);
for(int k=NN>>1;(j^=k)<k;k>>=1);
}
}
void DFT(pr a[],int r=1) {
Rev(a);
for(int i=1;i<=NN;i<<=1) {
pr wi=mpr(cos(2.0*Pi/i),r*sin(2.0*Pi/i));
for(int j=0;j<NN;j+=i) {
pr w=mpr(1.0,0.0);
for(int k=j;k<j+i/2;k++) {
pr x=a[k],y=w*a[k+i/2];
a[k]=x+y,a[k+i/2]=x-y;
w=w*wi;
}
}
}if(r==-1) for(int i=0;i<NN;i++) a[i].r/=NN;
}
void FFT(pr a[],pr b[],pr c[]) {
DFT(a,1),DFT(b,1);
for(int i=0;i<NN;i++) c[i]=a[i]*b[i];
DFT(c,-1);
}
inline int in(int x=0,char ch=getchar()) { while(ch>'9' || ch<'0') ch=getchar();
while(ch>='0' && ch<='9') x=x*10+ch-'0',ch=getchar();return x; }
int n;LL ans;
int a[N],b[N];
int bf[N],bd[N],tp[N];
pr x1[N],x2[N],x3[N];
int main() {
n=in();
for(int i=0;i<n;i++) a[i]=in();
for(int i=0;i<n;i++) bd[a[i]]++;
for(int j=0;j<n;j+=B) {
for(int i=j;i<n && i<j+B;i++) bd[a[i]]--;
//before and behind
memset(x1,0,sizeof(x1)),memset(x2,0,sizeof(x2));
for(int i=0;i<NN/2;i++) x1[i]=mpr(bf[i],0),x2[i]=mpr(bd[i],0);
FFT(x1,x2,x3);
for(int i=j;i<n && i<j+B;i++) tp[a[i]]++;
for(int i=0;i<NN;i++) if(!(i&1)) ans+=(LL)tp[i/2]*(int)(x3[i].r+0.5);
for(int i=j;i<n && i<j+B;i++) tp[a[i]]=0;
for(int p=j;p<n && p<j+B;p++) for(int q=p+1;q<n && q<j+B;q++) {
if(2*a[p]-a[q]>=0) ans+=bf[2*a[p]-a[q]];
if(2*a[q]-a[p]>=0) ans+=bd[2*a[q]-a[p]];
}
for(int p=j;p<n && p<j+B;p++) {
for(int q=j;q<p;q++) {
if(2*a[q]-a[p]>=0) ans+=tp[2*a[q]-a[p]];
tp[a[q]]++;
}
for(int q=j;q<p;q++) tp[a[q]]--;
}
for(int i=j;i<n && i<j+B;i++) bf[a[i]]++;
}cout<<ans<<endl;
return 0;
}

浙公网安备 33010602011771号