【UNR #2】梦中的题面 题解
前言
题目链接:UOJ。
一道好题。
题意简述
给定 \(n,m,b,c\),求满足下列条件的 \(m\) 元组 \((x_1,\ldots,x_m)\) 的个数模 \(998244353\)。
- \(x_i\in \mathbb{Z}\);
 - \(0\le x_i\le b^i-c\);
 - \(\sum x_i < n\);
 
\(−10^8\le c< b, 2\le b< 10^8, 1\le m\le 80, 1\le n\le b^{m+1}\) 。
题目分析
在 \(c=1\) 时,有一个显然的数位 DP 能够做到 \(\mathcal{O}(m^3+m^2b)\) 的时间和 \(\mathcal{O}(m^2+mb)\) 的空间,此略。
直接数位 DP 不太可能了,瓶颈在于计算 \(c\) 个 \([0,b)\) 数恰好拼出 \(x\) 的方案数。因此考虑容斥。
先把 \(n\gets n-1\),转换为 \(\sum x_i\leq n\)。
考虑这样一个问题:\(x_i\ge0,\sum_{i=1}^m x_i\leq n\),方案数为 \(\binom{n+m}{m}\)。
枚举 \(S\subseteq\{i\}_{i=1}^m\),\(S\) 中元素不满足 \(x_i\le b^i-c\),即 \(x_i\ge b^i-(c-1)\),那么设 \(n'=n+(c-1)|S|-\sum_{i\in S}b^i\),剩下就是 \(x_i\ge 0\) 的问题,方案数为 \(\binom{n+(c-1)|S|-\sum_{i\in S}b^i+m}{m}\)。
考虑枚举 \(x=|S|\),设 \(X=n+(c-1)|S|+m\) 为常数。那么需要求 \((-1)^x\sum_{|S|=x}\binom{X-\sum_{i\in S}b^i}{m}\),倘若 \(X\ge\sum_{i\in S}b^i\),\(\binom{X-\sum_{i\in S}b^i}{m}\) 为关于 \(\sum_{i\in S}b^i\) 的多项式 \(A_x(\sum_{i\in S}b^i)\),这个之和 \(X,m\) 有关,可以搞出来。那么即为 \((-1)^x\sum_{|S|=x}A_x(\sum_{i\in S}b^i)=(-1)^x\sum_{|S|=x}\sum_i A_{x,i}(\sum_{i\in S}b^i)^i=(-1)^x\sum_i A_{x,i}\sum_{|S|=x}(\sum_{i\in S}b^i)^i\)。
问题即为所有 \(|S|=x\),且 \(\sum_{i\in S}b^i\le X\) 的 \(S\) 的 \(\sum_{i\in S}b^i\) 的 \(k\) 次幂之和。
把 \(X\) 算出来,类似数位 DP 即可。
时间复杂度可以做到 \(\mathcal{O}(m^4)\)。
代码
我相信上面的分析已经足够详细了,如果有没有看懂的地方欢迎评论。
我本人非常讨厌题解给出一坨很丑的代码,因此虽然给出的代码比较长,但是结构十分清晰,易于阅读、理解。
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int M = 85, SN = 650;
const int mod = 998244353;
inline int add(int a, int b) {
    return a += b, a >= mod ? a - mod : a;
}
inline int sub(int a, int b) {
    return a -= b, a < 0 ? a + mod : a;
}
inline int mul(int a, int b) {
    return (unsigned long long)a * b % mod;
}
int m, b, c;
int n[M], pwb[M];
int Inv[M], fac[M], ifac[M];
void init() {
    Inv[1] = 1;
    for (int i = 2; i <= m; ++i) {
        Inv[i] = mul(mod - mod / i, Inv[mod % i]);
    }
    fac[0] = ifac[0] = 1;
    for (int i = 1; i <= m; ++i) {
        fac[i] = mul(fac[i - 1], i);
        ifac[i] = mul(ifac[i - 1], Inv[i]);
        if (mul(fac[i], ifac[i]) != 1) throw;
    }
    pwb[0] = 1;
    for (int i = 1; i <= m; ++i) {
        pwb[i] = mul(pwb[i - 1], b);
    }
}
inline int CC(int n, int m) {
    return mul(fac[n], mul(ifac[m], ifac[n - m]));
}
namespace BINT {
string mul(const string &a, long long b) {
    string c;
    int n = a.length();
    long long jin = 0;
    for (int i = 0; i < n; ++i) {
        long long cur = (a[i] - '0') * b + jin;
        c += cur % 10 + '0';
        jin = cur / 10;
    }
    while (jin) c += jin % 10 + '0', jin /= 10;
    while (c.length() > 1u && c.back() == '0') c.pop_back();
    return c;
}
string O(string s) {
    reverse(s.begin(), s.end());
    return s;
}
int cmp(const string &a, const string &b) {
    int n = a.length();
    int m = b.length();
    if (n > m) return 1;
    if (n < m) return -1;
    for (int i = n - 1; ~i; --i) {
        if (a[i] > b[i]) return 1;
        if (a[i] < b[i]) return -1;
    }
    return 0;
}
string sub(const string &a, const string &b) {
    // assume a >= b
    int n = a.length();
    int m = b.length();
    int jie = 0;
    string c;
    for (int i = 0; i < n; ++i) {
        int x = a[i] - '0' - jie;
        if (i < m) x -= b[i] - '0';
        if (x < 0) jie = 1, x += 10;
        else jie = 0;
        c += x + '0';
    }
    while (c.length() > 1u && c.back() == '0') c.pop_back();
    return c;
}
string add(const string &a, const string &b) {
    int n = a.length();
    int m = b.length();
    int jin = 0;
    string c;
    for (int i = 0; i < max(n, m) || jin; ++i) {
        int x = jin;
        if (i < n) x += a[i] - '0';
        if (i < m) x += b[i] - '0';
        if (x >= 10) x -= 10, jin = 1;
        else jin = 0;
        c += x + '0';
    }
    while (c.length() > 1u && c.back() == '0') c.pop_back();
    return c;
}
string itos(long long x) {
    return O(to_string(x));
}
int tomodint(const string &x) {
    int res = 0;
    int n = x.length();
    for (int i = n - 1; ~i; --i) {
        res = ::add(::mul(res, 10), x[i] - '0');
    }
    return res;
}
string add(const string &a, long long b) {
    int n = a.length();
    string c;
    for (int i = 0; i < n || b; ++i) {
        if (i < n) b += a[i] - '0';
        c += b % 10 + '0';
        b /= 10;
    }
    while (c.length() > 1u && c.back() == '0') c.pop_back();
    return c;
}
string sub(const string &a, long long b) {
    int n = a.length();
    string c;
    int jie = 0;
    for (int i = 0; i < n; ++i) {
        int cur = a[i] - '0' - b % 10 - jie;
        b /= 10;
        if (cur <= 0) jie = 1, cur += 10;
        else jie = 0;
        c += cur + '0';
    }
    while (c.length() > 1u && c.back() == '0') c.pop_back();
    // b equals to 0
    return c;
}
char sn[SN];
string y[M], n;
void chai(int n[], string x) {
    for (int i = m; ~i; --i) {
        int l = 0, r = b - 1, ans = -1;
        while (l <= r) {
            int mid = (l + r) >> 1;
            string ty = mul(y[i], mid);
            if (cmp(x, ty) != -1)
                ans = mid, l = mid + 1;
            else
                r = mid - 1;
        }
        if (ans == -1) throw;
        string ty = mul(y[i], ans);
        if (cmp(x, ty) == -1) throw;
        x = sub(x, ty);
        n[i] = ans;
    }
}
void init() {
    y[0] = itos(1);
    for (int i = 1; i <= m + 1; ++i) {
        y[i] = mul(y[i - 1], b);
    }
    n = string(BINT::sn);
    n = O(n);
    if (n == "0") {
        puts("0");
        exit(0);
    }
    n = sub(n, itos(1));
}
}
namespace yzh {
int nX[M];
int A[M];
string X;
void getpoly() {
    for (int i = 0; i <= m; ++i) {
        A[i] = 0;
    }
    int X_mod = BINT::tomodint(X);
    A[0] = 1;
    for (int i = 1; i <= m; ++i) {
        // * (X_mod-i+1 - x)
        for (int j = i; j >= 0; --j) {
            A[j] = mul(A[j], X_mod - i + 1);
            if (j) A[j] = add(A[j], mul(A[j - 1], mod - 1));
        }
    }
    for (int i = 0; i <= m; ++i) {
        A[i] = mul(A[i], ifac[m]);
    }
}
struct node {
    int v[M];
    void clear() {
        memset(v, 0x00, sizeof(v));
    }
    int &operator[](int x) { return v[x]; }
    int const &operator[](int x) const { return v[x]; }
    friend inline node operator+(const node &a, const node &b) {
        node c;
        c.clear();
        for (int i = 0; i <= m; ++i)
            c[i] = ::add(a[i], b[i]);
        return c;
    }
    inline node add(int x) const {
        node b;
        b.clear();
        for (int i = 0; i <= m; ++i) {
            int pw = 1;
            for (int j = 0; j <= i; ++j) {
                b[i] = ::add(b[i], mul(mul(CC(i, j), pw), v[i - j]));
                pw = mul(pw, x);
            }
        }
        return b;
    }
};
node f[M][M];
bool vis[M][M];
node dfs(int i, int x, bool lim) {
    node y;
    y.clear();
    if (i == 0) {
        if (x == 0)
            y[0] = 1;
        return y;
    }
    if (!lim && vis[i][x]) return f[i][x];
    y.clear();
    for (int o = 0; o <= 1; ++o) {
        if (lim && o == 1 && nX[i] == 0) continue;
        if (x == 0 && o) continue;
        node const &nxt = dfs(i - 1, x - o, lim && o == nX[i]);
        if (o)
            y = y + nxt.add(pwb[i]);
        else
            y = y + nxt;
    }
    if (!lim) vis[i][x] = true, f[i][x] = y;
    return y;
}
int calc(int x) {
    long long delta = 1ll * (c - 1) * x + m;
    if (delta >= 0)
        X = BINT::add(BINT::n, delta);
    else {
        X = BINT::sub(BINT::n, -delta);
    }
    BINT::chai(nX, X);
    getpoly();
    auto res = dfs(m, x, true);
    int ans = 0;
    for (int i = 0; i <= m; ++i) {
        int t = 0;
        t = res[i];
        t = mul(t, A[i]);
        ans = add(ans, t);
    }
    return ans;
}
void solve() {
    int ans = 0;
    for (int x = 0; x <= m; ++x) {
        int res = calc(x);
        if (x & 1)
            ans = sub(ans, res);
        else
            ans = add(ans, res);
    }
    printf("%d", ans);
}
}
int main() {
    scanf("%d%d%d%s", &m, &b, &c, BINT::sn);
    init();
    BINT::init();
    yzh::solve();
    return 0;
}
如果你对 \(c=1\) 的部分分有兴趣的话(注意 UOJ 上 \(b\) 的范围的不同):
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int M = 85, SN = 650;
const int mod = 998244353;
inline int add(int a, int b) {
    return a += b, a >= mod ? a - mod : a;
}
inline int sub(int a, int b) {
    return a -= b, a < 0 ? a + mod : a;
}
inline int mul(int a, int b) {
    return (unsigned long long)a * b % mod;
}
inline void O(int &a, int b) {
    a += b, a = a >= mod ? a - mod : a;
}
int m, b, c;
int n[M];
namespace BINT {
string mul(const string &a, int b) {
    string c;
    int n = a.length();
    long long jin = 0;
    for (int i = 0; i < n; ++i) {
        long long cur = (a[i] - '0') * b + jin;
        c += cur % 10 + '0';
        jin = cur / 10;
    }
    while (jin) c += jin % 10 + '0', jin /= 10;
    while (c.length() > 1u && c.back() == '0') c.pop_back();
    return c;
}
string O(string s) {
    reverse(s.begin(), s.end());
    return s;
}
int cmp(const string &a, const string &b) {
    int n = a.length();
    int m = b.length();
    if (n > m) return 1;
    if (n < m) return -1;
    for (int i = n - 1; ~i; --i) {
        if (a[i] > b[i]) return 1;
        if (a[i] < b[i]) return -1;
    }
    return 0;
}
string sub(const string &a, const string &b) {
    // assume a >= b
    int n = a.length();
    int m = b.length();
    int jie = 0;
    string c;
    for (int i = 0; i < n; ++i) {
        int x = a[i] - '0' - jie;
        if (i < m) x -= b[i] - '0';
        if (x < 0) jie = 1, x += 10;
        else jie = 0;
        c += x + '0';
    }
    while (c.length() > 1u && c.back() == '0') c.pop_back();
    return c;
}
string add(const string &a, const string &b) {
    int n = a.length();
    int m = b.length();
    int jin = 0;
    string c;
    for (int i = 0; i < max(n, m) || jin; ++i) {
        int x = jin;
        if (i < n) x += a[i] - '0';
        if (i < m) x += b[i] - '0';
        if (x >= 10) x -= 10, jin = 1;
        else jin = 0;
        c += x + '0';
    }
    while (c.length() > 1u && c.back() == '0') c.pop_back();
    return c;
}
string itos(int x) {
    return O(to_string(x));
}
char sn[SN];
string y[M];
}
namespace sub1 {
bool check() {
    return c == 1 && b <= 3000000;
}
const int M = 18, B = 55 + 10;
int dp[2][M * B];
int f[M][M];
void solve() {
    dp[0][0] = 1;
    f[m + 1][0] = 1;
    for (int i = m; ~i; --i) {
        int ct = m - i;
        for (int j = 0; j <= m * b; ++j) {
            if (ct > 0) {
                dp[ct & 1][j] = sub(dp[~ct & 1][j], j >= b ? dp[~ct & 1][j - b] : 0);
            }
            if (j) {
                O(dp[ct & 1][j], dp[ct & 1][j - 1]);
            }
        }
        for (int j = 0; j <= m; ++j) {
            // j * b + n[i] - d <= m
            // j * b + n[i] - m <= d
            int x = j * b + n[i] - m;
            // d = [max(0, x), j * b + n[i]]
            for (int d = max(0, x); d <= j * b + n[i] && d <= ct * (b - 1); ++d)
                O(f[i][j * b + n[i] - d], mul(f[i + 1][j], sub(dp[ct & 1][d], d ? dp[ct & 1][d - 1] : 0)));
            // d = [0, x - 1]
            if (x >= 1) {
                O(f[i][m], mul(f[i + 1][j], dp[ct & 1][min(x - 1, ct * (b - 1))]));
            }
        }
    }
    int ans = 0;
    for (int j = 0; j <= m; ++j)
        ans = add(ans, f[0][j]);
    printf("%d", ans);
}
}
int main() {
    scanf("%d%d%d%s", &m, &b, &c, BINT::sn);
    {
        using namespace BINT;
        y[0] = string("1");
        for (int i = 1; i <= m + 1; ++i) {
            y[i] = mul(y[i - 1], b);
        }
        string x(sn);
        x = O(x);
        x = sub(x, itos(1));
        for (int i = m; ~i; --i) {
            int l = 0, r = b - 1, ans = -1;
            while (l <= r) {
                int mid = (l + r) >> 1;
                string ty = mul(y[i], mid);
                if (cmp(x, ty) != -1)
                    ans = mid, l = mid + 1;
                else
                    r = mid - 1;
            }
            if (ans == -1) throw;
            string ty = mul(y[i], ans);
            if (cmp(x, ty) == -1) throw;
            x = sub(x, ty);
            n[i] = ans;
        }
    }
    
    if (sub1::check()) return sub1::solve(), 0;
    return 0;
}
本文作者:XuYueming,转载请注明原文链接:https://www.cnblogs.com/XuYueming/p/19043656。
若未作特殊说明,本作品采用 知识共享署名-非商业性使用 4.0 国际许可协议 进行许可。

                
            
        
浙公网安备 33010602011771号