CF1093F Vasya and Array 更优秀的做法

CF1093F Vasya and Array 更优秀的做法

摘要

本题有一个经典的 DP + 容斥做法,时间复杂度是 \(O(nk)\)

本文作者在此基础上,创新性地发掘题目性质,简化 DP 状态,利用数据结构优化 DP,提出了一个时间复杂度 \(O(n)\) 的做法。显然相比经典做法,是更加优秀的。

本文将分别介绍这两种做法。

题目大意

题目链接

给出一段长度为 \(n\) 的整数序列,一个正整数 \(k\) ,一个正整数 \(\text{len}\)。序列中的所有数要么在 \([1,k]\) 之间,要么等于 \(-1\)

我们称一个序列是好的,当且仅当不存在 \(\text{len}\) 个连续的相同的数字。

你可以将每个 \(-1\),替换成任意一个 \([1,k]\) 之间的整数。求有多少种方案,使得最终序列是好的。答案对 \(998244353\) 取模。

数据范围:\(1\leq n\leq 10^5\)\(1\leq k\leq 100\)\(1\leq \text{len}\leq n\)

经典做法

\(\text{dp}_0(i,j)\) 表示考虑了前 \(i\) 个位置,最后一个位置上填的数是 \(j\),前 \(i\) 个位置组成的序列合法的方案数。设 \(\text{sdp}(i)=\sum_{j=1}^{k}\text{dp}_0(i,j)\)

首先,当 \(a_i\neq -1\)\(a_i\neq j\) 时,\(\text{dp}_0(i,j)=0\)

否则,我们暂时令 \(\text{dp}_0(i,j)=\text{sdp}(i - 1)\)。但是这会把一些不合法的方案算进去。具体来说,这样有可能出现 \([a_{i-\text{len}+1}\dots a_i]\) 全部相同的情况。这种情况会出现当且仅当如下两个条件都满足:

  • \(i\geq \text{len}\)
  • \([a_{i-\text{len}+1}\dots a_i]\) 中每个数都等于 \(-1\)\(j\)

所以,我们还要减去这种情况的数量:\(\text{sdp}(i-\text{len})-\text{dp}_0(i-\text{len},j)\)。其中 \(\text{sdp}(i - \text{len})\) 表示 \([a_{i-\text{len}+1}\dots a_i]\) 全部等于 \(j\) 时,前 \(i - \text{len}\) 位的填写方案。不过这些方案中,有一些方案可能在前 \(i-1\) 位就已经导致不合法了。这些提前不合法的方案本来就没有被算在 \(\text{sdp}(i-1)\) 中,所以不需要被减去,它们的数量是 \(\text{dp}_0(i - \text{len},j)\)

综上所述,可以得到转移式:

\[\text{dp}_0(i,j)=\begin{cases} 0&& a_i\neq0\text{ 且 }a_i\neq j\\ \text{sdp}(i-1)-(\text{sdp}(i-\text{len})-\text{dp}_0(i-\text{len},j))\cdot [\text{上文中两个条件}]&&\text{otherwise} \end{cases} \]

时间复杂度\(O(nk)\)

更优秀的做法

首先,当 \(\text{len} = 1\) 时,答案一定是 \(0\)。以下只讨论 \(\text{len} > 1\) 的情况。

先考虑一种朴素的 DP。设 \(\text{dp}_1(i,j,l)\) 表示考虑了前 \(i\) 个位置,第 \(i\) 位上填的数是 \(j\),最后一个 \(\neq j\) 的位置是 \(l\),此时使得前 \(i\) 位组成的序列合法的方案数。

转移时,考虑当前位填了什么:

\[\begin{cases} \text{dp}_1(i-1,j,l)\to \text{dp}_1(i,j, l) && \text{if }l\geq i - \text{len}+1\\ \text{dp}_1(i-1,j,l)\to \text{dp}_1(i,x, i) && \text{if }x\neq j \end{cases} \]

初始状态为 \(\text{dp}_1(0,0,0) = 1\)。答案是 \(\sum_{j = 1}^{k}\sum_{l = n - \text{len}+1}^{n - 1}\text{dp}_1(n,j,l)\)

这个朴素 DP 的时间复杂度是 \(O(n^2k)\)


优化它!设上一位填的数为 \(j\),当前位填的数为 \(x\)。我们发现,当 \(x\)\(j\) 不同时,\(x\) 的值具体是什么其实不重要:对所有 \(x\neq j\),它们的转移是一模一样的。这就给了我们简化的空间。

