序列合并

Problem

有一个序列,初始时为空,你会不断往序列末尾添加一个 \([1, m]\) 的随机整数。

任意时刻

  • 若序列末尾两个数相同(记为 \(x\)),且小于 \(t\),则这两个数会合并成 \(x+1\)

  • 若序列长度为 \(n\) 且无法合并,则操作结束。

求序列中所有元素和的期望,答案对 \(10^9+7\) 取模。

\(1\leq n,m\leq 10^3\)\(m\leq t\leq 10^9\)

Solution

\(L=\min\{t, n+m-1\}\)

\(p_{i,j}\):限制序列长度为 \(i\),第一个位置出现 \(j\) 的概率。

\[p_{i,j}=[j\leq m]\frac1m+p_{i,j-1}\times p_{i-1,j-1} \]

\(q_{i,j}\):限制序列长度为 \(i\),第一个位置为 \(j\) 下,之后不再改变的概率。

\[q_{i,j}=1-[j<L]p_{i-1,j} \]

\(g_{i,j}\):限制序列长度为 \(i\),第一个位置为 \(j\),且之后不再改变后,整个序列的期望。

\(ans_i\):序列长度为 \(i\) 时,权值和的期望。

\(f_{i,j}\):序列长度为 \(i\),第一个数字出现 \(j\) 时,序列元素权值和。

\(j<L\) 时,

\[\begin{aligned} g_{i,j}&=j+\sum_{S}S\times\Pr\{第2到i权值和为S|第1个为j,且j不变\}\\ &=j+\frac{\sum_{S}S\times\Pr\{第2到i权值和为S,j不改变|第1个为j\}}{\Pr\{j不改变|第1个为j\}}\\ &=j+\frac{\sum_{S}S\times\Pr\{第2到i权值和为S,j不改变|第1个为j\}}{q_{i,j}}\\ &=j+\frac{\sum_{S}S\times\Pr\{第2到i权值和为S,j任意|第1个为j\}-\sum_{S}S\times\Pr\{第2到i权值和为S,j改变|第1个为j\}}{q_{i,j}}\\ &=j+\frac{ans_{i-1}-\Pr\{第2个为j\}\times\sum_S S\times\Pr\{第2到i权值和为S|第2个为j\}}{q_{i,j}}\\ &=j+\frac{ans_{i-1}-p_{i-1,j}\times f_{i-1,j}}{q_{i,j}} \end{aligned} \]

\(j=L\) 时,

\[g_{i,j}=j+ans_{i-1} \]

对于 \(ans_i\)

\[ans_i=\sum_{j=1}^Lp_{i,j}\times q_{i,j}\times g_{i,j} \]

对于 \(f_{i,j}\)

\[f_{i,j}=q_{i,j}\times g_{i,j}+(1-q_{i,j})\times f_{i,j+1} \]

为了避免求逆元,将 \(g_{i,j}\times q_{i,j}\) 整体转移。复杂度 \(\mathcal O(nm)\)

Code

#include <bits/stdc++.h>
const int N = 2005, P = 1e9 + 7;
using std::min;
int n, m, t, L, f[N][N], qg[N][N], p[N][N], q[N][N], ans[N];
int qpow(int a, int b) {
	int t = 1;
	for (; b; b >>= 1, a = 1LL * a * a % P)
		if (b & 1) t = 1LL * t * a % P;
	return t;
}
int main() {
	scanf("%d%d%d", &n, &m, &t); L = min(n + m - 1, t);
	int inv = qpow(m, P-2);
	for (int i = 1; i <= n; i++) {
		for (int j = 1; j <= L; j++) {
			p[i][j] = ((j <= m ? inv : 0) + 1LL * p[i][j - 1] * p[i - 1][j - 1]) % P;
			q[i][j] = (1 - (j < L ? p[i - 1][j] : 0) + P) % P;
		}
		for (int j = L; j; j--) {
			qg[i][j] = (1LL * q[i][j] * j + ans[i - 1] - (j < L ? 1LL * p[i - 1][j] * f[i - 1][j] % P : 0) + P) % P;
			f[i][j] = (qg[i][j] + (j < L ? 1LL * (1 - q[i][j] + P) * f[i][j + 1] : 0)) % P;
			ans[i] = (ans[i] + 1LL * p[i][j] * qg[i][j]) % P;
		}
	}
	printf("%d\n", ans[n]);
	return 0;
}
posted @ 2021-02-22 21:40  AC-Evil  阅读(379)  评论(0编辑  收藏  举报