bzoj3771

http://www.lydsy.com/JudgeOnline/problem.php?id=3771

生成函数。。。

其实就是多项式乘法。。。lrj书上有一个通俗的解释。。。

然后就是这个样子,我们构造一个多项式,a[x]=1,表示这个水果存在,那么我们乘一下就求出对应的大小了。但是可能会有重复的,所以要用容斥减去。

这个式子大概是这个样子的(a^3-3*a*b+2*c)/6+(a^2-b)/2+a,为什么呢?

a表示每种斧头选1次,b2次,c3次 

a^3是所有随便选的情况,但是这里会有重复,可能一个东西选了两次还有三次,还有选了一种情况的排列。

那么我们就要用容斥减去,首先我们讨论选了三把斧头的情况,假设我们选了a斧头和b斧头,那么我们一把斧头选择了两次的集合是(a,a,b),但是原先的a^3包括了(a,a,b),(a,b,a),(b,a,a),重复了三次,a*b只包含(a,a,b)的情况,那么我们就要减去3*a*b,但是3*a*b减去了(a,a,a)这种情况,还减了三次,我们希望减一次就好了,那么再加上2*c就行了,除以6是因为排列的情况。

如果我们选择了两次,那么只有(a,a)要减去,直接减去就行了,除以2是排列。一次直接加上。。。

然后因为多项式的点值表示可以直接相加,因为每个多项式我们带进去的东西都是一样的,所以可以把x提出来,系数相减就行了。。。

最后化成系数表达式,每个系数对应的就是方案数。。。

#include<bits/stdc++.h>
using namespace std;
#define pi acos(-1)
const int N = 300010;
int n, m, lim, l;
int r[N];
complex<double> a[N], b[N], c[N], t[N], t1[N], t2[N];
void fft(complex<double> *a, int f)
{
    for(int i = 0; i <= n; ++i) if(i < r[i]) swap(a[i], a[r[i]]);
    for(int i = 1; i < n; i <<= 1)
    {
        complex<double> wn(cos(pi / i), f * sin(pi / i));
        for(int p = i << 1, j = 0; j < n; j += p)
        {
            complex<double> w(1, 0);
            for(int k = 0; k < i; ++k, w *= wn)
            {
                complex<double> x = a[j + k], y = w * a[j + k + i];
                a[j + k] = x + y; a[j + k + i] = x - y;
            }
        }
    }
    if(f == -1) for(int i = 0; i <= n; ++i) a[i] /= n;
}
int main()
{
    scanf("%d", &n); --n;
    for(int i = 0; i <= n; ++i)
    {
        int x; scanf("%d", &x); 
        a[x] = b[x * 2] = c[x * 3] = 1;
        lim = max(lim, x);
    } 
    m = 3 * lim;
    for(n = 1; n <= m; n <<= 1) ++l;
    for(int i = 0; i <= n; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    fft(a, 1); fft(b, 1); fft(c, 1);
    //t[i]:a^3 t1[i]:3*b*a t2[i]:a^2;    
    for(int i = 0; i <= n; ++i) a[i] = (a[i] * a[i] * a[i] - 3.0 * b[i] * a[i] + 2.0 * c[i]) / 6.0 + (a[i] * a[i] - b[i]) / 2.0 + a[i];    
//    for(int i = 0; i <= n; ++i) a[i] = a[i] + (t[i] - t1[i] + 2.0 * c[i]) / 6.0 + (t2[i] - b[i]) / 2.0;  
    fft(a, -1);
    for(int i = 0; i <= m; ++i) if((int)(a[i].real() + 0.5) > 0)
        printf("%d %d\n", i, (int)(a[i].real() + 0.5));
    return 0;
}
View Code

 

posted @ 2017-05-18 21:36  19992147  阅读(121)  评论(0编辑  收藏  举报