UOJ#793 【UR #24】比特智慧
前言
似乎有人觉得这题很好拿分,可是我觉得这题在一场总时长三个小时的比赛中很难拿分啊。感觉 \(m=n-1\) 就需要不少的观察了,还要在 \(n=3\) 的样例下调出正确的式子,很难吧 /可怜。
题解
首先假设形态固定,考虑怎么算答案?
考虑先按照题目描述,按照顺序依次加边。如果不考虑这条边产生的相邻关系,那么这个问题就递归成了两个子问题,两个子问题是独立的。
这个时候考虑上这条边带来的相邻关系。假设左边和右边的边都没有被"分割"开,那么问题的答案乘上 \(a=\frac{k-1}{k}\) 即可;有一边被分割就是 \(b=\frac{k(k-1)(k-2)}{k^2(k-1)}\),两边都被割就是 \(c=\frac{(k^2-3k+3)k(k-1)}{k^2(k-1)}\)。
接下来我们分析问题的结构。观察题面里给的图,很容易想到,把所有边看成点,所有加的边形成森林。
首先先算出对于树的答案。假设边数为 \(i\) 的答案为 \(w_i\)。\(F(x,y)\) 中 \(x\) 表示多边形的边数,\(y\) 表示连通块的个数,值为方案数。
考虑用第一个点所在的树把整张图分开(假设树的大小为 \(i\)),这样就会把整个问题分割成 \(i\) 个子问题。
为了方便给 \(F\) 乘个 \(x\)。
设 \(G\) 为 \(F\) 的复合逆,即:\(F(G) = G(F) = x\)。将式子带入 \(G\) 得:
又有 \([x^{n+1}] F = \frac{1}{n+1}[x^{n}] (1+yW(x))^{n+1}\)
所以只要算出 \([x^{n}]W^{n-m}\) 即可。
接下来算 \(W\)。为了方便我定义变量我还是设成 \(F\) 和 \(G\)。
其中 \(F\) 表示:
\(G\) 表示:
之前要要算的 \(W\) 就是 \(x(1+F)\)。最后要求的就是 \([x^{m}](1+F)^{n-m}\)。因此再次使用拉格朗日反演,我们只需要算出 \(F\) 的复合逆 \(g\) 即可。考虑在这两个式子内都带入 \(g\)。设 \(u = G(g)\)。
最终要求的大概就是 \(g^{-m}\)。\(g\) 是一个有理分式所以可以线性甚至根号log。
代码
#include<bits/stdc++.h>
#define L(i, j, k) for(int i = (j); i <= (k); ++i)
#define R(i, j, k) for(int i = (j); i >= (k); --i)
#define ll long long
#define sz(a) ((int) (a).size())
#define vi vector < int >
#define me(a, x) memset(a, x, sizeof(a))
#define ull unsigned long long
#define ld __float128
using namespace std;
const int mod = 998244353, N = 1 << 20, inv2 = (mod + 1) / 2;
#define add(a, b) (a + b >= mod ? a + b - mod : a + b)
#define dec(a, b) (a < b ? a - b + mod : a - b)
int qpow(int x, int y = mod - 2) {
int res = 1;
for(; y; x = (ll) x * x % mod, y >>= 1) if(y & 1) res = (ll) res * x % mod;
return res;
}
int fac[N], ifac[N], inv[N];
void init(int x) {
fac[0] = ifac[0] = inv[1] = 1;
L(i, 2, x) inv[i] = (ll) inv[mod % i] * (mod - mod / i) % mod;
L(i, 1, x) fac[i] = (ll) fac[i - 1] * i % mod, ifac[i] = (ll) ifac[i - 1] * inv[i] % mod;
}
int C(int x, int y) {
return y < 0 || x < y ? 0 : (ll) fac[x] * ifac[y] % mod * ifac[x - y] % mod;
}
struct poly {
vector<int> a;
int size() { return sz(a); }
int & operator [] (int x) { return a[x]; }
int v(int x) { return x < 0 || x >= sz(a) ? 0 : a[x]; }
void clear() { vector<int> ().swap(a); }
void rs(int x = 0) { a.resize(x); }
poly (int n = 0) { rs(n); }
poly (vector<int> o) { a = o; }
poly (const poly &o) { a = o.a; }
poly Rs(int x = 0) { vi res = a; res.resize(x); return res; }
friend poly operator * (poly aa, poly bb) {
if(!sz(aa) || !sz(bb)) return {};
int lim, all = sz(aa) + sz(bb) - 1;
vi ns(all);
L(i, 0, sz(aa) - 1) L(j, 0, sz(bb) - 1)
(ns[i + j] += (ll) aa[i] * bb[j] % mod) %= mod;
return ns;
}
poly Shift (int x) {
poly zm (sz(a) + x);
L(i, 0, sz(a) - 1) zm[i + x] = a[i];
return zm;
}
friend poly operator * (poly aa, int bb) {
poly res(sz(aa));
L(i, 0, sz(aa) - 1) res[i] = (ll) aa[i] * bb % mod;
return res;
}
friend poly operator * (int bb, poly aa) {
poly res(sz(aa));
L(i, 0, sz(aa) - 1) res[i] = (ll) aa[i] * bb % mod;
return res;
}
friend poly operator + (poly aa, poly bb) {
vector<int> res(max(sz(aa), sz(bb)));
L(i, 0, sz(res) - 1) res[i] = add(aa.v(i), bb.v(i));
return poly(res);
}
friend poly operator - (poly aa, poly bb) {
vector<int> res(max(sz(aa), sz(bb)));
L(i, 0, sz(res) - 1) res[i] = dec(aa.v(i), bb.v(i));
return poly(res);
}
poly Deriv() {
if(!sz(a)) return poly();
poly res(sz(a) - 1);
L(i, 1, sz(a) - 1) res[i - 1] = (ll) a[i] * i % mod;
return res;
}
} ;
int n, m, k;
inline int Inv(int x) {
return x == 1 ? 1 : (ll) Inv(mod % x) * (mod - mod / x) % mod;
}
int ns[N];
void Main() {
cin >> n >> m >> k;
int a = (ll) (k - 1) * Inv(k) % mod;
int b = (ll) (k - 2) * Inv(k) % mod;
int c = ((ll) k * k % mod + mod - (ll) 3 * k % mod + 3) % mod *
Inv(k) % mod * Inv(k - 1) % mod;
poly F = vi{a, b};
poly G = (poly)vi{a, 2 * b % mod, c} * vi{a, 2 * b % mod, c};
poly A = F * G, B = (G.Deriv() * F - F.Deriv() * G) * m;
ns[0] = qpow(a, m);
int iv = qpow(A[0]);
L(i, 1, n) {
// [x^{n-1}] sum_{i} A[i] H'[n-i-1] = [x^n] H * B
ns[i] = 0;
L(j, 1, min(sz(A) - 1, i)) {
(ns[i] += mod - (ll) A[j] * ns[i - j] % mod * (i - j) % mod) %= mod;
}
L(j, 0, min(sz(B) - 1, i - 1)) {
(ns[i] += (ll) B[j] * ns[i - j - 1] % mod) %= mod;
}
ns[i] = (ll) ns[i] * inv[i] % mod * iv % mod;
}
int ret = 0;
L(i, 0, m - 1)
(ret += (ll) C(n - m, i + 1) * (i + 1) % mod * ns[m - i - 1] % mod) %= mod;
cout << (ll) ret * inv[m] % mod * C(n + 1, n - m) % mod *
inv[n + 1] % mod * qpow(k, m + 1) % mod << '\n';
}
int main() {
init(N >> 1);
ios :: sync_with_stdio (false);
cin.tie(0); cout.tie(0);
int t; cin >> t; while(t--) Main();
return 0;
}