[题解]AT_abc215_g [ABC215G] Colorful Candies 2

思路

定义 \(vis_i\) 表示数 \(i\) 在序列中出现的次数。如果我们选出 \(k\) 个数,答案就是(其中 \(m\) 表示 \(\max(c_i)\)):

\[ \sum_{i = 1}^m\frac{\binom{n}{x} - \binom{n - vis_i}{k}}{\binom{n}{x}} \]

显然,我们只枚举序列中存在的元素,时间复杂度 \(\Theta(n^2)\),过不了,考虑优化。

不难发现,对于答案的贡献与其权值无关,之和出现的次数有关。那么,对于所有满足 \(i \neq j \wedge vis_i = vis_j\) 的元素,对于答案的贡献都是一样的。因此将其看作一种元素考虑。

答案就转变为了(\(p\) 为压缩后的序列大小,\(a\) 为压缩后的序列):

\[ \sum_{i = 1}^p(cnt_i \times \frac{\binom{n}{k} - \binom{n - vis_{a_i}}{k}}{\binom{n}{k}}) \]

时间复杂度为 \(\Theta(np)\),因为在最坏情况下,出现次数分别是:\(1,2,3,\dots\)。所以 \(p\)\(\Theta(\sqrt n)\) 级别的。

因此,时间复杂度为 \(\Theta(n \sqrt n)\)

Code

#include <bits/stdc++.h>  
#define int long long  
#define re register  
  
using namespace std;  
  
const int N = 5e4 + 10,mod = 998244353;  
int n,m;  
int arr[N],brr[N],mul[N],inv[N];  
map<int,int> vis,mp;  
  
inline int read(){  
    int r = 0,w = 1;  
    char c = getchar();  
    while (c < '0' || c > '9'){  
        if (c == '-') w = -1;  
        c = getchar();  
    }  
    while (c >= '0' && c <= '9'){  
        r = (r << 3) + (r << 1) + (c ^ 48);  
        c = getchar();  
    }  
    return r * w;  
}  
  
inline int exgcd(int a,int b,int &x,int &y){  
    if (!b){  
        x = 1;  
        y = 0;  
        return a;  
    }  
    int d = exgcd(b,a % b,y,x);  
    y = y - a / b * x;  
    return d;  
}  
  
inline void init(){  
    mul[0] = 1;  
    for (re int i = 1;i <= n;i++) mul[i] = mul[i - 1] * i % mod;  
    for (re int i = 0;i <= n;i++){  
        int a = mul[i],p = mod,x,y;  
        exgcd(a,p,x,y);  
        inv[i] = (x % mod + mod) % mod;  
    }  
}  
  
inline int C(int n,int m){  
    if (n < m) return 0;  
    return mul[n] * inv[n - m] % mod * inv[m] % mod;  
}  
  
signed main(){  
    n = read();  
    init();  
    for (re int i = 1;i <= n;i++){  
        int x;  
        x = read();  
        vis[x]++;  
    }  
    for (auto it = vis.begin();it != vis.end();it++) mp[it -> second]++;  
    for (auto it = mp.begin();it != mp.end();it++){  
        m++;  
        arr[m] = (it -> first);  
        brr[m] = (it -> second);  
    }  
    for (re int i = 1;i <= n;i++){  
        int ans = 0;  
        for (re int j = 1;j <= m;j++) ans = (ans + ((C(n,i) - C(n - arr[j],i)) % mod + mod) % mod * brr[j] % mod) % mod;  
        int a = C(n,i),p = mod,x,y;  
        exgcd(a,p,x,y);  
        int iv = (x % mod + mod) % mod;  
        printf("%lld\n",ans * iv % mod);  
    }  
    return 0;  
}  
posted @ 2024-06-22 10:50  WBIKPS  阅读(17)  评论(0)    收藏  举报