题解:UOJ823 【UR #26】铁轨回收
题意
给定长度为 \(n\) 的序列 \(a,b\),按照 \(i=1\sim n-1\) 的顺序执行以下操作:
- 从 \([i+1,n]\) 中随机选择一个整数 \(j\),令 \(a_j\gets \min(a_j+a_i,b_j)\)。
对于每个 \(i\in [0,b_n]\),求最终 \(a_n=i\) 的概率,答案对 \(998244353\) 取模。\(1\leq n\leq 50\),\(0\leq a_i\leq b_i\leq 30\)。
题解
神仙题啊,太神仙了。
设最终得到的 \(a\) 序列为 \(c_{1\sim n}\)。
概率显然可以转计数,最后除以总方案数 \((n-1)!\) 即可。
考虑每次操作连一条 \(i\to j\) 的有向边,得到一棵以 \(n\) 为根的内向树。倒序考虑每个点 \(i\),考察其会带来什么限制。容易看出,若 \(c_i<b_i\) 则一定有 \(\sum\limits_{j\in son_i}c_j=c_i\)。而若 \(c_i=b_i\),此时只能说明 \(\sum\limits_{j\in son_i}c_j\geq b_i\),如果我们直接把 \(\sum\limits_{j\in son_i}c_j\) 放进 DP 状态里会比较炸。人类智慧地,我们考虑容斥,用无限制的总方案数,减去实际上 \(\sum\limits_{j\in son_i}c_j<b_i\) 的方案数。
对于有限制的点,我们只关心当前所有总和的限制构成的多重集 \(S\)(这里不记录 \(0\))。对于无限制的点,我们只关心这些点的个数 \(k\)。据此设计 DP 状态 \(f_{i,S,k}\),表示考虑了 \([i+1,n]\) 中的点,\(S,k\) 的意义如前文所述。
枚举点 \(i\) 的状态,可以得到以下 \(4\) 种转移:
- \(fa_i\) 为无限制的点:\(k\cdot f_{i,S,k}\to f_{i-1,S,k+1}\ (0\leq k\leq n-i)\)。
- \(fa_i\) 为有限制的点,且 \(c_i<b_i\)。我们枚举 \(c_i\) 和 \(x\in S\):\(f_{i,S,k}\to f_{i-1,(S-\{x\})\cup\{x-c_i,c_i-a_i\},k}\ (a_i\leq c_i<b_i)\)。
- \(fa_i\) 为有限制的点,且 \(c_i=b_i\),此处统计总方案数。我们枚举 \(x\in S\):\(f_{i,S,k}\to f_{i-1,(S-\{x\})\cup\{x-b_i\},k+1}\)。
- \(fa_i\) 为有限制的点,且 \(c_i=b_i\),此处容斥减去不合法的方案数。我们枚举实际的 \(c_i<b_i\) 和 \(x\in S\):\(-f_{i,S,k}\to f_{i-1,(S-\{x\})\cup\{x-b_i,c_i-a_i\}}\ (a_i\leq c_i<b_i)\)。
对于每个 \(a_n\leq c_n<b_n\),我们都跑一次 DP。初值为 \(f_{n-1,\{c_n-a_n\},0}=1\)(若 \(c_n=a_n\) 则 \(S=\varnothing\)),答案为 \(\sum\limits_{i=0}^{n-1}f_{0,\varnothing,i}\)。统计答案时:
- 若 \(c_n<a_n\),则答案为 \(0\)。
- 若 \(a_n\leq c_n<b_n\),则按照前文中的方式统计答案。
- 若 \(c_n=b_n\),则我们同样容斥,用 \((n-1)!\) 减去 \(ans_{a_n\sim b_n-1}\) 的和。
注意到 DP 过程中 \(S\) 的总和单调不增,因此 \(S\) 实际上是 \(\sum\limits_{i=1}^{b_n}\pi(i)=\mathcal{O}(\pi(b_n))\) 的量级,其中 \(\pi(i)\) 表示 \(i\) 的分拆数。
由于我们对于每个可能的 \(c_n\) 都要跑一次 DP,时间复杂度是 \(\mathcal{O}(b_n^3n^2\pi(b_n))\),无法承受。
根据经典结论,\(n\) 的任意一种正整数分拆中,本质不同的数只有 \(\mathcal{O}(\sqrt{n})\) 种,枚举 \(x\in S\) 时,相同的 \(x\) 的转移都是一样的,所以我们可以改为枚举 \(S\) 中所有本质不同的数 \(x\),转移时带上 \(cnt_S(x)\) 的权值。这样时间复杂度可以降至 \(\mathcal{O}(b_n^{2.5}n^2\pi(b_n))\),还是无法承受。
感受一下,对于每一个 \(c_n\) 都跑一次 DP 有点蠢,因为仅仅是初始状态变化了,转移过程是完全一样的。
考虑转置原理,将 DP 倒过来做。可以立即为把转移的方向反过来,系数不变。初值为 \(f_{0,\varnothing,1}=1\),考虑转移:
- \(fa_i\) 为无限制的点:\((k-1)f_{i-1,S,k}\to f_{i,S,k-1}\ (1\leq k\leq n-i+1)\)。
- \(fa_i\) 为有限制的点,且 \(c_i<b_i\)。我们枚举 \(x,y\in S\),相当于原来的 \(x-c_i,c_i-a_i\):\(f_{i-1,S,k}\to f_{i,(S-\{x,y\})\cup\{x+y+a_i\}}\ (0\leq k\leq n-i,\ 0\leq y<b_i-a_i)\)。
- \(fa_i\) 为有限制的点,且 \(c_i=b_i\),此处统计总方案数。我们枚举 \(x\in S\):\(f_{i-1,S,k}\to f_{i,(S-\{x\})\cup\{x+b_i\},k-1}\ (1\leq k\leq n-i+1)\)。
- \(fa_i\) 为有限制的点,且 \(c_i=b_i\),此处容斥减去不合法的方案数。我们枚举 \(x,y\in S\),相当于原来的 \(x-b_i,c_i-a_i\):\(-f_{i-1,S,k}\to f_{i,(S-\{x,y\})\cup \{x+b_i\}}\ (0\leq k\leq n-i,\ 0\leq y<b_i-a_i)\)。
需要特判 \(b_i=0\) 的情况,此时应该乘上 \(|S|\) 作为系数转移。
这样 \(f_{n-1,\{c_n-a_n\},0}\) 就直接是答案了!时间复杂度降至 \(\mathcal{O}(b_n^{1.5}n^2\pi(b_n))\),可以通过本题。
实现上,可以搜索求出所有可能的 \(S\) 编号,然后预处理出从 \(S\) 删除某个数 \(x\in S\) 后的转移点和转移系数,同理预处理出向 \(S\) 插入一个数 \(x\) 的转移点。需要注意的是,虽然 \(S\) 中不含 \(0\),但依然要考虑到插入 \(0\) 的情况,因为在原来倒着做的时候,我们完全有可能拆出来一个 \(0\)。还有需要保证转移过去的 \(S\) 的总和 \(\leq b_n\)。
代码
#include <bits/stdc++.h>
using namespace std;
#define lowbit(x) ((x) & -(x))
typedef long long ll;
typedef unsigned long long ull;
typedef long double ld;
typedef pair<int, int> pii;
const int N = 55, M = 3e4 + 5, MOD = 998244353;
template<typename T> inline void chk_min(T &x, T y) { x = min(x, y); }
template<typename T> inline void chk_max(T &x, T y) { x = max(x, y); }
template<typename T> inline T add(T x, T y) { return x += y, x >= MOD ? x - MOD : x; }
template<typename T> inline T sub(T x, T y) { return x -= y, x < 0 ? x + MOD : x; }
template<typename T> inline void cadd(T &x, T y) { x += y, x < MOD || (x -= MOD); }
template<typename T> inline void csub(T &x, T y) { x -= y, x < 0 && (x += MOD); }
int n, tot, a[N], b[N], f[2][M][N], ans[N];
map<vector<int>, int> id;
vector<int> st[M];
int sum[M], ins[M][N];
vector<pii> del[M];
queue<int> q;
int qpow(int a, int b) {
int res = 1;
for (; b; b >>= 1) {
if (b & 1) res = (ll)res * a % MOD;
a = (ll)a * a % MOD;
}
return res;
}
void get_id(const vector<int> &vec) {
if (id.find(vec) != id.end()) return;
st[id[vec] = ++tot] = vec;
for (int x : vec) sum[tot] += x;
q.push(tot);
}
void prework(int n) {
get_id({});
while (!q.empty()) {
int x = q.front(); q.pop();
auto vec = st[x];
for (int i = vec.empty() ? 1 : vec.back(); i <= n - sum[x]; ++i)
vec.push_back(i), get_id(vec), vec.pop_back();
}
for (int s = 1; s <= tot; ++s) {
auto &vec = st[s];
int cnt = 0;
for (int i = 0; i < vec.size(); ++i) {
++cnt;
if (i == vec.size() - 1 || vec[i] != vec[i + 1]) {
vector<int> tmp;
for (int j = 0; j < vec.size(); ++j) if (i != j) tmp.push_back(vec[j]);
del[s].push_back({id[tmp], cnt}), cnt = 0;
}
}
del[s].push_back({s, 1});
for (auto [v, _] : del[s]) ins[v][sum[s] - sum[v]] = s;
}
}
int main() {
ios::sync_with_stdio(0), cin.tie(0);
cin >> n;
for (int i = 1; i <= n; ++i) cin >> a[i] >> b[i];
prework(b[n]);
for (int k = 0; k < n; ++k) f[0][id[{}]][k] = 1;
for (int i = 1; i <= n - 1; ++i) {
int cur = i & 1, prv = cur ^ 1;
for (int s = 1; s <= tot; ++s) {
// k
for (int k = 1; k <= n - i + 1; ++k)
cadd<int>(f[cur][s][k - 1], (ll)f[prv][s][k] * (k - 1) % MOD);
// c_i = b_i, total
if (sum[s] + b[i] <= b[n]) {
if (!b[i]) {
for (int k = 1; k <= n - i + 1; ++k)
cadd<int>(f[cur][s][k - 1], (ll)f[prv][s][k] * (n - i - k + 1) % MOD);
} else {
for (auto [v, c] : del[s]) {
int to = ins[v][sum[s] - sum[v] + b[i]];
for (int k = 1; k <= n - i + 1; ++k)
cadd<int>(f[cur][to][k - 1], (ll)f[prv][s][k] * c % MOD);
}
}
}
// c_i = b_i, inclusion-exclusion || c_i < b_i
if (sum[s] + a[i] <= b[n]) {
for (auto [v1, cy] : del[s]) {
int y = sum[s] - sum[v1];
if (y >= b[i] - a[i]) continue;
for (auto d : {b[i], y + a[i]}) if (sum[v1] + d <= b[n]) {
int w = d == b[i] ? sub(0, cy) : cy;
if (!d) {
for (int k = 0; k <= n - i; ++k)
cadd<int>(f[cur][v1][k], (ll)f[prv][s][k] * (n - i - k) % MOD);
continue;
}
for (auto [v2, cx] : del[v1]) {
int x = sum[v1] - sum[v2], to = ins[v2][x + d];
int c = (ll)w * cx % MOD;
for (int k = 0; k <= n - i; ++k)
cadd<int>(f[cur][to][k], (ll)f[prv][s][k] * c % MOD);
}
}
}
}
}
for (int s = 1; s <= tot; ++s) fill(f[prv][s], f[prv][s] + n + 1, 0);
}
int fc = 1;
for (int i = 1; i < n; ++i) fc = (ll)fc * i % MOD;
int ifc = qpow(fc, MOD - 2);
ans[b[n]] = fc;
for (int i = a[n]; i < b[n]; ++i) {
int v = i == a[n] ? id[{}] : id[{i - a[n]}], x = f[n - 1 & 1][v][0];
cadd(ans[i], x), csub(ans[b[n]], x);
}
for (int i = 0; i <= b[n]; ++i) cout << (ll)ans[i] * ifc % MOD << ' ';
return 0;
}

浙公网安备 33010602011771号