SDOI2017 遗忘的集合

题目描述:

luogu

bzoj

题解:

生成函数+多项式ln(+反演?)

首先如果我们已知$S$,那$S$中$i$的生成函数就是$1+x^i+x^{2i}+……$。

整理一下就是$ \frac{1} {1-x^i}$。

所以设$S$的生成函数为$F(x)$,则$F(x)= \prod _{i \in S} \frac{1}{1-x^i} $

由于我们并不知道有哪些是答案,我们可以引入一个系数$a_i \in {0,1}$,表示数$i$不取/取。

所以$F(x) = \prod_{i \ge 1}(\frac{1}{1-x^i})^{a_i}$

其实都明白$i$是有上界的,因为整个式子都是模$x^{n+1}$,所以上界为$n$。

然后对上式取负对数,得$-ln(F(x)) = \sum_{i \ge 1} a_i * ln(1-x^i)$。

同时$ln(1-x^i)=\sum_{j \ge 1}- \frac{x^{ij}}{j}$

然后左右去掉负号,得$ln(F(x)) = \sum_{i \ge 1} a_i \sum_{j \ge 1} \frac{x^{ij}}{j}$

把带$x$的提前,得$ln(F(x)) = \sum_{T \ge 1}x^T \sum_{i | T} \frac{i*a_i}{T}$

所以把$F(x)$取个$ln$,然后系数就可以$O(nlnn)$筛出了。

代码:

#include<cmath>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef long long ll;
const int N = (1<<21)+50;
const int M = 32768;
const long double Pi = acos(-1.0);
template<typename T>
inline void read(T&x)
{
    T f = 1,c = 0;char ch=getchar();
    while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
    while(ch>='0'&&ch<='9'){c=c*10+ch-'0';ch=getchar();}
    x = f*c;
}
int n,MOD;
void Mod(int&x){if(x>=MOD)x-=MOD;}
int fastpow(int x,int y)
{
    int ret = 1;
    while(y)
    {
        if(y&1)ret=1ll*ret*x%MOD;
        x=1ll*x*x%MOD;y>>=1;
    }
    return ret;
}
struct cp
{
    long double x,y;
    cp(){}
    cp(long double x,long double y):x(x),y(y){}
    cp operator + (const cp&a)const{return cp(x+a.x,y+a.y);}
    cp operator - (const cp&a)const{return cp(x-a.x,y-a.y);}
    cp operator * (const cp&a)const{return cp(x*a.x-y*a.y,x*a.y+y*a.x);}
};
int to[N],lim,LL[N],L;
void fft(cp*a,int len,int k)
{
    for(int i=0;i<len;++i)
        if(i<to[i])swap(a[i],a[to[i]]);
    for(int i=1;i<len;i<<=1)
    {
        cp w0(cos(Pi/i),k*sin(Pi/i));
        for(int j=0;j<len;j+=(i<<1))
        {
            cp w(1,0);
            for(int o=0;o<i;o++,w=w*w0)
            {
                cp w1 = a[j+o],w2 = w*a[j+o+i];
                a[j+o] = w1+w2;
                a[j+o+i] = w1-w2;
            }
        }
    }
    if(k==-1)
        for(int i=0;i<len;++i)a[i].x/=len,a[i].y/=len;
}
cp a[N],b[N],c[N],d[N];
void get_lim(int len)
{
//    lim = 1,L = 0;
//    while(lim<=len)lim<<=1,L++;
    lim = len;L=LL[len];
    for(int i=1;i<lim;++i)to[i]=((to[i>>1]>>1)|((i&1)<<(L-1)));
}
void mtt(int*F,int*G,int*H,int len)
{
    get_lim(len<<1);
    for(register int i=0;i<lim;++i)a[i]=b[i]=cp(0,0);
    for(register int i=0;i<len;++i)
        a[i]=cp(F[i]&(M-1),F[i]>>15),b[i]=cp(G[i]&(M-1),G[i]>>15);
    fft(a,lim,1),fft(b,lim,1);
    for(register int i=0;i<lim;++i)
    {
        int j = (lim-i)&(lim-1);
        c[j]=cp(0.5*(a[i].x+a[j].x),0.5*(a[i].y-a[j].y))*b[i];
        d[j]=cp(0.5*(a[i].y+a[j].y),0.5*(a[j].x-a[i].x))*b[i];
    }
    fft(c,lim,1),fft(d,lim,1);
    for(register int i=0;i<lim;++i)
    {
        int kaa = (ll)(c[i].x/lim+0.5)%MOD;
        int kab = (ll)(c[i].y/lim+0.5)%MOD;
        int kba = (ll)(d[i].x/lim+0.5)%MOD;
        int kbb = (ll)(d[i].y/lim+0.5)%MOD;
        Mod(H[i] = ((ll)kaa+((ll)(kab+kba)<<15)%MOD+((ll)kbb<<30)%MOD)%MOD+MOD);
    }
}
int H[N],ny[N],T[N];
void get_inv(int len,int*F,int*G)
{
    if(len==1){G[0]=fastpow(F[0],MOD-2);return ;}
    get_inv(len>>1,F,G),mtt(G,G,T,len>>1),mtt(T,F,H,len);
    for(register int i=0;i<len;++i)Mod(G[i]=G[i]*2%MOD-H[i]+MOD);
}
void up(int*a,int len)
{
    for(register int i=1;i<len;++i)a[i-1]=1ll*a[i]*i%MOD;
    a[len-1] = 0;
}
void down(int*a,int len)
{
    for(register int i=len-1;i;--i)a[i]=1ll*a[i-1]*ny[i]%MOD;
    a[0]=0;
}
void get_ln(int len,int*F,int*G)
{
    get_inv(len,F,G);up(F,len);
    mtt(F,G,G,len);
    down(G,len);
}
int F[N],G[N];
int main()
{
    read(n),read(MOD);F[0]=1;
    lim=LL[2]=1;while(lim<=n)lim<<=1,LL[lim<<1]=LL[lim]+1;
    ny[1]=1;
    for(register int i=2;i<(lim<<1);++i)ny[i]=1ll*(MOD-MOD/i)*ny[MOD%i]%MOD;
    for(register int i=1;i<=n;++i)
        read(F[i]),F[i]%=MOD;
    get_ln(lim,F,G);
    int cnt = 0;
    for(register int i=1;i<=n;++i)G[i]=1ll*G[i]*i%MOD;
    for(register int i=1;i<=n;++i)
        for(register int j=i+i;j<=n;j+=i)Mod(G[j]+=MOD-G[i]);
    for(register int i=1;i<=n;++i)if(G[i])cnt++;
    printf("%d\n",cnt);
    for(register int i=1;i<=n;++i)
        if(G[i])printf("%d ",i);
    puts("");
    return 0;
}
View Code

(我还重新学了一遍MTT)

posted @ 2019-06-19 14:19  LiGuanlin  阅读(214)  评论(0编辑  收藏  举报