[题解] Atcoder ABC 225 H Social Distance 2 生成函数,分治FFT
题目
首先还没有安排座位的\(m-k\)个人之间是有顺序的,所以先把答案乘上\((m-k)!\),就可以把这些人看作不可区分的。
已经确定的k个人把所有座位分成了k+1段。对于第i段,如果我们能求出这一段恰好额外坐j人时的代价总和\(f_{i,j}\),并令\(f_{i,j}\)的普通生成函数为\(F_i(x)\),答案就是\(\prod F_i(x)\)的\(m-k\)次项系数。
先考虑k+1段中两边都有已经确定的人的k-1段。对于每一段i,枚举其中额外坐的人数j,现在考虑求出\(f_{i,j}\)。令\(g_i\)表示只考虑两个相邻的人,他们之间的距离为i时的代价,显然\(g_i=i\)。令\(G(x)\)为g的普通生成函数:
令第i段空座位两边端点之间的距离为len,发现\(f_{i,j}=G(x)^{j+1}\)的len次项系数(每j个人有j+1段空隙)。由于\(n\leq2e5\),所以可以对每一个j(\(j\geq0\))用这个公式暴力算:\(\frac{1}{(1-x)^m}=\sum_{n\geq0}\binom{n+m-1}{m-1}x^n\)。
考虑序列两头只有一边有已经确定的人的段。这里\(f_{i,j}=G(x)^j\)的0~len次项系数之和。根据上面的公式,我们实际要求的是一个组合数前缀和的形式。\(C(n,n)+C(n+1,n)+C(n+2,n)+\cdots+C(m,n)=C(m+1,n+1)\),可以根据这个直接\(O(1)\)算。
对于k=0的情况特殊处理,方法和上面处理序列两头的类似。
所以现在已经算出了每一段的\(F(x)\),项数之和是\(O(n)\)的,用分治FFT把所有\(F(x)\)卷起来即可。卷之前把所有\(F(x)\)顺序打乱,防止分治的时候被卡。时间复杂度\(O(nlog^2n)\)。
点击查看代码
#include <bits/stdc++.h>
#include <atcoder/all>
#define rep(i,n) for(int i=0;i<n;++i)
#define repn(i,n) for(int i=1;i<=n;++i)
#define LL long long
#define pii pair <int,int>
#define fi first
#define se second
#define mpr make_pair
#define pb push_back
using namespace std;
using mint=atcoder::modint998244353;
const LL MOD=998244353;
LL qpow(LL x,LL a)
{
LL res=x,ret=1;
while(a>0)
{
if((a&1)==1) ret=ret*res%MOD;
a>>=1;
res=res*res%MOD;
}
return ret;
}
LL n,m,k,a[200010],fac[400010],inv[400010];
vector <vector <mint> > v;
LL C(LL nn,LL mm){return fac[nn]*inv[mm]%MOD*inv[nn-mm]%MOD;}
void deal(LL emp)
{
if(emp<=0) return;
vector <mint> tmp;tmp.pb(1);
repn(seg,emp) tmp.pb(C(emp+seg,seg+seg));
v.pb(tmp);
}
vector <mint> solve(LL lb,LL ub)
{
if(lb==ub) return v[lb];
return atcoder::convolution(solve(lb,(lb+ub)/2),solve((lb+ub)/2+1,ub));
}
int main()
{
fac[0]=1;repn(i,400005) fac[i]=fac[i-1]*(LL)i%MOD;
rep(i,400003) inv[i]=qpow(fac[i],MOD-2);
cin>>n>>m>>k;
rep(i,k) scanf("%lld",&a[i]);
if(k==0)
{
LL seg=m-1,res=0;
for(LL len=m-1;len<n;++len) (res+=C(len+seg-1,seg+seg-1)*(n-len))%=MOD;
cout<<res*fac[m]%MOD<<endl;
return 0;
}
rep(i,k-1)
{
LL len=a[i+1]-a[i];
vector <mint> tmp;
repn(seg,len)
{
LL val=C(len+seg-1,seg+seg-1);
tmp.pb(val);
}
v.pb(tmp);
}
deal(a[0]-1);deal(n-a[k-1]);
random_shuffle(v.begin(),v.end());
vector <mint> ans=solve(0,v.size()-1);
cout<<(LL)ans[m-k].val()*fac[m-k]%MOD<<endl;
return 0;
}