P6151 [集训队作业2019] 青春猪头少年不会梦到兔女郎学姐

先考虑在序列上怎么做。

定义 \(f(a,b)\) 表示对于所有将 \(a\) 个无标号求分为 \(b\) 段的方案,求所有段长度乘积之和。

考虑 \(f(a,b)\) 的组合意义,先在 \(a-1\) 一个空位中插入 \(b-1\) 个隔板,再在每两个隔板中间选出一个数。我们不妨把选的这个数也看做隔板,这样就是在 \(a+b-1\) 个空位中选出 \(2\times b-1\) 个。即:

\[f(a,b)=\binom {a+b-1} {2\times b -1}=\binom {a+b-1} {a-b} \]

考虑求出第 \(i\) 个数被分成 \(b_i\) 段的方案数,使用容斥原理,枚举有 \(c_i\) 段相邻。

\[\sum _ c (\prod\binom {b_i-1} {c_i}(-1)^{c_i})\frac {(\sum b_i-c_i)!}{\prod (b_i-c_i)!} \]

考虑枚举 \(b_i-c_i\),相当于枚举 \(c_i\) 表示第 \(i\) 个数至多分为 \(c_i\) 段,原式变为:

\[\sum _c(\prod \binom {b_i-1}{c_i-1}(-1)^{b_i-c_i})\frac {(\sum c_i)!}{\prod c_i!} \]

写出答案的式子:

\[\sum _b \sum _c(\prod \binom {b_i-1}{c_i-1}(-1)^{b_i-c_i}f(a_i,b_i))\frac {(\sum c_i)!}{\prod c_i!} \]

发现后面的 \(\frac {(\sum c_i)!}{\prod c_i!}\) 很像 \(EGF\) 的形式,考虑对于每种数字,写出 \(c_i\)\(EGF\)。答案的 \(EGF\) 就是每个 \(EGF\) 的乘积,可以使用分治 \(FTT\)\(O(n\log^2n)\) 的时间里计算。

\[F_i(x)=\sum _c \frac {x^c}{c!} \sum _{b\geq c} \binom {b-1} {c-1} (-1)^{b-c}f(a_i,b) \]

计算 \(F_i(x)\) 可以考虑提出和 \(b\) 有关以及和 \(b-c\) 有关的项,并将和 \(b-c\) 有关的翻转,就是一个卷积的形式了。

再回来考虑一个环的情况。

我们不妨钦定第一个位置为 \(1\) ,最后一个位置不为 \(1\)。计算这个只需要用开头为 \(1\) 的方案 \(-\) 开头结尾都为 \(1\) 的方案数即可。

如果我们钦定开头为 \(1\),那么答案的式子就会变为:

\[\sum _b \sum _c(\prod \binom {b_i-1}{c_i-1}(-1)^{b_i-c_i}f(a_i,b_i))\frac {(c_1-1+\sum _ {c=2} ^n c_i)!}{(c_1-1)!\prod _ {c=2} ^n c_i!} \]

所以我们只需要将 \(F_1(x)\) 系数向下平移一位即可(此处系数不包含 \(\frac 1{c!}\))。

同理,计算开头和结尾都是 \(1\) 的答案只需要将 \(F_1(x)\) 系数向下平移两位即可。

现在还有一个问题,对于一个周期为 \(T\) 的环,能够对应 \(T\) 个不同的排列,我们希望它能够被恰好计算 \(T\) 次。

\(m=\sum a_i\)。根据我们刚才的算法,设它有 \(b\)\(1\),那么它会被计算 \(b/(m/T)\) 次。我们只需要在 \(F_1(x)\) 中含有 \(b\) 的项先除掉 \(b\) ,再对答案 \(\times m\) 即可。

代码

