BZOJ 3513: [MUTC2013]idiots FFT
Description
给定n个长度分别为a_i的木棒,问随机选择3个木棒能够拼成三角形的概率。
Input
第一行T(T<=100),表示数据组数。
接下来若干行描述T组数据,每组数据第一行是n,接下来一行有n个数表示a_i。
3≤N≤10^5,1≤a_i≤10^5
Output
T行,每行一个整数,四舍五入保留7位小数。
题解:
三角形的三条边要满足最小边与次小边之和要小于最长边之和
令 $f_{i}$ 表示两边之和为 $i$ 的数量.
那么合法的三角形数量应为 $\sum_{i=1}^{Max}f_{i}\times g_{i-1}$ ($g_{i}$ 表示长度小于等于 $i$ 的数量)
然而这样做其实十分麻烦,因为 $g_{i-1}$ 中与 $f_{i}$ 中是会有重复元素的
我们变一下,令 $g_{i}$ 表示长度大于等于 $i$ 的数量
那么不合法的情况为 $\sum_{i=1}^{Max}f_{i}\times g_{i}$,可以用总数量减掉不合法数量来求合法数量
构造生成函数 $A=\sum_{i=1}^{Max}a_{i}x^i$, $a_{i}$ 表示长度为 $i$ 的边有多少个
那么 $f=A^2$ 就是两边结合的情况,用 $FFT$ 来加速
要注意当 $i$ 为偶数时,相同的边也会贡献一次,所以要先减掉这些相同边
然后发现我们这么结合时有序的,而实际上边应该是无序的,所以还需要 $/2$
得到正确的 $f$ 后一次枚举每一个 $i$,与 $g_{i}$ 结合即可
#include<bits/stdc++.h>
#define setIO(s) freopen(s".in","r",stdin)
#define maxn 400003
#define ll long long
using namespace std;
namespace IO
{
inline int read()
{
int ans=0;
char ch=getchar();
while(!isdigit(ch))ch=getchar();
while(isdigit(ch))ans=(ans<<3)+(ans<<1)+(ch^48),ch=getchar();
return ans;
}
};
const double pi=acos(-1.0);
struct cpx
{
double x,y;
cpx(double a=0,double b=0){ x=a,y=b; }
cpx operator+(const cpx b) { return cpx(x+b.x, y+b.y); }
cpx operator-(const cpx b) { return cpx(x-b.x, y-b.y); }
cpx operator*(const cpx b) { return cpx(x*b.x-y*b.y,x*b.y+y*b.x); }
}A[maxn],B[maxn];
inline void FFT(cpx *a,int n,int flag)
{
for(int i=0,k=0;i<n;++i)
{
if(i>k) swap(a[i], a[k]);
for(int j=(n>>1);(k^=j)<j;j>>=1);
}
for(int mid=1;mid<n;mid<<=1)
{
cpx wn(cos(pi/mid), flag*sin(pi/mid)),x,y;
for(int i=0;i<n;i+=(mid<<1))
{
cpx w(1,0);
for(int j=0;j<mid;++j)
{
x=a[i+j],y=w*a[i+j+mid];
a[i+j]=x+y, a[i+j+mid]=x-y;
w=w*wn;
}
}
}
if(flag==-1) for(int i=0;i<n;++i) a[i].x/=(double)n;
}
int f[maxn], arr[maxn], g[maxn];
ll answer[maxn];
inline void solve()
{
int n,Max=0,len;
n=IO::read();
for(int i=1;i<=n;++i) arr[i]=IO::read(), ++f[arr[i]], Max=max(Max, arr[i]);
for(int i=Max;i>=1;--i) g[i]=f[i]+g[i+1];
for(int i=1;i<=Max;++i) A[i].x=(double)f[i];
for(len=1;len<=(Max<<1);len<<=1);
FFT(A,len,1);
for(int i=0;i<len;++i) A[i]=A[i]*A[i];
FFT(A,len,-1);
for(int i=1;i<len;++i)
{
answer[i]=(ll)(A[i].x+0.5);
if(!answer[i]) continue;
if(i%2==0) answer[i]-=f[i>>1];
answer[i]>>=1;
}
ll up,down;
up=down=(ll)n*(n-1)*(n-2)/6;
for(int i=0;i<=len;++i) up-=answer[i]*(ll)g[i];
printf("%.7f\n",(double)up/(double)down);
memset(A,0,sizeof(A)), memset(f,0,sizeof(f)), memset(g,0,sizeof(g));
}
int main()
{
// setIO("input");
int T;
T=IO::read();
while(T--) solve();
return 0;
}

浙公网安备 33010602011771号