题解:P12392 「RiOI-6」Re:帝国少女

题意

给定 \(m\) 和一个长度为 \(n\) 的序列 \(a\),对于 \(1\leq i\leq n,a_i\in[1,m]\cap\mathbb{Z}\)。现在可以进行若干次修改,每次修改可以选定一个 \(a_i\),将其修改为 \([1,m]\) 中的一个整数。求在 \(a\) 中任意相邻两数不同的前提下,最小的修改次数。

对于一个序列 \(a\),定义困难程度 \(f(a)\) 为上面的问题的答案,\(g(a)\) 为使得修改次数最小的修改方案数。

在所有长度为 \(n\) 值域为 \([1,m]\) 的整数序列中,对于每个 \(k\in\left[0,\lfloor\frac{n}{2}\rfloor\right]\),求出所有 \(f(a)=k\) 的序列 \(a\)\(g(a)\) 之和。\(1\leq n\leq 5\times 10^3\)\(2\leq m\leq 10^9\)

题解

前情回顾:Luogu P12391 | 我的题解

很好的细节 DP 题,使我大脑飞速旋转。

先做 \(m=2\) 的情况。若 \(2k\neq n\),则只能修改为 \(0,1,\cdots\)\(1,0,\cdots\) 中的一种,各从中选择 \(k\) 位翻转就是对应的原序列,方案数为 \(2\binom{n}{k}\);若 \(2k=n\),有 \(2\) 种修改方案,方案数为 \(\binom{n}{k}\),答案为 \(2\binom{n}{k}\)。因此,对于 \(m=2\) 的情况,输出 \(2\binom{n}{k}\) 即可。

再做 \(m>2\) 的情况。还是按极长同色段划分序列,我们先不考虑颜色,思考一个段内的修改方案是怎么样的(我们用 \(0/1\) 表示对应位置是否修改):

  • 若长度为奇数,显然只有 \(010\cdots 010\) 这种修改方案。
  • 若长度为偶数:比较显然有 \(0101\cdots 01\)\(1010\cdots 10\)\(2\) 种方案,还有一种容易忽略的方案是 \(0101\cdots11\cdots1010\),也就是修改中间的连续 \(2\) 个位置 \((i,i+1)\),其中 \(i\) 是偶数。计算可得我们能选出 \(\frac{len}{2}-1\) 个这样的 \(i\)

考虑颜色的限制,我们发现当前段的染色方案数。只会受前面段原来的颜色和前面段最后一个位置的颜色限制,由此设计 DP 状态:令 \(f_{i,j,0/1}\) 表示考虑 \(a[1,i]\),所有满足 \(f(a[1,i])=j\)\(g(a[1,i])\) 之和。转移时我们枚举 \(k\) 表示 \(i\) 所在颜色段的长度,根据 \(k\) 的奇偶性分类讨论,需要特判 \(j=\left\lfloor\frac{i}{2}\right\rfloor\) 的情况。

下文中我们令 \(t=\left\lfloor\frac{k}{2}\right\rfloor\)

\(f\) 的计算包含两部分:当前段的初始方案数和染色方案数。

对于当前段的初始方案数,讨论 \(f_{i-k,j-t,0/1}\)\(0/1\) 的取值:

  • 若取值为 \(0\),则当前段的颜色与前面的段的颜色不同即可,有 \(m-1\) 种。
  • 若取值为 \(1\),则当前段的颜色既需要与前面的段不同,也需要和前面的段的最后一个位置修改后的颜色不同,有 \(m-2\) 种。这是大部分情况,但是有一个 corner case:若我们在当前段修改了首位的颜色,则段的初始颜色不必与最后一个位置修改后的颜色不同,此时有 \(m-1\) 种方案。

对于染色方案数,分别考虑首位和其他位的染色方案数即可,如果是连续修改中间的 \(2\) 位也需要考虑进去。读者可以自行推导。

这样我们就得到了 \(\mathcal{O}(n^3)\) 的朴素 DP。后面也会放这部分的代码。

考虑优化。我们发现转移式中大部分都是乘上 \(m-1\) 或者 \(m-2\) 的形式,可以提出来,于是需要快速计算包含 \(k\)\(t\) 的那部分式子,更具体地,我们要维护出

\[\sum_{k=1}^{\min(i-1,2j+1)}f_{i-k,j-t,0/1}(m-1)^{t-1} \]

以及

\[\sum_{k=1}^{\min(i-1,2j+1)}f_{i-k,j-t,0/1}(m-1)^{t-1}(t-1) \]

容易想到前缀和优化。我们分奇偶维护,令

\[\begin{align*} g_{i,j,0/1}&=\sum_{k=0}^{\min(\frac{i}{2}, j)}f_{i-2k,j-k,0/1}(m-1)^k\\ h_{i,j,0/1}&=\sum_{k=0}^{\min(\frac{i}{2}, j)}f_{i-2k,j-k,0/1}(m-1)^kk \end{align*} \]

不难列出转移方程:

