P11420 [清华集训 2024] 乘积的期望
首先把 \(\prod a\) 拆开,对于一个操作,它会被若干个位置所钦定。考虑对于所有有效的操作 dp。
如果我们知道了一个操作覆盖的集合,那么操作范围一定包含集合的最小和最大值。因此我们只关心集合的左右端点。
考虑对这些区间扫描线。\(f_{i,S,j}\) 表示现在扫到 \(i\),\(i-m+1\sim i\) 的未匹配状态是 \(S\),当前 \(j\) 个有效操作的方案数。转移考虑 \(i\) 不为端点、为左或右、单独开一个区间即可。容易做到 \(O(2^mn^2m)\)。可以通过 \(m\le 16\)。
对于 \(m\) 较大的部分,此时有 \(3m>n\),不妨令 \(n=3m-1\)。再对于序列 \(a\) 考虑。此时 \(a\) 可以划分为 \(3\) 段:\(1\sim m,m+1\sim 2m,2m+1\sim 3m\)。
如果我们知道了 \(a_{1\sim m}\) 的值,那么可以推出 \(1\sim m\) 的操作次数 \(c_{1\sim m}\),具体地,\(c_i=a_i-a_{i-1}\)。同理,如果我们知道 \(a_{2m+1\sim 3m}\),对于 \(1\le i\le m\) 可以推出 \(c_{i+m+1}=a_{i+2m}-a_{i+2m+1}\)。显然此时有 \(a_{1\sim m}\) 不减,\(a_{2m+1\sim 3m}\) 不增。
此时未确定的就只有 \(c_{m+1}\),而操作总数为 \(C\),因此 \(c_{m+1}=C-a_m-a_{2m+1}\)。
对于 \(a_{m+1\sim 2m}\),有 \(a_{m+i}=C-a_i-a_{i+2m}\),即 \(a_i+a_{i+m}+a_{i+2m}=C\)。同时因为 \(c_{m+1}\ge 0\),有 \(a_m+a_{2m+1}\le C\)。
考虑枚举 \(a_m\),然后按列 dp,每次加入 \(a_i,a_{i+m},a_{i+2m}\)。设 \(f_{i,j,k}\) 表示加入到 \(i\),\(a_i=j\),\(a_{i+2m}=k\),转移暴力枚举可以做到 \(O(n^2C^4)\)。后面两维分步转移可以做到 \(O(n^2C^3)\)。
对于 \(C\) 较大时,可以发现答案是关于 \(C\) 的 \(n\) 次多项式,取 \(C=0\sim n\) 插值即可,复杂度 \(O(n^6)\)。
#include <bits/stdc++.h>
using namespace std;
using ll = long long;
const int mod = 998244353;
void Add(int &x, ll y) {
x = (x + y) % mod;
}
int Pow(int x, int y) {
int b = x, r = 1;
for(; y; b = (ll)b * b % mod, y /= 2) {
if(y & 1) r = (ll)r * b % mod;
}
return r;
}
const int kN = 55;
int n, m, c;
int b[kN];
namespace Task1 {
const int kS = (1 << 16) + 5;
int s[kN];
int f[kS][kN], old[kS][kN];
void Solve() {
for(int i = 1; i <= n; i++) {
s[i] = s[i - 1] + b[i];
}
f[0][0] = 1;
for(int i = 0; i < n; i++) {
memcpy(old, f, sizeof(f));
memset(f, 0, sizeof(f));
for(int msk = 0; msk < (1 << m - 1); msk++) {
int pc = __builtin_popcount(msk);
for(int c = 0; c <= i; c++) {
int val = old[msk][c];
if(!val) continue;
Add(f[msk * 2][c], (ll)val * pc);
Add(f[msk * 2][c + 1], (ll)val * (s[i + 1] - s[max(i + 1 - m, 0)]));
Add(f[msk * 2 + 1][c + 1], val);
for(int j = 0; j < m - 1; j++) {
if((msk >> j) & 1) {
Add(f[(msk ^ (1 << j)) << 1][c], (ll)val * (s[i - j] - s[max(i + 1 - m, 0)]));
}
}
}
}
}
int res = 0;
int sum = accumulate(b + 1, b + n + 1, 0);
int inv = Pow(sum, mod - 2);
for(int c = 0, dn = 1; c <= n; c++) {
Add(res, (ll)f[0][c] * dn);
dn = (ll)dn * inv % mod * (::c - c) % mod;
}
cout << res << "\n";
}
}
namespace Task2 {
const int kN = 75, kM = 28;
int tn, tm;
int mul[kN], imul[kN];
int pw[kN][kN];
int f[kM][kN][kN], tmp[kN][kN];
void Init(int N = kN - 2) {
mul[0] = 1;
for(int i = 1; i <= N; i++) {
mul[i] = (ll)mul[i - 1] * i % mod;
}
imul[N] = Pow(mul[N], mod - 2);
for(int i = N - 1; ~i; i--) {
imul[i] = (ll)imul[i + 1] * (i + 1) % mod;
}
}
int Work(int c) {
int coef = 1;
if(n < m + m) {
int len = m + m - n;
coef = Pow(c, len);
n -= len;
m -= len;
}
if(!n) return coef;
coef = (ll)coef * mul[c] % mod;
for(int i = 1; i <= n; i++) {
pw[i][0] = 1;
for(int j = 1; j <= c; j++) {
pw[i][j] = (ll)pw[i][j - 1] * b[i] % mod;
}
}
int ans = 0;
for(int v = 0; v <= c; v++) { // a_m
memset(f, 0, sizeof(f));
for(int i = 0; i <= v; i++) {
for(int j = 0; j <= c - v; j++) {
int tmp = (ll)pw[1][i] * imul[i] % mod;
tmp = (ll)tmp * pw[1 + m][c - v - j] % mod;
tmp = (ll)tmp * imul[c - v - j] % mod;
f[1][i][j] = tmp;
}
}
f[0][0][0] = 1;
for(int i = 1; i < m; i++) {
memset(tmp, 0, sizeof(tmp));
for(int x = 0; x <= v; x++) {
for(int y = 0; y + x <= c; y++) {
if((i + 2 * m > n) && y) break;
int val = f[i][x][y];
val = (ll)val * x % mod;
val = (ll)val * (c - x - y) % mod;
if(i + 2 * m <= n) val = (ll)val * y % mod;
if(!val) continue;
for(int nx = x; nx <= c; nx++) {
int coef = (ll)imul[nx - x] * pw[i + 1][nx - x] % mod;
Add(tmp[nx][y], (ll)val * coef);
}
}
}
for(int x = 0; x <= v; x++) {
for(int y = 0; y <= c; y++) {
int val = tmp[x][y];
if(!val) continue;
for(int ny = 0; ny <= min(y, c - x); ny++) {
int coef = (ll)imul[y - ny] * pw[i + m + 1][y - ny] % mod;
Add(f[i + 1][x][ny], (ll)val * coef);
}
}
}
}
Add(ans, (ll)f[m][v][0] * v * (c - v));
}
ans = (ll)ans * coef % mod;
int sum = accumulate(b + 1, b + n + 1, 0);
int inv = Pow(sum, mod - 2);
return (ll)Pow(inv, c) * ans % mod;
}
void Solve() {
Init();
tn = n, tm = m;
if(c <= n + 1) {
cout << Work(c) << "\n";
}else {
vector<int> y (n + 1);
for(int i = 0; i <= n; i++) {
y[i] = Work(i);
n = tn;
m = tm;
}
int res = 0;
for(int i = 0; i <= n; i++) {
int coef = y[i], tmp = 1;
for(int j = 0; j <= n; j++) {
if(i == j) continue;
tmp = (ll)tmp * (i + mod - j) % mod;
coef = (ll)coef * (c + mod - j) % mod;
}
Add(res, (ll)coef * Pow(tmp, mod - 2));
}
cout << res << "\n";
}
}
}
int main() {
// freopen("1.in", "r", stdin);
// freopen("1.out", "w", stdout);
ios::sync_with_stdio(0), cin.tie(0);
cin >> n >> m >> c;
for(int i = 1; i <= n - m + 1; i++) {
cin >> b[i];
}
if(m <= 16) {
return Task1::Solve(), 0;
}else {
return Task2::Solve(), 0;
}
return 0;
}
浙公网安备 33010602011771号