AT_abc225_h 题解
[ABC225H] Social Distance 2
题目大意
有 $N$ 个椅子排列在一行,一个椅子只能坐一个人,$M$ 个人每个人会坐一把椅子,假设 $B_1,\cdots,B_m$ 是他们坐的椅子排序后的序列,那么这样的贡献是 $\prod_{i=1}^{m-1} (b_{i+1}-b_i)$。
现在有 $k$ 个人已经确定了座位,求对于剩下的人的每种可能坐的位置的排列的贡献之和。
- $2\leq N\leq 2\times 10^5,2\leq M\leq N,0\leq K\leq M,1\leq A_1<A_2<\cdots<A_K\leq N$
记得对 $998244353$ 取模(为什么翻译没有把这个写出来……)
思路
一看到连乘就可以想到生成函数,可以看做是每一段中间放一个点,每一段的贡献则为这一段的点选出一个点的方案数。把每个空隙的多项式处理好然后分治乘起来即可,至于具体实现可以借助 NTT 来做。
显然如果 $k=0$,答案就是 $C^{2m+1}_{n+m+1}$。
对于两边的段,选 $i$ 个贡献为 $C^{2i}_{len+i-1}$。
对于中间的段,选 $i$ 个贡献为 $C^{2i+1}_{len+i}$。
总的时间复杂度为 $O(n\log^2n)$,可以通过此题。
代码
#include <iostream>
#include <vector>
#define int long long
#define MAXN 2000005
#define mod 998244353
#define G 3
#define IG 332748118
using namespace std;
int n, m, k, a[MAXN];
int fac[MAXN], inv[MAXN], ifc[MAXN];
vector <int> tr[MAXN << 2];
int tree[MAXN], aa[MAXN], bb[MAXN];
int read(){
int t = 1, x = 0;char ch = getchar();
while(!isdigit(ch)){if(ch == '-')t = -1;ch = getchar();}
while(isdigit(ch)){x = (x << 1) + (x << 3) + (ch ^ 48);ch = getchar();}
return x * t;
}
void write(int x){
if(x < 0){putchar('-');x = -x;}
if(x >= 10)write(x / 10);
putchar(x % 10 ^ 48);
}
int c(int n, int m){
if(n < 0 || m < 0 || n < m)return 0;
int res = fac[n] * ifc[m] % mod * ifc[n - m] % mod;
return res;
}
int qpow(int a, int b){
int res = 1;
while(b > 0){
if(b & 1){res *= a;res %= mod;}
a *= a;a %= mod;b >>= 1;
}
return res;
}
void NTT(int *f, int len, int flag){
for(int i = 0 ; i < len ; i ++)
if(i < tree[i])swap(f[i], f[tree[i]]);
for(int i = 2 ; i <= len ; i <<= 1){
int left = i >> 1, w = qpow(flag ? G : IG, (mod - 1) / i);
for(int j = 0 ; j < len ; j += i){
int wi = 1;
for(int k = j ; k < j + left ; k ++){
int t = f[k + left] * wi % mod;
f[k + left] = (f[k] - t + mod) % mod;
f[k] = (f[k] + t) % mod;
wi = (wi * w) % mod;
}
}
}
if(flag == 0){
int tmp = qpow(len, mod - 2);
for(int i = 0 ; i < len ; i ++)f[i] = f[i] * tmp % mod;
}
}
void build(int node, int left, int right){
if(left == right){
tr[node].resize(a[left + 1] - a[left]);
for(int i = 0 ; i < a[left + 1] - a[left] ; i ++)
tr[node][i] = c(a[left + 1] - a[left] + i - (left == 0 || left == k), (i << 1) + 1 - (left == 0 || left == k));
return;
}
int mid = left + right >> 1;
build(node << 1, left, mid);build(node << 1 | 1, mid + 1, right);
int len = 1, lena = a[right + 1] - a[left];tr[node].resize(lena);
while(len < (lena << 1))len <<= 1;
for(int i = 1 ; i < len ; i ++)tree[i] = (tree[i >> 1] >> 1) | ((i & 1) ? len > 1 : 0);
for(int i = 0 ; i < len ; i ++)aa[i] = 0;
for(int i = 0 ; i < len ; i ++)bb[i] = 0;
for(int i = 0 ; i < tr[node << 1].size() ; i ++)aa[i] = tr[node << 1][i];
for(int i = 0 ; i < tr[node << 1 | 1].size() ; i ++)bb[i] = tr[node << 1 | 1][i];
NTT(aa, len, 1);NTT(bb, len, 1);
for(int i = 0 ; i < len ; i ++)aa[i] = aa[i] * bb[i] % mod;
NTT(aa, len, 0);
for(int i = 0 ; i < lena ; i ++)tr[node][i] = aa[i];
}
signed main(){
n = read();m = read();k = read();
for(int i = 1 ; i <= k ; i ++)a[i] = read();
fac[0] = 1;inv[1] = 1;ifc[0] = 1;m -= k;
for(int i = 1 ; i <= MAXN ; i ++)fac[i] = fac[i - 1] * i % mod;
for(int i = 2 ; i <= MAXN ; i ++)inv[i] = (mod - mod / i) * inv[mod % i] % mod;
for(int i = 1 ; i <= MAXN ; i ++)ifc[i] = ifc[i - 1] * inv[i] % mod;
if(k == 0)write(fac[m] * c(n + m - 1, (m << 1) + 1) % mod);
else{a[k + 1] = n + 1;build(1, 0, k);write(tr[1][m] * fac[m] % mod);}
putchar('\n');return 0;
}

浙公网安备 33010602011771号