\[\begin{align*} g_{i,j,0/1}&=(m-1)g_{i-2,j-1,0/1}+f_{i,j,0/1}\\ h_{i,j,0/1}&=(m-1)(g_{i-2,j-1,0/1}+h_{i-2,j-1,0/1}) \end{align*} \]

容易用 \(g\)\(h\) 对应计算出我们所需的式子做到 \(\mathcal{O}(1)\) 转移。时间复杂度 \(\mathcal{O}(n^2)\)

代码

$\mathcal{O}(n^3)$ DP
#include <iostream>

using namespace std;

#define lowbit(x) ((x) & -(x))
#define chk_min(x, v) (x) = min((x), (v))
#define chk_max(x, v) (x) = max((x), (v))
typedef long long ll;
typedef pair<int, int> pii;
const int N = 5e3 + 5, MOD = 1e9 + 7;

inline int add(int x, int y) { return x += y, x >= MOD ? x - MOD : x; }
inline int sub(int x, int y) { return x -= y, x < 0 ? x + MOD : x; }
inline void cadd(int &x, int y) { x += y, x < MOD || (x -= MOD); }
inline void csub(int &x, int y) { x -= y, x < 0 && (x += MOD); }

int n, m, fac[N], ifac[N], f[N][N][2], pw[N];

int qpow(int a, ll b) {
	int res = 1;
	for (; b; b >>= 1) {
		if (b & 1) res = 1ll * res * a % MOD;
		a = 1ll * a * a % MOD;
	}
	return res;
}
void pre() {
	fac[0] = 1;
	for (int i = 1; i <= n; ++i) fac[i] = 1ll * fac[i - 1] * i % MOD;
	ifac[n] = qpow(fac[n], MOD - 2);
	for (int i = n - 1; ~i; --i) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % MOD;
}
int C(int n, int m) { return 1ll * fac[n] * ifac[m] % MOD * ifac[n - m] % MOD; }

int main() {
    ios::sync_with_stdio(0), cin.tie(0);
    cin >> n >> m;
    if (m == 2) {
    	pre();
    	for (int i = 0; i <= n >> 1; ++i) cout << C(n, i) * 2 % MOD << ' ';
    } else {
        pw[0] = 1;
    	for (int i = 1; i <= n >> 1; ++i) pw[i] = 1ll * pw[i - 1] * (m - 1) % MOD;
    	f[1][0][0] = m;
    	for (int i = 2; i <= n; ++i) for (int j = 0; j <= i >> 1; ++j) {
    		if (j == (i >> 1)) {
    			if (i & 1) f[i][j][0] = 1ll * m * pw[j] % MOD;
    			else {
    				f[i][j][0] = 1ll * m * pw[j] % MOD;
    				cadd(f[i][j][0], 1ll * m * (j - 1) % MOD * pw[j - 1] % MOD * (m - 2) % MOD);
    				f[i][j][1] = 1ll * m * pw[j] % MOD;
    			}
    		}
    		for (int k = 1; k < i && (k >> 1 <= j); ++k) {
    			int t = k >> 1;
    			int v1 = 1ll * f[i - k][j - t][0] * (m - 1) % MOD;
    			int v2 = 1ll * f[i - k][j - t][1] * (m - 2) % MOD;
    			if (k & 1) {
    				cadd(f[i][j][0], 1ll * v1 * pw[t] % MOD);
    				cadd(f[i][j][0], 1ll * v2 * pw[t] % MOD);
    			} else {
    				// 101010...
    				cadd(f[i][j][0], 1ll * v1 % MOD * (m - 2) % MOD * pw[t - 1] % MOD);
    				cadd(f[i][j][0], 1ll * v2 % MOD * (m - 2) % MOD * pw[t - 1] % MOD);
    				cadd(f[i][j][0], 1ll * f[i - k][j - t][1] * qpow(m - 1, t) % MOD);
    				// 010101...
    				cadd(f[i][j][1], 1ll * v1 % MOD * qpow(m - 1, t) % MOD);
    				cadd(f[i][j][1], 1ll * v2 % MOD * qpow(m - 1, t) % MOD);
    				// 010...11...010
    				cadd(f[i][j][0], 1ll * v1 % MOD * (t - 1) % MOD * pw[t - 1] % MOD * (m - 2) % MOD);
    				cadd(f[i][j][0], 1ll * v2 % MOD * (t - 1) % MOD * pw[t - 1] % MOD * (m - 2) % MOD);
    			}
    		}
    	}
    	for (int i = 0; i <= n >> 1; ++i) cout << add(f[n][i][0], f[n][i][1]) << ' ';
    }
    return 0;
}
$\mathcal{O}(n^2)$ DP
#include <iostream>

using namespace std;

#define lowbit(x) ((x) & -(x))
#define chk_min(x, v) (x) = min((x), (v))
#define chk_max(x, v) (x) = max((x), (v))
typedef long long ll;
typedef pair<int, int> pii;
const int N = 5e3 + 5, MOD = 1e9 + 7;

