[题解] Atcoder ABC 225 H Social Distance 2 生成函数,分治FFT

题目
首先还没有安排座位的\(m-k\)个人之间是有顺序的,所以先把答案乘上\((m-k)!\),就可以把这些人看作不可区分的。

已经确定的k个人把所有座位分成了k+1段。对于第i段,如果我们能求出这一段恰好额外坐j人时的代价总和\(f_{i,j}\),并令\(f_{i,j}\)的普通生成函数为\(F_i(x)\),答案就是\(\prod F_i(x)\)\(m-k\)次项系数。


先考虑k+1段中两边都有已经确定的人的k-1段。对于每一段i,枚举其中额外坐的人数j,现在考虑求出\(f_{i,j}\)。令\(g_i\)表示只考虑两个相邻的人,他们之间的距离为i时的代价,显然\(g_i=i\)。令\(G(x)\)为g的普通生成函数:

\[\begin{align} G(x)&=\sum_{n>0}n\cdot x^n\\ &=x\sum_{n\geq0}(n+1)x^n\\ &=x \cdot \frac1{(1-x)^2}\\ &=\frac x {(1-x)^2}\\ \end{align} \]

令第i段空座位两边端点之间的距离为len,发现\(f_{i,j}=G(x)^{j+1}\)的len次项系数(每j个人有j+1段空隙)。由于\(n\leq2e5\),所以可以对每一个j(\(j\geq0\))用这个公式暴力算:\(\frac{1}{(1-x)^m}=\sum_{n\geq0}\binom{n+m-1}{m-1}x^n\)


考虑序列两头只有一边有已经确定的人的段。这里\(f_{i,j}=G(x)^j\)的0~len次项系数之和。根据上面的公式,我们实际要求的是一个组合数前缀和的形式。\(C(n,n)+C(n+1,n)+C(n+2,n)+\cdots+C(m,n)=C(m+1,n+1)\),可以根据这个直接\(O(1)\)算。


对于k=0的情况特殊处理,方法和上面处理序列两头的类似。

所以现在已经算出了每一段的\(F(x)\),项数之和是\(O(n)\)的,用分治FFT把所有\(F(x)\)卷起来即可。卷之前把所有\(F(x)\)顺序打乱,防止分治的时候被卡。时间复杂度\(O(nlog^2n)\)

点击查看代码
#include <bits/stdc++.h>
#include <atcoder/all>

#define rep(i,n) for(int i=0;i<n;++i)
#define repn(i,n) for(int i=1;i<=n;++i)
#define LL long long
#define pii pair <int,int>
#define fi first
#define se second
#define mpr make_pair
#define pb push_back

using namespace std;
using mint=atcoder::modint998244353;

const LL MOD=998244353;

LL qpow(LL x,LL a)
{
	LL res=x,ret=1;
	while(a>0)
	{
		if((a&1)==1) ret=ret*res%MOD;
		a>>=1;
		res=res*res%MOD;
	}
	return ret;
}

LL n,m,k,a[200010],fac[400010],inv[400010];
vector <vector <mint> > v;

LL C(LL nn,LL mm){return fac[nn]*inv[mm]%MOD*inv[nn-mm]%MOD;}

void deal(LL emp)
{
  if(emp<=0) return;
  vector <mint> tmp;tmp.pb(1);
  repn(seg,emp) tmp.pb(C(emp+seg,seg+seg));
  v.pb(tmp);
}

vector <mint> solve(LL lb,LL ub)
{
  if(lb==ub) return v[lb];
  return atcoder::convolution(solve(lb,(lb+ub)/2),solve((lb+ub)/2+1,ub));
}

int main()
{
  fac[0]=1;repn(i,400005) fac[i]=fac[i-1]*(LL)i%MOD;
  rep(i,400003) inv[i]=qpow(fac[i],MOD-2);
  cin>>n>>m>>k;
  rep(i,k) scanf("%lld",&a[i]);
  if(k==0)
  {
    LL seg=m-1,res=0;
    for(LL len=m-1;len<n;++len) (res+=C(len+seg-1,seg+seg-1)*(n-len))%=MOD;
    cout<<res*fac[m]%MOD<<endl;
    return 0;
  }
  rep(i,k-1)
  {
    LL len=a[i+1]-a[i];
    vector <mint> tmp;
    repn(seg,len)
    {
      LL val=C(len+seg-1,seg+seg-1);
      tmp.pb(val);
    }
    v.pb(tmp);
  }
  deal(a[0]-1);deal(n-a[k-1]);
  random_shuffle(v.begin(),v.end());
  vector <mint> ans=solve(0,v.size()-1);
  cout<<(LL)ans[m-k].val()*fac[m-k]%MOD<<endl;
	return 0;
}
posted @ 2022-05-16 11:07  LegendStane  阅读(138)  评论(1)    收藏  举报