定义 \(a_{n+1} = k+1\)。定义 \(\text{nxt}_i\) 表示位置 \(i\) 后面第一个 \(a_{i'}\neq -1\) 的位置 \(i'\)。设 \(\text{dp}_2(i,j\in\{0,1\},l)\) 表示考虑了前 \(i\) 个位置,第 \(i\) 位上填的数是 / 否等于 \(a_{\text{nxt}_i}\),前 \(i\) 位里最后一个填的数与第 \(i\) 位上不同的位置是 \(l\),此时使得前 \(i\) 位组成的序列合法的方案数。

转移分 \(a_i\) 是否为 \(-1\) 两种情况。

\(a_i\neq -1\) 时,枚举 \(l\)。则有如下转移:

\[\begin{cases} \text{dp}_2(i-1,0,l)\to \text{dp}_2(i,[a_i=a_{\text{nxt}_i}],i - 1)\\ \text{dp}_2(i-1,1,l)\to \text{dp}_2(i,[a_i=a_{\text{nxt}_i}],l) && \text{if }l\geq i - \text{len} + 1 \end{cases} \]

\(a_i = -1\) 时,显然 \(\text{nxt}_{i} =\text{nxt}_{i-1}\)。枚举 \(l\)。我们分别考虑如下情况:

  • \(i-1\) 位填的数与 \(a_{\text{nxt}_{i}}\) 不同:
    • \(i\) 位上填的数与第 \(i-1\) 位上填的数相同。
    • \(i\) 位上填的数与第 \(a_{\text{nxt}_i}\) 相同。
    • \(i\) 位上填的数,既不等于第 \(i-1\) 位上填的数,也不等于 \(a_{\text{nxt}_{i}}\)
  • \(i-1\) 位填的数与 \(a_{\text{nxt}_{i}}\) 相同:
    • \(i\) 位填的数与 \(a_{\text{nxt}_{i}}\) 不同。
    • \(i\) 位填的数与 \(a_{\text{nxt}_{i}}\) 相同。

这五种情况分别对应如下转移:

\[\begin{cases} \text{dp}_2(i - 1, 0, l)&\to \text{dp}_2(i, 0, l)&&\text{if }l\geq i - \text{len} + 1\\ \text{dp}_2(i - 1, 0, l)&\to\text{dp}_2(i, 1, i - 1)\\ \text{dp}_2(i - 1, 0, l)\times(k - 2) & \to\text{dp}_2(i,0,i-1)\\ \text{dp}_2(i - 1, 1, l)\times(k - 1) &\to \text{dp}_2(i, 0, i - 1)\\ \text{dp}_2(i - 1, 1, l)&\to\text{dp}_2(i,1,l) && \text{if }l\geq i - \text{len} + 1 \end{cases} \]

上述式子里默认 \(1 < i < \text{nxt}_i \leq n\)。当 \(i = 1\)\(\text{nxt}_i = n+1\) 时,有一些特殊情况要考虑。为了表述简洁,这里就不细写了。

现在,这个 DP 的时间复杂度是 \(O(n^2)\)。虽然无法 AC,但这是迈向 \(O(n)\) 做法的关键一步。我将这一 DP 的代码附在了本文末尾:点击跳转


在这个 DP 里,第 \(1\) 维的枚举不可避免,第 \(2\) 维的状态数已经被我们优化到 \(O(1)\)。考虑优化第 \(3\) 维。

观察转移式,发现从 \(i-1\) 变成 \(i\) 时,第 \(3\) 维的转移,相当于做如下操作:

  1. 区间求和。对所有 \(l\in[i - \text{len},i - 2]\) 求和。
  2. 单点加。
  3. 把一段区间的值覆盖为 \(0\)
  4. \(a_i\neq -1\)\(a_i\neq a_{\text{nxt}_i}\) 时,需要把 \(\text{dp}_2(i,1)\) 中的一段 \(l\),复制到 \(\text{dp}_2(i,0)\) 对应的位置上。

按第二维的 \(0,1\),维护两棵线段树,通过打懒标记即可实现这四种操作。

时间复杂度 \(O(n\log n)\)


继续优化,发现区间操作都是假的。

  • 操作 \(3\) 要么是单点清空,要么是全局清空。
  • 操作 \(1\) 和操作 \(4\),因为其他位置都清空了,所以区间求和、区间复制,其实就是全局求和、全局交换。

所以我们只需要用两个数组来维护。

时间复杂度 \(O(n)\)

参考代码

最终代码

友情提醒:使用读入、输出优化可以使代码更快,详见本博客公告。

// problem: CF1093F
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;

template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }

const int MAXN = 1e5;
const int MOD = 998244353;
inline int mod1(int x) { return x < MOD ? x : x - MOD; }
inline int mod2(int x) { return x < 0 ? x + MOD : x; }
inline void add(int &x, int y) { x = mod1(x + y); }
inline void sub(int &x, int y) { x = mod2(x - y); }


int n, K, len, a[MAXN + 5];

struct FantasticDataStructure {
	int sum;
	int arr[MAXN + 5];
	int TIM;
	int tim[MAXN + 5];
	void upd(int p) {
		if (tim[p] < TIM) {
			tim[p] = TIM;
			arr[p] = 0;
		}
	}
	void point_add(int p, int v) {
		upd(p);
		add(arr[p], v);
		add(sum, v);
	}
	void point_set0(int p) {
		upd(p);
		sub(sum, arr[p]);
		arr[p] = 0;
	}
	void global_set0() {
		TIM++;
		sum = 0;
	}
	int query() {
		return sum;
	}
	FantasticDataStructure() {}
};
FantasticDataStructure S[2];
int id[2];

