# UOJ#793 【UR #24】比特智慧

## 题解

$F(x,y) = 1 + y\sum_{i \ge 1} w_i x^i F(x,y)^i$

$F(x,y) = x (1 + y\sum_{i \ge 1} w_i F(x,y)^i)$

$$G$$$$F$$ 的复合逆，即：$$F(G) = G(F) = x$$。将式子带入 $$G$$ 得：

$x = G (1 + y W(x))$

$G = \frac{x}{1 + yW(x)}$

$\frac{1}{n+1}[x^n][y^{n-m}](1+yW(x))^{n+1}$

$$G$$ 表示：

$F = G ( cF^2 + 2bF + a)$

$G = x + G^2 (c F + b)$

$x = u (cx^2+bx+a)$

$u = \frac{x}{cx^2+bx+a}$

$u = g + u^2 (cx+b)$

$g = u-u^2(cx+b)$

## 代码

#include<bits/stdc++.h>
#define L(i, j, k) for(int i = (j); i <= (k); ++i)
#define R(i, j, k) for(int i = (j); i >= (k); --i)
#define ll long long
#define sz(a) ((int) (a).size())
#define vi vector < int >
#define me(a, x) memset(a, x, sizeof(a))
#define ull unsigned long long
#define ld __float128
using namespace std;
const int mod = 998244353, N = 1 << 20, inv2 = (mod + 1) / 2;
#define add(a, b) (a + b >= mod ? a + b - mod : a + b)
#define dec(a, b) (a < b ? a - b + mod : a - b)
int qpow(int x, int y = mod - 2) {
int res = 1;
for(; y; x = (ll) x * x % mod, y >>= 1) if(y & 1) res = (ll) res * x % mod;
return res;
}
int fac[N], ifac[N], inv[N];
void init(int x) {
fac[0] = ifac[0] = inv[1] = 1;
L(i, 2, x) inv[i] = (ll) inv[mod % i] * (mod - mod / i) % mod;
L(i, 1, x) fac[i] = (ll) fac[i - 1] * i % mod, ifac[i] = (ll) ifac[i - 1] * inv[i] % mod;
}
int C(int x, int y) {
return y < 0 || x < y ? 0 : (ll) fac[x] * ifac[y] % mod * ifac[x - y] % mod;
}
struct poly {
vector<int> a;
int size() { return sz(a); }
int & operator [] (int x) { return a[x]; }
int v(int x) { return x < 0 || x >= sz(a) ? 0 : a[x]; }
void clear() { vector<int> ().swap(a); }
void rs(int x = 0) { a.resize(x); }
poly (int n = 0) { rs(n); }
poly (vector<int> o) { a = o; }
poly (const poly &o) { a = o.a; }
poly Rs(int x = 0) { vi res = a; res.resize(x); return res; }
friend poly operator * (poly aa, poly bb) {
if(!sz(aa) || !sz(bb)) return {};
int lim, all = sz(aa) + sz(bb) - 1;
vi ns(all);
L(i, 0, sz(aa) - 1) L(j, 0, sz(bb) - 1)
(ns[i + j] += (ll) aa[i] * bb[j] % mod) %= mod;
return ns;
}
poly Shift (int x) {
poly zm (sz(a) + x);
L(i, 0, sz(a) - 1) zm[i + x] = a[i];
return zm;
}
friend poly operator * (poly aa, int bb) {
poly res(sz(aa));
L(i, 0, sz(aa) - 1) res[i] = (ll) aa[i] * bb % mod;
return res;
}
friend poly operator * (int bb, poly aa) {
poly res(sz(aa));
L(i, 0, sz(aa) - 1) res[i] = (ll) aa[i] * bb % mod;
return res;
}
friend poly operator + (poly aa, poly bb) {
vector<int> res(max(sz(aa), sz(bb)));
L(i, 0, sz(res) - 1) res[i] = add(aa.v(i), bb.v(i));
return poly(res);
}
friend poly operator - (poly aa, poly bb) {
vector<int> res(max(sz(aa), sz(bb)));
L(i, 0, sz(res) - 1) res[i] = dec(aa.v(i), bb.v(i));
return poly(res);
}
poly Deriv() {
if(!sz(a)) return poly();
poly res(sz(a) - 1);
L(i, 1, sz(a) - 1) res[i - 1] = (ll) a[i] * i % mod;
return res;
}
} ;

int n, m, k;
inline int Inv(int x) {
return x == 1 ? 1 : (ll) Inv(mod % x) * (mod - mod / x) % mod;
}
int ns[N];
void Main() {
cin >> n >> m >> k;
int a = (ll) (k - 1) * Inv(k) % mod;
int b = (ll) (k - 2) * Inv(k) % mod;
int c = ((ll) k * k % mod + mod - (ll) 3 * k % mod + 3) % mod *
Inv(k) % mod * Inv(k - 1) % mod;
poly F = vi{a, b};
poly G = (poly)vi{a, 2 * b % mod, c} * vi{a, 2 * b % mod, c};
poly A = F * G, B = (G.Deriv() * F - F.Deriv() * G) * m;
ns[0] = qpow(a, m);
int iv = qpow(A[0]);
L(i, 1, n) {
// [x^{n-1}] sum_{i} A[i] H'[n-i-1] = [x^n] H * B
ns[i] = 0;
L(j, 1, min(sz(A) - 1, i)) {
(ns[i] += mod - (ll) A[j] * ns[i - j] % mod * (i - j) % mod) %= mod;
}
L(j, 0, min(sz(B) - 1, i - 1)) {
(ns[i] += (ll) B[j] * ns[i - j - 1] % mod) %= mod;
}
ns[i] = (ll) ns[i] * inv[i] % mod * iv % mod;
}
int ret = 0;
L(i, 0, m - 1)
(ret += (ll) C(n - m, i + 1) * (i + 1) % mod * ns[m - i - 1] % mod) %= mod;
cout << (ll) ret * inv[m] % mod * C(n + 1, n - m) % mod *
inv[n + 1] % mod * qpow(k, m + 1) % mod << '\n';
}

int main() {
init(N >> 1);
ios :: sync_with_stdio (false);
cin.tie(0); cout.tie(0);
int t; cin >> t; while(t--) Main();
return 0;
}

posted @ 2023-02-20 09:04  zhoukangyang  阅读(1753)  评论(1编辑  收藏  举报