BZOJ 3513: [MUTC2013]idiots

3513: [MUTC2013]idiots

Time Limit: 20 Sec  Memory Limit: 128 MB
Submit: 476  Solved: 162
[Submit][Status][Discuss]

Description

给定n个长度分别为a_i的木棒,问随机选择3个木棒能够拼成三角形的概率。

Input

第一行T(T<=100),表示数据组数。
接下来若干行描述T组数据,每组数据第一行是n,接下来一行有n个数表示a_i。
3≤N≤10^5,1≤a_i≤10^5

Output

T行,每行一个整数,四舍五入保留7位小数。

Sample Input

2
4
1 3 3 4
4
2 3 3 4

Sample Output

0.5000000
1.0000000

HINT

T<=20

N<=100000

Source

By sbullet

分析:

求出不合法的概率然后用1减去...

$dp[i]$代表选取两个木棍之和小于等于$i$的方案数...$dp[i]=\sum num[k]num[i-k]$

FFT一下...

代码:

#include<algorithm>
#include<iostream>
#include<cstring>
//#include<complex>
#include<cstdio>
#include<cmath>
//by NeighThorn
#define pi acos(-1)
using namespace std;
//typedef complex<double> M;

const int maxn=400000+5;

struct complex{
	
	double r,i;
	
	inline complex(double a=0,double b=0): r(a),i(b) {};
	
	inline complex operator + (const complex &a){
		return complex(r+a.r,i+a.i);
	}
	
	inline complex operator - (const complex &a){
		return complex(r-a.r,i-a.i);
	}
	
	inline complex operator * (const complex &a){
		return complex(r*a.r-i*a.i,r*a.i+i*a.r);
	}
	
}a[maxn],b[maxn],dp[maxn];

int n,N,m,L,cas,Max,R[maxn],num[maxn];
long long ans,sum,tot;
double res;

inline void FFT(complex *a,int f){
	for(int i=0;i<N;i++)
		if(i>R[i]) swap(a[i],a[R[i]]);
	for(int i=1;i<N;i<<=1){
		complex wn(cos(pi/i),f*sin(pi/i));
		for(int j=0;j<N;j+=i<<1){
			complex w(1,0);
			for(int k=0;k<i;k++,w=w*wn){
				complex x=a[j+k],y=w*a[j+k+i];
				a[j+k]=x+y,a[j+k+i]=x-y;
			}
		}
	}
	if(f==-1){
		for(int i=0;i<N;i++)
			a[i].r=a[i].r/N;
	}
}

signed main(void){
	scanf("%d",&cas);
	while(cas--){
		scanf("%d",&n);Max=0;ans=0;
		memset(a,0,sizeof(a));
		memset(b,0,sizeof(b));
		memset(num,0,sizeof(num));
		for(int i=1,x;i<=n;i++)
			scanf("%d",&x),num[x]++,Max=max(Max,x);
		for(int i=1;i<=Max;i++)
			a[i].r=num[i],b[i].r=num[i];
		m=Max<<1;L=0;
		for(N=1;N<=m;N<<=1) L++;
		for(int i=0;i<N;i++)
			R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
		FFT(a,1);FFT(b,1);
		for(int i=0;i<N;i++)
			dp[i]=a[i]*b[i];
		FFT(dp,-1);sum=0;tot=1LL*n*(n-1)*(n-2)/6;
		for(int i=1;i<=Max;i++){
			sum+=dp[i].r+0.1;
			if((i&1)==0) sum-=num[i>>1];
			ans+=1LL*num[i]*sum;
		}
		ans>>=1;res=1.0-1.0*ans/(1.0*tot);
		printf("%.7f\n",res);
	}
	return 0;
}

  


By NeighThorn

 

posted @ 2017-03-16 20:07  NeighThorn  阅读(280)  评论(0编辑  收藏  举报