# [2022牛客多校赛第四场] C-Easy Counting Problem

## 题目大意

$$i\in[0,w)$$ $$(2\leq w\leq 10)$$

$$1\leq c_i\leq 50000,\sum c_i\leq 50000$$

$$q (1\leq q\leq 300)$$ 次询问，每次询问 $$n (1\leq n\leq 10^7)$$

## 题解

$$i$$ 恰好出现 $$c_i$$ 次，且 $$n=\sum c_i$$，则方案数为 $$\frac{n!}{\prod_i c_i !}$$

$\exp(x)=\sum_{i=0}^\infty \frac{x^i}{i!}$

$[x^n]n!\prod_{k=1}^{w}\left(\exp(x)-\sum_{i=0}^{c_k-1}\frac{x^i}{i!}\right)$

$$g_k(x)=\sum_{i=0}^{c_k-1}\frac{x^i}{i!}$$

$\prod_{k=1}^{w}\left(\exp(x)-g_k(x)\right)$

$(\exp(x)-g_1(x))(\exp(x)-g_2(x))=\\ \exp(2x)-(g_1(x)+g_2(x))\exp(x)+g_1(x)g_2(x)$

$$dp(i,j)$$ 只考虑前 $$i$$$$(\exp(x)-g(x))$$$$f_j(x)$$ 的值，则有

$dp(i,j)=dp(i-1,j-1)-dp(i-1,j)g_i(x)\\ =dp(i-1,j-1)-dp(i-1,j)\left(\sum_{j=0}^{c_i-1}\frac{x^j}{j!}\right)$

$$s=\sum c_i$$，于是只要进行 $$w^2$$ 次卷积就能计算出 $$f_0(x),f_1(x),\cdots,f_w(x)$$，时间复杂度 $$O(w^2s\log s)$$

## Code

#include <bits/stdc++.h>
using namespace std;

#define LL long long

template<typename elemType>
elemType X = 0, w = 0; char ch = 0;
while (!isdigit(ch)) { w |= ch == '-';ch = getchar(); }
while (isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar();
T = (w ? -X : X);
}

const LL MOD = 998244353;
const int maxn = 1e7 + 5;

LL qpow(LL b, LL n, LL MOD) {
if (MOD == 1) return 0;
LL x = 1, Power = b % MOD;
while (n) {
if (n & 1) x = x * Power % MOD;
Power = Power * Power % MOD;
n >>= 1;
}
return x;
}

namespace Poly {
const int maxn = 2100000;
int r[maxn];
int L, limit;
const LL P = 998244353, G = 3, Gi = 332748118;

LL pinv(LL x) { return qpow(x, P - 2, P); }

void NTT(LL* A, int type) {
for (int i = 0; i < limit; i++)
if (i < r[i]) swap(A[i], A[r[i]]);
for (int mid = 1; mid < limit; mid <<= 1) {
LL Wn = qpow(type == 1 ? G : Gi, (P - 1) / (mid << 1), P);
for (int j = 0; j < limit; j += (mid << 1)) {
LL w = 1;
for (int k = 0; k < mid; k++, w = (w * Wn) % P) {
int x = A[j + k], y = w * A[j + k + mid] % P;
A[j + k] = (x + y) % P;
A[j + k + mid] = (x - y + P) % P;
}
}
}
if (type == 1) return;
LL inv_limit = pinv(limit);
for (int i = 0; i < limit; ++i)
A[i] = A[i] * inv_limit % P;
}

void Conv(LL* a, int N, LL* b, LL M, LL* c) {
L = 0; limit = 1;
while (limit <= N + M) limit <<= 1, L++;
for (int i = N;i < limit;++i) a[i] = 0;
for (int i = M;i < limit;++i) b[i] = 0;
for (int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
NTT(a, 1); NTT(b, 1);
for (int i = 0; i < limit; i++) c[i] = a[i] * b[i] % P;
NTT(c, -1);
}
}

LL f[200010], g[200010];
int inv[maxn], fact[maxn], finv[maxn], c[11];
vector<LL> dp[2][11];
int q, w;

void init() {
inv[1] = fact[0] = fact[1] = finv[0] = finv[1] = 1;
for (int i = 2;i <= 1e7;++i) {
inv[i] = ((-1LL * (MOD / i) * inv[MOD % i]) % MOD + MOD) % MOD;
fact[i] = 1LL * fact[i - 1] * i % MOD;
finv[i] = 1LL * finv[i - 1] * inv[i] % MOD;
}
}

void add(const vector<LL>& a, const vector<LL>& b, vector<LL>& c) {
c.clear(); c.resize(max(a.size(), b.size()));
int p = 0;
while (p < a.size() && p < b.size()) { c[p] = (a[p] + b[p]) % MOD; ++p; }
while (p < a.size()) { c[p] = a[p]; ++p; }
while (p < b.size()) { c[p] = b[p]; ++p; }
}

void conv(const vector<LL>& a, const vector<LL>& b, vector<LL>& c) {
int n = a.size(), m = b.size();
for (int i = 0;i < n;++i) f[i] = a[i];
for (int i = 0;i < m;++i) g[i] = b[i];
Poly::Conv(f, n, g, m, f);
c.clear(); c.resize(n + m - 1);
for (int i = 0;i < n + m - 1;++i) c[i] = f[i];
}

LL solve(int n) {
LL ans = 0;
for (int k = 0;k <= w;++k) {
LL k_inv = qpow(k, MOD - 2, MOD);
LL kk = qpow(k, n, MOD);
for (int i = 0;i < dp[w & 1][k].size() && i <= n;++i) {
ans = (ans + dp[w & 1][k][i] * finv[n - i] % MOD * kk % MOD) % MOD;
kk = kk * k_inv % MOD;
}
}
ans = ans * fact[n] % MOD;
return ans;
}

vector<LL> vec;

int main() {
init();
for (int i = 1;i <= w;++i)
dp[0][0].push_back(1);
for (int i = 1;i <= w;++i) {
for (int j = 0;j <= i;++j) {
dp[i & 1][j].clear();
vec.clear(); vec.resize(c[i]);
for (int k = 0;k < c[i];++k)
vec[k] = MOD - finv[k];
conv(dp[(i & 1) ^ 1][j], vec, dp[i & 1][j]);
if (j == 0) continue;
add(dp[i & 1][j], dp[(i & 1) ^ 1][j - 1], vec);
dp[i & 1][j] = vec;
}
}
int s = 0;
for (int i = 1;i <= w;++i) s += c[i];