AT_abc225_h 题解

[ABC225H] Social Distance 2

题目大意

有 $N$ 个椅子排列在一行,一个椅子只能坐一个人,$M$ 个人每个人会坐一把椅子,假设 $B_1,\cdots,B_m$ 是他们坐的椅子排序后的序列,那么这样的贡献是 $\prod_{i=1}^{m-1} (b_{i+1}-b_i)$。

现在有 $k$ 个人已经确定了座位,求对于剩下的人的每种可能坐的位置的排列的贡献之和。

  • $2\leq N\leq 2\times 10^5,2\leq M\leq N,0\leq K\leq M,1\leq A_1<A_2<\cdots<A_K\leq N$

记得对 $998244353$ 取模(为什么翻译没有把这个写出来……)

思路

一看到连乘就可以想到生成函数,可以看做是每一段中间放一个点,每一段的贡献则为这一段的点选出一个点的方案数。把每个空隙的多项式处理好然后分治乘起来即可,至于具体实现可以借助 NTT 来做。

显然如果 $k=0$,答案就是 $C^{2m+1}_{n+m+1}$。

对于两边的段,选 $i$ 个贡献为 $C^{2i}_{len+i-1}$。

对于中间的段,选 $i$ 个贡献为 $C^{2i+1}_{len+i}$。

总的时间复杂度为 $O(n\log^2n)$,可以通过此题。

代码

#include <iostream>
#include <vector>
#define int long long
#define MAXN 2000005
#define mod 998244353
#define G 3
#define IG 332748118
using namespace std;
int n, m, k, a[MAXN];
int fac[MAXN], inv[MAXN], ifc[MAXN];
vector <int> tr[MAXN << 2];
int tree[MAXN], aa[MAXN], bb[MAXN];
int read(){
    int t = 1, x = 0;char ch = getchar();
    while(!isdigit(ch)){if(ch == '-')t = -1;ch = getchar();}
    while(isdigit(ch)){x = (x << 1) + (x << 3) + (ch ^ 48);ch = getchar();}
    return x * t;
}
void write(int x){
    if(x < 0){putchar('-');x = -x;}
    if(x >= 10)write(x / 10);
    putchar(x % 10 ^ 48);
}
int c(int n, int m){
    if(n < 0 || m < 0 || n < m)return 0;
    int res = fac[n] * ifc[m] % mod * ifc[n - m] % mod;
    return res;
}
int qpow(int a, int b){
    int res = 1;
    while(b > 0){
        if(b & 1){res *= a;res %= mod;}
        a *= a;a %= mod;b >>= 1;
    }
    return res;
}
void NTT(int *f, int len, int flag){
    for(int i = 0 ; i < len ; i ++)
        if(i < tree[i])swap(f[i], f[tree[i]]);
    for(int i = 2 ; i <= len ; i <<= 1){
        int left = i >> 1, w = qpow(flag ? G : IG, (mod - 1) / i);
        for(int j = 0 ; j < len ; j += i){
            int wi = 1;
            for(int k = j ; k < j + left ; k ++){
                int t = f[k + left] * wi % mod;
                f[k + left] = (f[k] - t + mod) % mod;
                f[k] = (f[k] + t) % mod;
                wi = (wi * w) % mod;
            }
        }
    }
    if(flag == 0){
        int tmp = qpow(len, mod - 2);
        for(int i = 0 ; i < len ; i ++)f[i] = f[i] * tmp % mod;
    }
}
void build(int node, int left, int right){
    if(left == right){
        tr[node].resize(a[left + 1] - a[left]);
        for(int i = 0 ; i < a[left + 1] - a[left] ; i ++)
            tr[node][i] = c(a[left + 1] - a[left] + i - (left == 0 || left == k), (i << 1) + 1 - (left == 0 || left == k));
        return;
    }
    int mid = left + right >> 1;
    build(node << 1, left, mid);build(node << 1 | 1, mid + 1, right);
    int len = 1, lena = a[right + 1] - a[left];tr[node].resize(lena);
    while(len < (lena << 1))len <<= 1;
    for(int i = 1 ; i < len ; i ++)tree[i] = (tree[i >> 1] >> 1) | ((i & 1) ? len > 1 : 0);
    for(int i = 0 ; i < len ; i ++)aa[i] = 0;
    for(int i = 0 ; i < len ; i ++)bb[i] = 0;
    for(int i = 0 ; i < tr[node << 1].size() ; i ++)aa[i] = tr[node << 1][i];
    for(int i = 0 ; i < tr[node << 1 | 1].size() ; i ++)bb[i] = tr[node << 1 | 1][i];
    NTT(aa, len, 1);NTT(bb, len, 1);
    for(int i = 0 ; i < len ; i ++)aa[i] = aa[i] * bb[i] % mod;
    NTT(aa, len, 0);
    for(int i = 0 ; i < lena ; i ++)tr[node][i] = aa[i];

}
signed main(){
    n = read();m = read();k = read();
    for(int i = 1 ; i <= k ; i ++)a[i] = read();
    fac[0] = 1;inv[1] = 1;ifc[0] = 1;m -= k;
    for(int i = 1 ; i <= MAXN ; i ++)fac[i] = fac[i - 1] * i % mod;
    for(int i = 2 ; i <= MAXN ; i ++)inv[i] = (mod - mod / i) * inv[mod % i] % mod;
    for(int i = 1 ; i <= MAXN ; i ++)ifc[i] = ifc[i - 1] * inv[i] % mod;
    if(k == 0)write(fac[m] * c(n + m - 1, (m << 1) + 1) % mod);
    else{a[k + 1] = n + 1;build(1, 0, k);write(tr[1][m] * fac[m] % mod);}
    putchar('\n');return 0;
}
posted @ 2023-08-29 15:32  tsqtsqtsq  阅读(25)  评论(0)    收藏  举报  来源