inline int add(int x, int y) { return x += y, x >= MOD ? x - MOD : x; }
inline int sub(int x, int y) { return x -= y, x < 0 ? x + MOD : x; }
inline void cadd(int &x, int y) { x += y, x < MOD || (x -= MOD); }
inline void csub(int &x, int y) { x -= y, x < 0 && (x += MOD); }

int n, m, f[N][N][2], g[N][N][2], h[N][N][2];
int iv, pw[N], fac[N], ifac[N];

int qpow(int a, ll b) {
	int res = 1;
	for (; b; b >>= 1) {
		if (b & 1) res = 1ll * res * a % MOD;
		a = 1ll * a * a % MOD;
	}
	return res;
}
void pre() {
	fac[0] = 1;
	for (int i = 1; i <= n; ++i) fac[i] = 1ll * fac[i - 1] * i % MOD;
	ifac[n] = qpow(fac[n], MOD - 2);
	for (int i = n - 1; ~i; --i) ifac[i] = 1ll * ifac[i + 1] * (i + 1) % MOD;
}
int C(int n, int m) { return 1ll * fac[n] * ifac[m] % MOD * ifac[n - m] % MOD; }

int main() {
    ios::sync_with_stdio(0), cin.tie(0);
    cin >> n >> m;
    if (m == 2) {
    	pre();
    	for (int i = 0; i <= n >> 1; ++i) cout << C(n, i) * 2 % MOD << ' ';
    } else {
    	pw[0] = 1;
    	for (int i = 1; i <= n >> 1; ++i) pw[i] = 1ll * pw[i - 1] * (m - 1) % MOD;
    	iv = qpow(m - 1, MOD - 2);
    	g[1][0][0] = f[1][0][0] = m;
    	for (int i = 2; i <= n; ++i) for (int j = 0; j <= i >> 1; ++j) {
    		if (j == (i >> 1)) {
    			if (i & 1) f[i][j][0] = 1ll * m * pw[j] % MOD;
    			else {
    				f[i][j][0] = 1ll * m * pw[j] % MOD;
    				cadd(f[i][j][0], 1ll * m * (j - 1) % MOD * pw[j - 1] % MOD * (m - 2) % MOD);
    				f[i][j][1] = 1ll * m * pw[j] % MOD;
    			}
    		}
            int v1 = 1ll * g[i - 1][j][0] * iv % MOD;
            int v2 = 1ll * g[i - 1][j][1] * iv % MOD;
    		cadd(f[i][j][0], 1ll * v1 * (m - 1) % MOD * (m - 1) % MOD);
    		cadd(f[i][j][0], 1ll * v2 * (m - 2) % MOD * (m - 1) % MOD);
    		if (j >= 1) {
	    		v1 = g[i - 2][j - 1][0], v2 = g[i - 2][j - 1][1];
                int x = 1ll * v1 * (m - 1) % MOD;
                int y = 1ll * v2 * (m - 1) % MOD;
	    		// 101010...
	    		cadd(f[i][j][0], 1ll * x * (m - 2) % MOD);
	    		cadd(f[i][j][0], 1ll * v2 * (m - 2) % MOD * (m - 2) % MOD);
	    		cadd(f[i][j][0], 1ll * y);
	    		// 010101...
	    		cadd(f[i][j][1], 1ll * x * (m - 1) % MOD);
	    		cadd(f[i][j][1], 1ll * y * (m - 2) % MOD);
	    		// 010...11...010
	    		v1 = h[i - 2][j - 1][0], v2 = h[i - 2][j - 1][1];
	    		cadd(f[i][j][0], 1ll * v1 % MOD * (m - 2) % MOD * (m - 1) % MOD);
	    		cadd(f[i][j][0], 1ll * v2 % MOD * (m - 2) % MOD * (m - 2) % MOD);
	    	}
	    	if (!j) {
	    		g[i][j][0] = f[i][j][0], g[i][j][1] = f[i][j][1];
	    		h[i][j][0] = h[i][j][1] = 0;
			} else {
		    	g[i][j][0] = add(1ll * g[i - 2][j - 1][0] * (m - 1) % MOD, f[i][j][0]);
		    	g[i][j][1] = add(1ll * g[i - 2][j - 1][1] * (m - 1) % MOD, f[i][j][1]);
		    	h[i][j][0] = add(1ll * h[i - 2][j - 1][0] * (m - 1) % MOD, 1ll * g[i - 2][j - 1][0] * (m - 1) % MOD);
		    	h[i][j][1] = add(1ll * h[i - 2][j - 1][1] * (m - 1) % MOD, 1ll * g[i - 2][j - 1][1] * (m - 1) % MOD);
		    }
		}
    	for (int i = 0; i <= n >> 1; ++i) cout << add(f[n][i][0], f[n][i][1]) << ' ';
    }
    return 0;
}
posted @ 2025-05-04 18:04  P2441M  阅读(14)  评论(0)    收藏  举报