[2022牛客多校赛第四场] C-Easy Counting Problem

题目大意

统计长度为\(n\)且数位\(i\)出现至少\(c_i\)次的数字串数量。

\(i\in[0,w)\) \((2\leq w\leq 10)\)

\(1\leq c_i\leq 50000,\sum c_i\leq 50000\)

\(q (1\leq q\leq 300)\) 次询问,每次询问 \(n (1\leq n\leq 10^7)\)

题解

\(i\) 恰好出现 \(c_i\) 次,且 \(n=\sum c_i\),则方案数为 \(\frac{n!}{\prod_i c_i !}\)

考虑指数型生成函数,

\[\exp(x)=\sum_{i=0}^\infty \frac{x^i}{i!} \]

数位 \(i\) 至少出现 \(c_i\) 次,所以需要减去次数小于 \(c_i\) 的项,即 \(\exp(x)-\sum_{i=0}^{c_i-1}\frac{x^i}{i!}\)

则最终答案为

\[[x^n]n!\prod_{k=1}^{w}\left(\exp(x)-\sum_{i=0}^{c_k-1}\frac{x^i}{i!}\right) \]

但是 \(n\) 太大了,不好直接卷积出来。

\(g_k(x)=\sum_{i=0}^{c_k-1}\frac{x^i}{i!}\)

考虑一下这个式子

\[\prod_{k=1}^{w}\left(\exp(x)-g_k(x)\right) \]

\[(\exp(x)-g_1(x))(\exp(x)-g_2(x))=\\ \exp(2x)-(g_1(x)+g_2(x))\exp(x)+g_1(x)g_2(x) \]

发现 \(\prod_{k=1}^{w}\left(\exp(x)-g_k(x)\right)\)\(\sum_{k=0}^{w} \exp(kx)f_k(x)\) 的形式。

于是我们想要求出 \(f_0(x),f_1(x),\cdots,f_w(x)\)

\(dp(i,j)\) 只考虑前 \(i\)\((\exp(x)-g(x))\)\(f_j(x)\) 的值,则有

\[dp(i,j)=dp(i-1,j-1)-dp(i-1,j)g_i(x)\\ =dp(i-1,j-1)-dp(i-1,j)\left(\sum_{j=0}^{c_i-1}\frac{x^j}{j!}\right) \]

\(s=\sum c_i\),于是只要进行 \(w^2\) 次卷积就能计算出 \(f_0(x),f_1(x),\cdots,f_w(x)\),时间复杂度 \(O(w^2s\log s)\)

对于单次询问,给定 \(n\),我们需要计算出 \(n!\sum_{k=0}^{w} \exp(kx)f_k(x)\)\(n\) 次项系数即为答案。对于每个 \(\exp(kx)f_k(x)\),直接暴力卷积 \(n\) 次项,因为 \(f_k(x)\) 最高不超过 \(s\) 次项,所以时间复杂度为 \(O(s)\),每次询问要暴力求 \(w\) 个卷积,复杂度 \(O(ws)\)\(q\) 次询问,复杂度 \(O(qws)\)。还需要 \(O(n)\) 预处理阶乘及其逆元,综上,本题的时间复杂度为 \(O(n+w^2s\log s+qws)\)

Code

#include <bits/stdc++.h>
using namespace std;

#define LL long long

template<typename elemType>
inline void Read(elemType& T) {
    elemType X = 0, w = 0; char ch = 0;
    while (!isdigit(ch)) { w |= ch == '-';ch = getchar(); }
    while (isdigit(ch)) X = (X << 3) + (X << 1) + (ch ^ 48), ch = getchar();
    T = (w ? -X : X);
}

const LL MOD = 998244353;
const int maxn = 1e7 + 5;

LL qpow(LL b, LL n, LL MOD) {
    if (MOD == 1) return 0;
    LL x = 1, Power = b % MOD;
    while (n) {
        if (n & 1) x = x * Power % MOD;
        Power = Power * Power % MOD;
        n >>= 1;
    }
    return x;
}

namespace Poly {
    const int maxn = 2100000;
    int r[maxn];
    int L, limit;
    const LL P = 998244353, G = 3, Gi = 332748118;

    LL pinv(LL x) { return qpow(x, P - 2, P); }

