0601 NOI2021 模拟

A

题目链接:咕咕咕

我考试的时候从正着走的角度出发,发现选完每个数之后,会将前面一个类似于二进制数的状态 \(+1\),然后触发一系列的收益。然后注意到收益仅仅与最后 \(\log\) 个选的状态相关,就设计了一个 \(O(n^3)\) 的 dp. 后来由于发现一长串 1 的情况我可能无法处理,就又加了一维变成了 \(O(n^4)\),剪了剪枝跑的飞快,能过 \(n = 200\).

结束之后一问,发现 zjx 和 lyz 都轻松爆切了。那这个东西的难点,或者说做法的本质区别究竟是什么呢?

区别在,我做的事情是从前往后看“进位”,都知道进位可以牵连到很多位,所以这个状态比较难优化。

人家做的事情是从后往前看“进位”,虽然进位可以牵连到很多位,但是我可以把后面的东西堆起来,之后一次性往前进。

所以人家的做法是这样的:

倒着做,变成严格不降的。设 \(dp_{i, j}\) 表示当前我考虑到了第 \(i\) 个数,已经选了 \(j\) 个和我一样的数的最大收益。

再设一个辅助转移的数组,\(memd_{i, j}\) 表示当前某个子序列值为 \(i\),已经给我进了 \(j\) 位的最大收益。

这样就可以转移了(用 \(dp\) 更新 \(memd\),用 \(memd\) 更新 \(dp\)

初始状态 \(memd_{i, 0}\)\(0\).

注意复杂度正确的原因:当 \(memd_i\) 中的 \(i\) 每增加 \(1\) 时,\(j\)\(/2\),所以只会更新 \(\log\) 轮,综合复杂度为 \(O(n^2 \log n)\).

#include <bits/stdc++.h>
#define rep(i,l,r) for(int i = (l); i <= (r); ++i)
#define per(i,r,l) for(int i = (r); i >= (l); --i)
using namespace std;
typedef long long ll;
inline int gi() {
	int f = 1, x = 0; char ch = getchar();
	while (ch < '0' || ch > '9') {if (ch == '-') f = -f;ch = getchar();}
	while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
	return f * x;
}
const int N = 2005;
int n, m, l[N], s[N], c[N << 1]; 
int dp[N][N << 1];
int rec[N][N << 1];
void solve() {
	reverse (l + 1, l + 1 + n);
	reverse (s + 1, s + 1 + n);
	memset (dp, 0xcf, sizeof (dp));
	memset (rec, 0xcf, sizeof (rec));
	for (int i = 1; i <= m; ++i)
		rec[i][0] = 0;
	int ans = 0;
	for (int i = 1; i <= n; ++i) {
		for (int j = 0; j <= i; ++j) {
			dp[i][j] = rec[l[i]][j];
			if (j) dp[i][j] = max(dp[i][j], rec[l[i]][j - 1] + c[l[i]] - s[i]);
			int cur = j, nx = 0, bit = 1;
			while (cur) nx += (cur >> 1) * c[l[i] + bit], ++bit, cur >>= 1;
			ans = max(ans, dp[i][j] + nx);
		}
		int mx = 0;
		for (int j = 0; j <= i; ++j) if (dp[i][j] > -1e8){
			int prof = 0, cur = j;
			for (int k = l[i]; k <= l[i] + 13; ++k) {
				rec[k][cur] = max(rec[k][cur], dp[i][j] + prof);
				prof += (cur >> 1) * c[k + 1];
				cur >>= 1;
			}
			mx = max(mx, dp[i][j] + prof);
		}
		for (int k = l[i] + 14; k <= m + (int)__lg(n); ++k)
			rec[k][0] = max(rec[k][0], mx);
	}
	cout << ans << '\n';
}
int main() {
	n = gi(), m = gi();
	rep (i, 1, n) l[i] = gi();
	rep (i, 1, n) s[i] = gi();
	rep (i, 1, n + m) c[i] = gi();
	solve();
	return 0;
}
posted @ 2021-06-01 14:12  LiM_233  阅读(77)  评论(0)    收藏  举报