【UNR #2】梦中的题面 题解
前言
题目链接:UOJ。
一道好题。
题意简述
给定 \(n,m,b,c\),求满足下列条件的 \(m\) 元组 \((x_1,\ldots,x_m)\) 的个数模 \(998244353\)。
- \(x_i\in \mathbb{Z}\);
- \(0\le x_i\le b^i-c\);
- \(\sum x_i < n\);
\(−10^8\le c< b, 2\le b< 10^8, 1\le m\le 80, 1\le n\le b^{m+1}\) 。
题目分析
在 \(c=1\) 时,有一个显然的数位 DP 能够做到 \(\mathcal{O}(m^3+m^2b)\) 的时间和 \(\mathcal{O}(m^2+mb)\) 的空间,此略。
直接数位 DP 不太可能了,瓶颈在于计算 \(c\) 个 \([0,b)\) 数恰好拼出 \(x\) 的方案数。因此考虑容斥。
先把 \(n\gets n-1\),转换为 \(\sum x_i\leq n\)。
考虑这样一个问题:\(x_i\ge0,\sum_{i=1}^m x_i\leq n\),方案数为 \(\binom{n+m}{m}\)。
枚举 \(S\subseteq\{i\}_{i=1}^m\),\(S\) 中元素不满足 \(x_i\le b^i-c\),即 \(x_i\ge b^i-(c-1)\),那么设 \(n'=n+(c-1)|S|-\sum_{i\in S}b^i\),剩下就是 \(x_i\ge 0\) 的问题,方案数为 \(\binom{n+(c-1)|S|-\sum_{i\in S}b^i+m}{m}\)。
考虑枚举 \(x=|S|\),设 \(X=n+(c-1)|S|+m\) 为常数。那么需要求 \((-1)^x\sum_{|S|=x}\binom{X-\sum_{i\in S}b^i}{m}\),倘若 \(X\ge\sum_{i\in S}b^i\),\(\binom{X-\sum_{i\in S}b^i}{m}\) 为关于 \(\sum_{i\in S}b^i\) 的多项式 \(A_x(\sum_{i\in S}b^i)\),这个之和 \(X,m\) 有关,可以搞出来。那么即为 \((-1)^x\sum_{|S|=x}A_x(\sum_{i\in S}b^i)=(-1)^x\sum_{|S|=x}\sum_i A_{x,i}(\sum_{i\in S}b^i)^i=(-1)^x\sum_i A_{x,i}\sum_{|S|=x}(\sum_{i\in S}b^i)^i\)。
问题即为所有 \(|S|=x\),且 \(\sum_{i\in S}b^i\le X\) 的 \(S\) 的 \(\sum_{i\in S}b^i\) 的 \(k\) 次幂之和。
把 \(X\) 算出来,类似数位 DP 即可。
时间复杂度可以做到 \(\mathcal{O}(m^4)\)。
代码
我相信上面的分析已经足够详细了,如果有没有看懂的地方欢迎评论。
我本人非常讨厌题解给出一坨很丑的代码,因此虽然给出的代码比较长,但是结构十分清晰,易于阅读、理解。
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int M = 85, SN = 650;
const int mod = 998244353;
inline int add(int a, int b) {
return a += b, a >= mod ? a - mod : a;
}
inline int sub(int a, int b) {
return a -= b, a < 0 ? a + mod : a;
}
inline int mul(int a, int b) {
return (unsigned long long)a * b % mod;
}
int m, b, c;
int n[M], pwb[M];
int Inv[M], fac[M], ifac[M];
void init() {
Inv[1] = 1;
for (int i = 2; i <= m; ++i) {
Inv[i] = mul(mod - mod / i, Inv[mod % i]);
}
fac[0] = ifac[0] = 1;
for (int i = 1; i <= m; ++i) {
fac[i] = mul(fac[i - 1], i);
ifac[i] = mul(ifac[i - 1], Inv[i]);
if (mul(fac[i], ifac[i]) != 1) throw;
}
pwb[0] = 1;
for (int i = 1; i <= m; ++i) {
pwb[i] = mul(pwb[i - 1], b);
}
}
inline int CC(int n, int m) {
return mul(fac[n], mul(ifac[m], ifac[n - m]));
}
namespace BINT {
string mul(const string &a, long long b) {
string c;
int n = a.length();
long long jin = 0;
for (int i = 0; i < n; ++i) {
long long cur = (a[i] - '0') * b + jin;
c += cur % 10 + '0';
jin = cur / 10;
}
while (jin) c += jin % 10 + '0', jin /= 10;
while (c.length() > 1u && c.back() == '0') c.pop_back();
return c;
}
string O(string s) {
reverse(s.begin(), s.end());
return s;
}
int cmp(const string &a, const string &b) {
int n = a.length();
int m = b.length();
if (n > m) return 1;
if (n < m) return -1;
for (int i = n - 1; ~i; --i) {
if (a[i] > b[i]) return 1;
if (a[i] < b[i]) return -1;
}
return 0;
}
string sub(const string &a, const string &b) {
// assume a >= b
int n = a.length();
int m = b.length();
int jie = 0;
string c;
for (int i = 0; i < n; ++i) {
int x = a[i] - '0' - jie;
if (i < m) x -= b[i] - '0';
if (x < 0) jie = 1, x += 10;
else jie = 0;
c += x + '0';
}
while (c.length() > 1u && c.back() == '0') c.pop_back();
return c;
}
string add(const string &a, const string &b) {
int n = a.length();
int m = b.length();
int jin = 0;
string c;
for (int i = 0; i < max(n, m) || jin; ++i) {
int x = jin;
if (i < n) x += a[i] - '0';
if (i < m) x += b[i] - '0';
if (x >= 10) x -= 10, jin = 1;
else jin = 0;
c += x + '0';
}
while (c.length() > 1u && c.back() == '0') c.pop_back();
return c;
}
string itos(long long x) {
return O(to_string(x));
}
int tomodint(const string &x) {
int res = 0;
int n = x.length();
for (int i = n - 1; ~i; --i) {
res = ::add(::mul(res, 10), x[i] - '0');
}
return res;
}
string add(const string &a, long long b) {
int n = a.length();
string c;
for (int i = 0; i < n || b; ++i) {
if (i < n) b += a[i] - '0';
c += b % 10 + '0';
b /= 10;
}
while (c.length() > 1u && c.back() == '0') c.pop_back();
return c;
}
string sub(const string &a, long long b) {
int n = a.length();
string c;
int jie = 0;
for (int i = 0; i < n; ++i) {
int cur = a[i] - '0' - b % 10 - jie;
b /= 10;
if (cur <= 0) jie = 1, cur += 10;
else jie = 0;
c += cur + '0';
}
while (c.length() > 1u && c.back() == '0') c.pop_back();
// b equals to 0
return c;
}
char sn[SN];
string y[M], n;
void chai(int n[], string x) {
for (int i = m; ~i; --i) {
int l = 0, r = b - 1, ans = -1;
while (l <= r) {
int mid = (l + r) >> 1;
string ty = mul(y[i], mid);
if (cmp(x, ty) != -1)
ans = mid, l = mid + 1;
else
r = mid - 1;
}
if (ans == -1) throw;
string ty = mul(y[i], ans);
if (cmp(x, ty) == -1) throw;
x = sub(x, ty);
n[i] = ans;
}
}
void init() {
y[0] = itos(1);
for (int i = 1; i <= m + 1; ++i) {
y[i] = mul(y[i - 1], b);
}
n = string(BINT::sn);
n = O(n);
if (n == "0") {
puts("0");
exit(0);
}
n = sub(n, itos(1));
}
}
namespace yzh {
int nX[M];
int A[M];
string X;
void getpoly() {
for (int i = 0; i <= m; ++i) {
A[i] = 0;
}
int X_mod = BINT::tomodint(X);
A[0] = 1;
for (int i = 1; i <= m; ++i) {
// * (X_mod-i+1 - x)
for (int j = i; j >= 0; --j) {
A[j] = mul(A[j], X_mod - i + 1);
if (j) A[j] = add(A[j], mul(A[j - 1], mod - 1));
}
}
for (int i = 0; i <= m; ++i) {
A[i] = mul(A[i], ifac[m]);
}
}
struct node {
int v[M];
void clear() {
memset(v, 0x00, sizeof(v));
}
int &operator[](int x) { return v[x]; }
int const &operator[](int x) const { return v[x]; }
friend inline node operator+(const node &a, const node &b) {
node c;
c.clear();
for (int i = 0; i <= m; ++i)
c[i] = ::add(a[i], b[i]);
return c;
}
inline node add(int x) const {
node b;
b.clear();
for (int i = 0; i <= m; ++i) {
int pw = 1;
for (int j = 0; j <= i; ++j) {
b[i] = ::add(b[i], mul(mul(CC(i, j), pw), v[i - j]));
pw = mul(pw, x);
}
}
return b;
}
};
node f[M][M];
bool vis[M][M];
node dfs(int i, int x, bool lim) {
node y;
y.clear();
if (i == 0) {
if (x == 0)
y[0] = 1;
return y;
}
if (!lim && vis[i][x]) return f[i][x];
y.clear();
for (int o = 0; o <= 1; ++o) {
if (lim && o == 1 && nX[i] == 0) continue;
if (x == 0 && o) continue;
node const &nxt = dfs(i - 1, x - o, lim && o == nX[i]);
if (o)
y = y + nxt.add(pwb[i]);
else
y = y + nxt;
}
if (!lim) vis[i][x] = true, f[i][x] = y;
return y;
}
int calc(int x) {
long long delta = 1ll * (c - 1) * x + m;
if (delta >= 0)
X = BINT::add(BINT::n, delta);
else {
X = BINT::sub(BINT::n, -delta);
}
BINT::chai(nX, X);
getpoly();
auto res = dfs(m, x, true);
int ans = 0;
for (int i = 0; i <= m; ++i) {
int t = 0;
t = res[i];
t = mul(t, A[i]);
ans = add(ans, t);
}
return ans;
}
void solve() {
int ans = 0;
for (int x = 0; x <= m; ++x) {
int res = calc(x);
if (x & 1)
ans = sub(ans, res);
else
ans = add(ans, res);
}
printf("%d", ans);
}
}
int main() {
scanf("%d%d%d%s", &m, &b, &c, BINT::sn);
init();
BINT::init();
yzh::solve();
return 0;
}
如果你对 \(c=1\) 的部分分有兴趣的话(注意 UOJ 上 \(b\) 的范围的不同):
#include <cstdio>
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int M = 85, SN = 650;
const int mod = 998244353;
inline int add(int a, int b) {
return a += b, a >= mod ? a - mod : a;
}
inline int sub(int a, int b) {
return a -= b, a < 0 ? a + mod : a;
}
inline int mul(int a, int b) {
return (unsigned long long)a * b % mod;
}
inline void O(int &a, int b) {
a += b, a = a >= mod ? a - mod : a;
}
int m, b, c;
int n[M];
namespace BINT {
string mul(const string &a, int b) {
string c;
int n = a.length();
long long jin = 0;
for (int i = 0; i < n; ++i) {
long long cur = (a[i] - '0') * b + jin;
c += cur % 10 + '0';
jin = cur / 10;
}
while (jin) c += jin % 10 + '0', jin /= 10;
while (c.length() > 1u && c.back() == '0') c.pop_back();
return c;
}
string O(string s) {
reverse(s.begin(), s.end());
return s;
}
int cmp(const string &a, const string &b) {
int n = a.length();
int m = b.length();
if (n > m) return 1;
if (n < m) return -1;
for (int i = n - 1; ~i; --i) {
if (a[i] > b[i]) return 1;
if (a[i] < b[i]) return -1;
}
return 0;
}
string sub(const string &a, const string &b) {
// assume a >= b
int n = a.length();
int m = b.length();
int jie = 0;
string c;
for (int i = 0; i < n; ++i) {
int x = a[i] - '0' - jie;
if (i < m) x -= b[i] - '0';
if (x < 0) jie = 1, x += 10;
else jie = 0;
c += x + '0';
}
while (c.length() > 1u && c.back() == '0') c.pop_back();
return c;
}
string add(const string &a, const string &b) {
int n = a.length();
int m = b.length();
int jin = 0;
string c;
for (int i = 0; i < max(n, m) || jin; ++i) {
int x = jin;
if (i < n) x += a[i] - '0';
if (i < m) x += b[i] - '0';
if (x >= 10) x -= 10, jin = 1;
else jin = 0;
c += x + '0';
}
while (c.length() > 1u && c.back() == '0') c.pop_back();
return c;
}
string itos(int x) {
return O(to_string(x));
}
char sn[SN];
string y[M];
}
namespace sub1 {
bool check() {
return c == 1 && b <= 3000000;
}
const int M = 18, B = 55 + 10;
int dp[2][M * B];
int f[M][M];
void solve() {
dp[0][0] = 1;
f[m + 1][0] = 1;
for (int i = m; ~i; --i) {
int ct = m - i;
for (int j = 0; j <= m * b; ++j) {
if (ct > 0) {
dp[ct & 1][j] = sub(dp[~ct & 1][j], j >= b ? dp[~ct & 1][j - b] : 0);
}
if (j) {
O(dp[ct & 1][j], dp[ct & 1][j - 1]);
}
}
for (int j = 0; j <= m; ++j) {
// j * b + n[i] - d <= m
// j * b + n[i] - m <= d
int x = j * b + n[i] - m;
// d = [max(0, x), j * b + n[i]]
for (int d = max(0, x); d <= j * b + n[i] && d <= ct * (b - 1); ++d)
O(f[i][j * b + n[i] - d], mul(f[i + 1][j], sub(dp[ct & 1][d], d ? dp[ct & 1][d - 1] : 0)));
// d = [0, x - 1]
if (x >= 1) {
O(f[i][m], mul(f[i + 1][j], dp[ct & 1][min(x - 1, ct * (b - 1))]));
}
}
}
int ans = 0;
for (int j = 0; j <= m; ++j)
ans = add(ans, f[0][j]);
printf("%d", ans);
}
}
int main() {
scanf("%d%d%d%s", &m, &b, &c, BINT::sn);
{
using namespace BINT;
y[0] = string("1");
for (int i = 1; i <= m + 1; ++i) {
y[i] = mul(y[i - 1], b);
}
string x(sn);
x = O(x);
x = sub(x, itos(1));
for (int i = m; ~i; --i) {
int l = 0, r = b - 1, ans = -1;
while (l <= r) {
int mid = (l + r) >> 1;
string ty = mul(y[i], mid);
if (cmp(x, ty) != -1)
ans = mid, l = mid + 1;
else
r = mid - 1;
}
if (ans == -1) throw;
string ty = mul(y[i], ans);
if (cmp(x, ty) == -1) throw;
x = sub(x, ty);
n[i] = ans;
}
}
if (sub1::check()) return sub1::solve(), 0;
return 0;
}
本文作者:XuYueming,转载请注明原文链接:https://www.cnblogs.com/XuYueming/p/19043656。
若未作特殊说明,本作品采用 知识共享署名-非商业性使用 4.0 国际许可协议 进行许可。

浙公网安备 33010602011771号