CF1093F Vasya and Array 更优秀的做法
CF1093F Vasya and Array 更优秀的做法
摘要
本题有一个经典的 DP + 容斥做法,时间复杂度是 \(O(nk)\)。
本文作者在此基础上,创新性地发掘题目性质,简化 DP 状态,利用数据结构优化 DP,提出了一个时间复杂度 \(O(n)\) 的做法。显然相比经典做法,是更加优秀的。
本文将分别介绍这两种做法。
题目大意
给出一段长度为 \(n\) 的整数序列,一个正整数 \(k\) ,一个正整数 \(\text{len}\)。序列中的所有数要么在 \([1,k]\) 之间,要么等于 \(-1\)。
我们称一个序列是好的,当且仅当不存在 \(\text{len}\) 个连续的相同的数字。
你可以将每个 \(-1\),替换成任意一个 \([1,k]\) 之间的整数。求有多少种方案,使得最终序列是好的。答案对 \(998244353\) 取模。
数据范围:\(1\leq n\leq 10^5\),\(1\leq k\leq 100\),\(1\leq \text{len}\leq n\)。
经典做法
设 \(\text{dp}_0(i,j)\) 表示考虑了前 \(i\) 个位置,最后一个位置上填的数是 \(j\),前 \(i\) 个位置组成的序列合法的方案数。设 \(\text{sdp}(i)=\sum_{j=1}^{k}\text{dp}_0(i,j)\)。
首先,当 \(a_i\neq -1\) 且 \(a_i\neq j\) 时,\(\text{dp}_0(i,j)=0\)。
否则,我们暂时令 \(\text{dp}_0(i,j)=\text{sdp}(i - 1)\)。但是这会把一些不合法的方案算进去。具体来说,这样有可能出现 \([a_{i-\text{len}+1}\dots a_i]\) 全部相同的情况。这种情况会出现当且仅当如下两个条件都满足:
- \(i\geq \text{len}\)。
- \([a_{i-\text{len}+1}\dots a_i]\) 中每个数都等于 \(-1\) 或 \(j\)。
所以,我们还要减去这种情况的数量:\(\text{sdp}(i-\text{len})-\text{dp}_0(i-\text{len},j)\)。其中 \(\text{sdp}(i - \text{len})\) 表示 \([a_{i-\text{len}+1}\dots a_i]\) 全部等于 \(j\) 时,前 \(i - \text{len}\) 位的填写方案。不过这些方案中,有一些方案可能在前 \(i-1\) 位就已经导致不合法了。这些提前不合法的方案本来就没有被算在 \(\text{sdp}(i-1)\) 中,所以不需要被减去,它们的数量是 \(\text{dp}_0(i - \text{len},j)\)。
综上所述,可以得到转移式:
时间复杂度\(O(nk)\)。
更优秀的做法
首先,当 \(\text{len} = 1\) 时,答案一定是 \(0\)。以下只讨论 \(\text{len} > 1\) 的情况。
先考虑一种朴素的 DP。设 \(\text{dp}_1(i,j,l)\) 表示考虑了前 \(i\) 个位置,第 \(i\) 位上填的数是 \(j\),最后一个 \(\neq j\) 的位置是 \(l\),此时使得前 \(i\) 位组成的序列合法的方案数。
转移时,考虑当前位填了什么:
初始状态为 \(\text{dp}_1(0,0,0) = 1\)。答案是 \(\sum_{j = 1}^{k}\sum_{l = n - \text{len}+1}^{n - 1}\text{dp}_1(n,j,l)\)。
这个朴素 DP 的时间复杂度是 \(O(n^2k)\)。
优化它!设上一位填的数为 \(j\),当前位填的数为 \(x\)。我们发现,当 \(x\) 和 \(j\) 不同时,\(x\) 的值具体是什么其实不重要:对所有 \(x\neq j\),它们的转移是一模一样的。这就给了我们简化的空间。
定义 \(a_{n+1} = k+1\)。定义 \(\text{nxt}_i\) 表示位置 \(i\) 后面第一个 \(a_{i'}\neq -1\) 的位置 \(i'\)。设 \(\text{dp}_2(i,j\in\{0,1\},l)\) 表示考虑了前 \(i\) 个位置,第 \(i\) 位上填的数是 / 否等于 \(a_{\text{nxt}_i}\),前 \(i\) 位里最后一个填的数与第 \(i\) 位上不同的位置是 \(l\),此时使得前 \(i\) 位组成的序列合法的方案数。
转移分 \(a_i\) 是否为 \(-1\) 两种情况。
当 \(a_i\neq -1\) 时,枚举 \(l\)。则有如下转移:
当 \(a_i = -1\) 时,显然 \(\text{nxt}_{i} =\text{nxt}_{i-1}\)。枚举 \(l\)。我们分别考虑如下情况:
- 第 \(i-1\) 位填的数与 \(a_{\text{nxt}_{i}}\) 不同:
- 第 \(i\) 位上填的数与第 \(i-1\) 位上填的数相同。
- 第 \(i\) 位上填的数与第 \(a_{\text{nxt}_i}\) 相同。
- 第 \(i\) 位上填的数,既不等于第 \(i-1\) 位上填的数,也不等于 \(a_{\text{nxt}_{i}}\)。
- 第 \(i-1\) 位填的数与 \(a_{\text{nxt}_{i}}\) 相同:
- 第 \(i\) 位填的数与 \(a_{\text{nxt}_{i}}\) 不同。
- 第 \(i\) 位填的数与 \(a_{\text{nxt}_{i}}\) 相同。
这五种情况分别对应如下转移:
上述式子里默认 \(1 < i < \text{nxt}_i \leq n\)。当 \(i = 1\) 或 \(\text{nxt}_i = n+1\) 时,有一些特殊情况要考虑。为了表述简洁,这里就不细写了。
现在,这个 DP 的时间复杂度是 \(O(n^2)\)。虽然无法 AC,但这是迈向 \(O(n)\) 做法的关键一步。我将这一 DP 的代码附在了本文末尾:点击跳转。
在这个 DP 里,第 \(1\) 维的枚举不可避免,第 \(2\) 维的状态数已经被我们优化到 \(O(1)\)。考虑优化第 \(3\) 维。
观察转移式,发现从 \(i-1\) 变成 \(i\) 时,第 \(3\) 维的转移,相当于做如下操作:
- 区间求和。对所有 \(l\in[i - \text{len},i - 2]\) 求和。
- 单点加。
- 把一段区间的值覆盖为 \(0\)。
- 在 \(a_i\neq -1\) 且 \(a_i\neq a_{\text{nxt}_i}\) 时,需要把 \(\text{dp}_2(i,1)\) 中的一段 \(l\),复制到 \(\text{dp}_2(i,0)\) 对应的位置上。
按第二维的 \(0,1\),维护两棵线段树,通过打懒标记即可实现这四种操作。
时间复杂度 \(O(n\log n)\)。
继续优化,发现区间操作都是假的。
- 操作 \(3\) 要么是单点清空,要么是全局清空。
- 操作 \(1\) 和操作 \(4\),因为其他位置都清空了,所以区间求和、区间复制,其实就是全局求和、全局交换。
所以我们只需要用两个数组来维护。
时间复杂度 \(O(n)\)。
参考代码
最终代码
友情提醒:使用读入、输出优化可以使代码更快,详见本博客公告。
// problem: CF1093F
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }
const int MAXN = 1e5;
const int MOD = 998244353;
inline int mod1(int x) { return x < MOD ? x : x - MOD; }
inline int mod2(int x) { return x < 0 ? x + MOD : x; }
inline void add(int &x, int y) { x = mod1(x + y); }
inline void sub(int &x, int y) { x = mod2(x - y); }
int n, K, len, a[MAXN + 5];
struct FantasticDataStructure {
int sum;
int arr[MAXN + 5];
int TIM;
int tim[MAXN + 5];
void upd(int p) {
if (tim[p] < TIM) {
tim[p] = TIM;
arr[p] = 0;
}
}
void point_add(int p, int v) {
upd(p);
add(arr[p], v);
add(sum, v);
}
void point_set0(int p) {
upd(p);
sub(sum, arr[p]);
arr[p] = 0;
}
void global_set0() {
TIM++;
sum = 0;
}
int query() {
return sum;
}
FantasticDataStructure() {}
};
FantasticDataStructure S[2];
int id[2];
// DP 转移: a[i] != -1 / a[i] == -1
void trans1(int p, int nxtval) {
int v = S[id[0]].query();
if (a[p] == nxtval) {
S[id[0]].global_set0();
if (p - len >= 0) {
S[id[1]].point_set0(p - len);
}
S[id[1]].point_add(p - 1, v);
} else {
swap(id[0], id[1]);
S[id[1]].global_set0();
if (p - len >= 0) {
S[id[0]].point_set0(p - len);
}
S[id[0]].point_add(p - 1, v);
}
}
void trans2(int p, int flag) {
int v0 = S[id[0]].query();
int v1 = S[id[1]].query();
if (p == 1) {
S[id[0]].global_set0();
} else {
if (p - len >= 0) {
S[id[0]].point_set0(p - len);
}
}
int toadd = 0;
if (flag + (p != 1) + 1 <= K) {
toadd = (ll)v0 * (K - flag - (p != 1)) % MOD;
}
add(toadd, (ll)v1 * (K - 1) % MOD);
S[id[0]].point_add(p - 1, toadd);
if (p - len >= 0) {
S[id[1]].point_set0(p - len);
}
if (flag) {
S[id[1]].point_add(p - 1, v0);
}
}
int main() {
cin >> n >> K >> len;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
if (len == 1) {
cout << 0 << endl;
return 0;
}
id[0] = 0;
id[1] = 1;
S[id[1]].point_add(0, 1);
a[0] = K + 1;
a[n + 1] = K + 2;
for (int i = 1, j = 2; i <= n; ++i) {
ckmax(j, i + 1);
while (a[j] == -1)
++j;
if (a[i] != -1) {
trans1(i, a[j]);
} else {
trans2(i, j != n + 1);
}
}
cout << S[id[0]].query() << endl;
return 0;
}
n^2 DP
为了帮助读者更好地理解题解,这里附上 \(O(n^2)\) 朴素 DP 的代码。
#include <bits/stdc++.h>
using namespace std;
#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())
typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;
template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }
const int MAXN = 1000;
const int MOD = 998244353;
inline int mod1(int x) { return x < MOD ? x : x - MOD; }
inline int mod2(int x) { return x < 0 ? x + MOD : x; }
inline void add(int &x, int y) { x = mod1(x + y); }
inline void sub(int &x, int y) { x = mod2(x - y); }
int n, K, len, a[MAXN + 5];
int dp[MAXN + 5][2][MAXN + 5];
void trans1(int p, int nxtval) {
for (int i = max(0, p - len); i <= max(0, p - 2); ++i) {
add(dp[p][a[p] == nxtval][p - 1], dp[p - 1][0][i]);
if (i >= p - len + 1)
add(dp[p][a[p] == nxtval][i], dp[p - 1][1][i]);
}
}
void trans2(int p, int flag) {
for (int i = max(0, p - len); i <= max(0, p - 2); ++i) {
// a[p - 1] 和 nxtval 不同
if (p != 1 && i >= p - len + 1) {
add(dp[p][0][i], dp[p - 1][0][i]); // a[p] 和 a[p - 1] 相同
}
if (flag) {
add(dp[p][1][p - 1], dp[p - 1][0][i]); // a[p] 和 nxtval 相同
}
if (flag + (p != 1) + 1 <= K) {
// a[p] 和 nxtval, a[p - 1] 都不同
add(dp[p][0][p - 1], (ll)dp[p - 1][0][i] * (K - flag - (p != 1)) % MOD);
}
// a[p - 1] 和 nxtval 相同
add(dp[p][0][p - 1], (ll)dp[p - 1][1][i] * (K - 1) % MOD);
if (i >= p - len + 1) {
add(dp[p][1][i], dp[p - 1][1][i]); // a[p] 和 a[p - 1], nxtval 都相同
}
}
}
int main() {
cin >> n >> K >> len;
for (int i = 1; i <= n; ++i) {
cin >> a[i];
}
dp[0][0][0] = 1;
a[0] = K + 1;
a[n + 1] = K + 2;
for (int i = 1, j = 2; i <= n; ++i) {
ckmax(j, i + 1);
while (a[j] == -1)
++j;
if (a[i] != -1) {
trans1(i, a[j]);
} else {
trans2(i, j != n + 1);
}
}
int ans = 0;
for (int i = max(0, n - len + 1); i <= n - 1; ++i) {
add(ans, dp[n][0][i]);
}
cout << ans << endl;
return 0;
}