洛谷 P5401 [CTS 2019] 珍珠 题解

题目链接

\(c_i\)表示第i种颜色的珍珠的数量,显然我们最多能装的瓶数是\(\sum \lfloor \frac {c_i}2 \rfloor\)。也就是说,\(c_i\)为奇数的\(i\)的数量不能太多,这个数量要在某个范围内才行,且这个范围是很好求的。考虑对每个n求出\(f_n\)表示刚好有n个\(c_i\)为奇数的方案数。

如何指定某些颜色一定选奇数个,并把所有不同的颜色排列组合成一行n个?考虑使用指数型生成函数(EGF)。复习一下EGF相乘的组合意义:\(A(x)=\sum_{i\ge 0}a_i\frac {x^i}{i!},B(x)=\sum_{i\ge 0}b_i\frac {x^i}{i!}\),其中\(a_i\)表示i个元素的集合做A操作的方案数,\(b_i\)表示i个元素的集合做B操作的方案数。若\(C(x)=A(x)\cdot B(x)=\sum_{i\ge 0}c_i\frac {x^i}{i!}\),则\(c_i\)表示i个元素的集合分成两个子集,其中第一个做A操作,第二个做B操作的总方案数。那么可以把i个元素的集合"有奇数个"也看成一种操作,\(0,1,2,3\cdots\)个元素的集合对应操作数是\(0,1,0,1\cdots\)。由于存在如下恒等式\(e^x=\sum_{i\ge 0}\frac{x^i}{i!}\),所以\(e^x\)的系数序列是\(1,1,1,1\cdots\)\(e^{-x}\)的系数序列是\(1,-1,1,-1\cdots\),那么\(0,1,0,1\cdots\)其实就是\(\frac{e^x-e^{-x}}2\)对应的系数序列。i个元素有偶数个的方案数"系数序列"也一样,可以用\(\frac{e^x+e^{-x}}2\)表示。

那么\(f_i\)就能表示成\(([n](\frac{e^x-e^{-x}}2)^i(\frac{e^x+e^{-x}}2)^{D-i})\cdot n!\cdot \binom di\)。其中\([n]\)表示这个函数的\(n\)次项系数,乘\(n!\)是因为这个函数是EGF。但是这样似乎不太能继续推了,考虑重新定义\(f\)。令\(g_i\)表示原来\(f_i\)的定义。对于\(f_i\):钦定任意i个颜色必须选奇数个,其他颜色随便,并把方案数贡献到\(f_i\)。这样,对于\(i<j\),每种实际奇数个数为i的方案都在\(f_j\)中被计算了\(\binom ji\)次,也就是\(g_i=\sum_{j\ge i}\binom ji f_j\)。如果求出了\(f\),可以用二项式反演在\(O(DlogD)\)的时间内求出\(g\)

现在的\(f_i=([n](\frac{e^x-e^{-x}}2)^i(e^x)^{D-i})\cdot n!\cdot \binom di\)。直接二项式定理暴力展开:

\[\begin{align} f_i&=(\frac 12)^in!\binom di \cdot\sum_{j=0}^i [n]e^{D+2j-2i}(-1)^{i-j}\binom ij\\ &=(\frac 12)^in!\binom di \cdot\sum_{j=0}^i \frac{(D+2j-2i)^n}{n!} (-1)^{i-j}\binom ij\\ &=(\frac 12)^i\binom di \cdot\sum_{j=0}^i (D+2j-2i)^n(-1)^{i-j}\binom ij\\ \end{align} \]

显然,这玩意儿可以写成一个以\(j\)为下标的序列和一个\(i-j\)为下标的序列卷积得到\(f_i\)的形式,所以可以\(O(DlogD)\)求出\(f\)

总复杂度\(O(DlogD)\)

点击查看代码
#include <bits/stdc++.h>

#define rep(i,n) for(int i=0;i<n;++i)
#define repn(i,n) for(int i=1;i<=n;++i)
#define LL long long
#define pii pair <int,int>
#define fi first
#define se second
#define mpr make_pair
#define pb push_back