#include <bits/stdc++.h>
using namespace std;
#define N 600010
#define Mod 998244353
const int G=3,invG=332748118;
inline int read() {
    int x=0;
    char ch=getchar();
    while (!isdigit(ch)) ch=getchar();
    while (isdigit(ch)) x=x*10+ch-'0',ch=getchar();
    return x;
}
inline int Pow(int a,int b,int p=Mod) {
    int res=1;
    for (;b;b>>=1,a=1LL*a*a%p)
        if (b&1) res=1LL*res*a%p;
    return res;
}
int a[N],fac[N],inv[N];
inline void init(int n) {
    fac[0]=fac[1]=inv[0]=inv[1]=1;
    for (int i=2;i<=n;i++) fac[i]=1LL*fac[i-1]*i%Mod;
    for (int i=2;i<=n;i++) inv[i]=1LL*(Mod-Mod/i)*inv[Mod%i]%Mod;
    for (int i=2;i<=n;i++) inv[i]=1LL*inv[i-1]*inv[i]%Mod;
}
inline int C(int n,int m) {
    if (n<m) return 0;
    return 1LL*fac[n]*inv[m]%Mod*inv[n-m]%Mod;
}
int GPow[2][20][1<<19];
inline void InitG() {
    for (int p=1;p<=19;p++) {
        int buf1=Pow(G,(Mod-1)/(1<<p));
        int buf0=Pow(invG,(Mod-1)/(1<<p));
        GPow[1][p][0]=GPow[0][p][0]=1;
        for (int i=1;i<(1<<p);i++)
            GPow[1][p][i]=1LL*GPow[1][p][i-1]*buf1%Mod,
            GPow[0][p][i]=1LL*GPow[0][p][i-1]*buf0%Mod;
    }
}
int A[N],B[N],rev[N];
inline void NTT(int *a,int len,int f) {
    int k=1; while ((1<<k)<len) k++;
    for (int i=0;i<len;i++) {
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
        if (i<rev[i]) swap(a[i],a[rev[i]]);
    }
    for (int l=2,cnt=1;l<=len;l<<=1,++cnt) {
        for (int i=0,m=l>>1;i<len;i+=l) {
            int *buf=GPow[f^1][cnt];
            for (int j=0;j<m;j++,buf++) {
                int p=a[i+j],q=1LL*(*buf)*a[i+j+m]%Mod;
                a[i+j]=(p+q)%Mod,a[i+j+m]=(p-q+Mod)%Mod;
            }
        }
    }
}
inline void DFT(int *a,int len) {NTT(a,len,0);}
inline void IDFT(int *a,int len) {
    NTT(a,len,1); int inv=Pow(len,Mod-2);
    for (int i=0;i<len;i++) a[i]=1LL*a[i]*inv%Mod;
}
inline vector<int> Mul(vector<int> f,vector<int> g) {
    int len=1; while (len<f.size()+g.size()) len<<=1;
    for (int i=0;i<f.size();i++) A[i]=f[i]; 
    for (int i=f.size();i<len;i++) A[i]=0; DFT(A,len);
    for (int i=0;i<g.size();i++) B[i]=g[i];
    for (int i=g.size();i<len;i++) B[i]=0; DFT(B,len);
    for (int i=0;i<len;i++) A[i]=1LL*A[i]*B[i]%Mod; IDFT(A,len);
    vector<int> res(f.size()+g.size());
    for (int i=0;i<res.size();i++) res[i]=A[i];
    return res;
}
vector<int> F[N];
inline vector<int> solve(int l,int r) {
    if (l==r) {
        vector<int> res=F[l];
        for (int i=0;i<res.size();i++) res[i]=1LL*res[i]*inv[i]%Mod;
        return res;
    }
    return Mul(solve(l,(l+r)>>1),solve(((l+r)>>1)+1,r));
}
int main() {
    int n=read(),m=0; InitG();
    for (int i=1;i<=n;i++) m+=(a[i]=read());
    init(m<<1);
    for (int i=1;i<=n;i++) {
        vector<int> f(a[i]+1),g(a[i]+1);
        for (int j=1;j<=a[i];j++) f[j]=1LL*fac[j-1]*C(a[i]+j-1,a[i]-j)%Mod;
        for (int j=1;j<=a[i];j++) g[j]=(a[i]-j)&1?Mod-inv[a[i]-j]:inv[a[i]-j];
        if (i==1) for (int j=1;j<=a[i];j++) f[j]=1LL*f[j]*fac[j-1]%Mod*inv[j]%Mod;
        f=Mul(f,g),F[i].resize(a[i]+1);
        for (int j=1;j<=a[i];j++) F[i][j]=1LL*f[j+a[i]]*inv[j-1]%Mod;
    }
    vector<int> g=solve(2,n),f(a[1]); int ans=0;
    for (int i=0;i<f.size();i++) f[i]=1LL*F[1][i+1]*inv[i]%Mod; f=Mul(f,g); 
    for (int i=0;i<f.size();i++) (ans+=1LL*f[i]*fac[i]%Mod)%=Mod;
    if (a[1]>1) {
        vector<int> f(a[1]-1);
        for (int i=0;i<f.size();i++) f[i]=1LL*F[1][i+2]*inv[i]%Mod; f=Mul(f,g);
        for (int i=0;i<f.size();i++) (ans+=Mod-1LL*f[i]*fac[i]%Mod)%=Mod;
    }
    printf("%d\n",1LL*ans*m%Mod);
    return 0;
}
posted @ 2020-10-09 07:18  bo1949  阅读(162)  评论(0)    收藏  举报