// DP 转移: a[i] != -1 / a[i] == -1
void trans1(int p, int nxtval) {
	int v = S[id[0]].query();
	if (a[p] == nxtval) {
		S[id[0]].global_set0();
		if (p - len >= 0) {
			S[id[1]].point_set0(p - len);
		}
		S[id[1]].point_add(p - 1, v);
	} else {
		swap(id[0], id[1]);
		S[id[1]].global_set0();
		if (p - len >= 0) {
			S[id[0]].point_set0(p - len);
		}
		S[id[0]].point_add(p - 1, v);
	}
}
void trans2(int p, int flag) {
	int v0 = S[id[0]].query();
	int v1 = S[id[1]].query();
	
	if (p == 1) {
		S[id[0]].global_set0();
	} else {
		if (p - len >= 0) {
			S[id[0]].point_set0(p - len);
		}
	}
	
	int toadd = 0;
	if (flag + (p != 1) + 1 <= K) {
		toadd = (ll)v0 * (K - flag - (p != 1)) % MOD;
	}
	add(toadd, (ll)v1 * (K - 1) % MOD);
	S[id[0]].point_add(p - 1, toadd);
	
	if (p - len >= 0) {
		S[id[1]].point_set0(p - len);
	}
	if (flag) {
		S[id[1]].point_add(p - 1, v0);
	}
}
int main() {
	cin >> n >> K >> len;
	for (int i = 1; i <= n; ++i) {
		cin >> a[i];
	}
	if (len == 1) {
		cout << 0 << endl;
		return 0;
	}
	
	id[0] = 0;
	id[1] = 1;
	S[id[1]].point_add(0, 1);
	
	a[0] = K + 1;
	a[n + 1] = K + 2;
	for (int i = 1, j = 2; i <= n; ++i) {
		ckmax(j, i + 1);
		while (a[j] == -1)
			++j;
		if (a[i] != -1) {
			trans1(i, a[j]);
		} else {
			trans2(i, j != n + 1);
		}
	}
	cout << S[id[0]].query() << endl;
	return 0;
}

n^2 DP

为了帮助读者更好地理解题解,这里附上 \(O(n^2)\) 朴素 DP 的代码。

#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;

template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }

const int MAXN = 1000;
const int MOD = 998244353;
inline int mod1(int x) { return x < MOD ? x : x - MOD; }
inline int mod2(int x) { return x < 0 ? x + MOD : x; }
inline void add(int &x, int y) { x = mod1(x + y); }
inline void sub(int &x, int y) { x = mod2(x - y); }


int n, K, len, a[MAXN + 5];
int dp[MAXN + 5][2][MAXN + 5];

void trans1(int p, int nxtval) {
	for (int i = max(0, p - len); i <= max(0, p - 2); ++i) {
		add(dp[p][a[p] == nxtval][p - 1], dp[p - 1][0][i]);
		if (i >= p - len + 1)
			add(dp[p][a[p] == nxtval][i], dp[p - 1][1][i]);
	}
}
void trans2(int p, int flag) {
	for (int i = max(0, p - len); i <= max(0, p - 2); ++i) {
		// a[p - 1] 和 nxtval 不同
		if (p != 1 && i >= p - len + 1) {
			add(dp[p][0][i], dp[p - 1][0][i]); // a[p] 和 a[p - 1] 相同
		}
		if (flag) {
			add(dp[p][1][p - 1], dp[p - 1][0][i]); // a[p] 和 nxtval 相同
		}
		if (flag + (p != 1) + 1 <= K) {
			// a[p] 和 nxtval, a[p - 1] 都不同
			add(dp[p][0][p - 1], (ll)dp[p - 1][0][i] * (K - flag - (p != 1)) % MOD);
		}
		
		// a[p - 1] 和 nxtval 相同
		add(dp[p][0][p - 1], (ll)dp[p - 1][1][i] * (K - 1) % MOD);
		if (i >= p - len + 1) {
			add(dp[p][1][i], dp[p - 1][1][i]); // a[p] 和 a[p - 1], nxtval 都相同
		}
	}
}
int main() {
	cin >> n >> K >> len;
	for (int i = 1; i <= n; ++i) {
		cin >> a[i];
	}
	
	dp[0][0][0] = 1;
	a[0] = K + 1;
	a[n + 1] = K + 2;
	
	for (int i = 1, j = 2; i <= n; ++i) {
		ckmax(j, i + 1);
		while (a[j] == -1)
			++j;
		if (a[i] != -1) {
			trans1(i, a[j]);
		} else {
			trans2(i, j != n + 1);
		}
	}
	int ans = 0;
	for (int i = max(0, n - len + 1); i <= n - 1; ++i) {
		add(ans, dp[n][0][i]);
	}
	cout << ans << endl;
	return 0;
}
posted @ 2020-11-22 23:45  duyiblue  阅读(314)  评论(2编辑  收藏  举报