    void NTT(LL* A, int type) {
        for (int i = 0; i < limit; i++)
            if (i < r[i]) swap(A[i], A[r[i]]);
        for (int mid = 1; mid < limit; mid <<= 1) {
            LL Wn = qpow(type == 1 ? G : Gi, (P - 1) / (mid << 1), P);
            for (int j = 0; j < limit; j += (mid << 1)) {
                LL w = 1;
                for (int k = 0; k < mid; k++, w = (w * Wn) % P) {
                    int x = A[j + k], y = w * A[j + k + mid] % P;
                    A[j + k] = (x + y) % P;
                    A[j + k + mid] = (x - y + P) % P;
                }
            }
        }
        if (type == 1) return;
        LL inv_limit = pinv(limit);
        for (int i = 0; i < limit; ++i)
            A[i] = A[i] * inv_limit % P;
    }

    void Conv(LL* a, int N, LL* b, LL M, LL* c) {
        L = 0; limit = 1;
        while (limit <= N + M) limit <<= 1, L++;
        for (int i = N;i < limit;++i) a[i] = 0;
        for (int i = M;i < limit;++i) b[i] = 0;
        for (int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
        NTT(a, 1); NTT(b, 1);
        for (int i = 0; i < limit; i++) c[i] = a[i] * b[i] % P;
        NTT(c, -1);
    }
}

LL f[200010], g[200010];
int inv[maxn], fact[maxn], finv[maxn], c[11];
vector<LL> dp[2][11];
int q, w;

void init() {
    inv[1] = fact[0] = fact[1] = finv[0] = finv[1] = 1;
    for (int i = 2;i <= 1e7;++i) {
        inv[i] = ((-1LL * (MOD / i) * inv[MOD % i]) % MOD + MOD) % MOD;
        fact[i] = 1LL * fact[i - 1] * i % MOD;
        finv[i] = 1LL * finv[i - 1] * inv[i] % MOD;
    }
}

void add(const vector<LL>& a, const vector<LL>& b, vector<LL>& c) {
    c.clear(); c.resize(max(a.size(), b.size()));
    int p = 0;
    while (p < a.size() && p < b.size()) { c[p] = (a[p] + b[p]) % MOD; ++p; }
    while (p < a.size()) { c[p] = a[p]; ++p; }
    while (p < b.size()) { c[p] = b[p]; ++p; }
}

void conv(const vector<LL>& a, const vector<LL>& b, vector<LL>& c) {
    int n = a.size(), m = b.size();
    for (int i = 0;i < n;++i) f[i] = a[i];
    for (int i = 0;i < m;++i) g[i] = b[i];
    Poly::Conv(f, n, g, m, f);
    c.clear(); c.resize(n + m - 1);
    for (int i = 0;i < n + m - 1;++i) c[i] = f[i];
}

LL solve(int n) {
    LL ans = 0;
    for (int k = 0;k <= w;++k) {
        LL k_inv = qpow(k, MOD - 2, MOD);
        LL kk = qpow(k, n, MOD);
        for (int i = 0;i < dp[w & 1][k].size() && i <= n;++i) {
            ans = (ans + dp[w & 1][k][i] * finv[n - i] % MOD * kk % MOD) % MOD;
            kk = kk * k_inv % MOD;
        }
    }
    ans = ans * fact[n] % MOD;
    return ans;
}

vector<LL> vec;

int main() {
    init();
    Read(w);
    for (int i = 1;i <= w;++i)
        Read(c[i]);
    dp[0][0].push_back(1);
    for (int i = 1;i <= w;++i) {
        for (int j = 0;j <= i;++j) {
            dp[i & 1][j].clear();
            vec.clear(); vec.resize(c[i]);
            for (int k = 0;k < c[i];++k)
                vec[k] = MOD - finv[k];
            conv(dp[(i & 1) ^ 1][j], vec, dp[i & 1][j]);
            if (j == 0) continue;
            add(dp[i & 1][j], dp[(i & 1) ^ 1][j - 1], vec);
            dp[i & 1][j] = vec;
        }
    }
    int s = 0;
    for (int i = 1;i <= w;++i) s += c[i];
    Read(q);
    while (q--) {
        int n; Read(n);
        if (n < s) printf("0\n");
        else printf("%lld\n", solve(n));
    }

    return 0;
}
posted @ 2022-07-31 17:54  AE酱  阅读(48)  评论(0编辑  收藏  举报