void fileio()
{
  #ifdef LGS
  freopen("in.txt","r",stdin);
  freopen("out.txt","w",stdout);
  #endif
}
void termin()
{
  #ifdef LGS
  std::cout<<"\n\nEXECUTION TERMINATED";
  #endif
  exit(0);
}

using namespace std;

const LL MOD=998244353;

LL qpow(LL x,LL a)
{
	LL res=x,ret=1;
	while(a>0)
	{
		if((a&1)==1) ret=ret*res%MOD;
		a>>=1;
		res=res*res%MOD;
	}
	return ret;
}

namespace NTT
{
  vector <LL> rev;
  void ntt(vector <LL> &a,LL G)
  {
    LL nn=a.size(),gn,g,x,y;vector <LL> tmp=a;
    rep(i,nn) a[i]=tmp[rev[i]];
    for(int len=1;len<nn;len<<=1)
    {
      gn=qpow(G,(MOD-1)/(len<<1));
      for(int i=0;i<nn;i+=(len<<1))
      {
        g=1;
        for(int j=i;j<i+len;++j,(g*=gn)%=MOD)
        {
          x=a[j];y=a[j+len]*g%MOD;
          a[j]=(x+y)%MOD;a[j+len]=(x-y+MOD)%MOD;
        }
      }
    }
  }
  vector <LL> convolution(vector <LL> a,vector <LL> b,LL G)
  {
    if(a.size()<=23&&b.size()<=23)
    {
      vector <LL> ret(a.size()+b.size()-1,0);
      rep(i,a.size()) rep(j,b.size()) (ret[i+j]+=a[i]*b[j])%=MOD;
      return ret;
    }
    LL nn=1,bt=0,sv=a.size()+b.size()-1;while(nn<a.size()+b.size()-1) nn<<=1LL,++bt;
    while(a.size()<nn) a.pb(0);while(b.size()<nn) b.pb(0);
    rev.clear();
    rep(i,nn)
    {
      rev.pb(0);
      rev[i]=(rev[i>>1]>>1)|((i&1)<<(bt-1));
    }
    ntt(a,G);ntt(b,G);
    rep(i,nn) (a[i]*=b[i])%=MOD;
    ntt(a,qpow(G,MOD-2));
    while(a.size()>sv) a.pop_back();
    LL inv=qpow(nn,MOD-2);
    rep(i,a.size()) (a[i]*=inv)%=MOD;
    return a;
  }
}

LL d,n,m,fac[100010],inv[100010],f[100010],g[100010];

LL CC(LL nn,LL mm){return fac[nn]*inv[mm]%MOD*inv[nn-mm]%MOD;}

int main()
{
  fileio();

  fac[0]=1;repn(i,100005) fac[i]=fac[i-1]*i%MOD;
  rep(i,100005) inv[i]=qpow(fac[i],MOD-2);
  cin>>d>>n>>m;
  if(n==1)
  {
    puts("0");
    termin();
  }

  vector <LL> A,B,C;
  rep(i,d+2) A.pb(inv[i]);
  rep(i,d+2)
  {
    LL val=inv[i];
    if(i%2==1) val=(MOD-val)%MOD;
    (val*=qpow((MOD+d-i-i)%MOD,n))%=MOD;
    B.pb(val);
  }
  C=NTT::convolution(A,B,3);
  rep(i,d+1) f[i]=C[i]*CC(d,i)%MOD*qpow(qpow(2,i),MOD-2)%MOD*fac[i]%MOD;

  //cout<<f[0]<<' '<<f[1]<<' '<<f[2]<<endl;

  A.clear();B.clear();
  rep(i,d+3) A.pb(f[i]*fac[i]%MOD);
  rep(i,d+3) B.pb(inv[i]*(i%2==1 ? MOD-1:1LL)%MOD);
  reverse(B.begin(),B.end());
  C=NTT::convolution(A,B,3);
  rep(i,d+1) g[i]=C[i+d+2]*inv[i]%MOD;

  LL ans=0;
  rep(i,d+1) if(n-i>=m+m) (ans+=g[i])%=MOD;
  cout<<ans<<endl;

  termin();
}
posted @ 2022-12-21 20:55  LegendStane  阅读(60)  评论(0)    收藏  举报