0601 NOI2021 模拟
A
题目链接:咕咕咕
我考试的时候从正着走的角度出发,发现选完每个数之后,会将前面一个类似于二进制数的状态 \(+1\),然后触发一系列的收益。然后注意到收益仅仅与最后 \(\log\) 个选的状态相关,就设计了一个 \(O(n^3)\) 的 dp. 后来由于发现一长串 1 的情况我可能无法处理,就又加了一维变成了 \(O(n^4)\),剪了剪枝跑的飞快,能过 \(n = 200\).
结束之后一问,发现 zjx 和 lyz 都轻松爆切了。那这个东西的难点,或者说做法的本质区别究竟是什么呢?
区别在,我做的事情是从前往后看“进位”,都知道进位可以牵连到很多位,所以这个状态比较难优化。
人家做的事情是从后往前看“进位”,虽然进位可以牵连到很多位,但是我可以把后面的东西堆起来,之后一次性往前进。
所以人家的做法是这样的:
倒着做,变成严格不降的。设 \(dp_{i, j}\) 表示当前我考虑到了第 \(i\) 个数,已经选了 \(j\) 个和我一样的数的最大收益。
再设一个辅助转移的数组,\(memd_{i, j}\) 表示当前某个子序列值为 \(i\),已经给我进了 \(j\) 位的最大收益。
这样就可以转移了(用 \(dp\) 更新 \(memd\),用 \(memd\) 更新 \(dp\))
初始状态 \(memd_{i, 0}\) 是 \(0\).
注意复杂度正确的原因:当 \(memd_i\) 中的 \(i\) 每增加 \(1\) 时,\(j\) 会 \(/2\),所以只会更新 \(\log\) 轮,综合复杂度为 \(O(n^2 \log n)\).
#include <bits/stdc++.h>
#define rep(i,l,r) for(int i = (l); i <= (r); ++i)
#define per(i,r,l) for(int i = (r); i >= (l); --i)
using namespace std;
typedef long long ll;
inline int gi() {
int f = 1, x = 0; char ch = getchar();
while (ch < '0' || ch > '9') {if (ch == '-') f = -f;ch = getchar();}
while (ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar();
return f * x;
}
const int N = 2005;
int n, m, l[N], s[N], c[N << 1];
int dp[N][N << 1];
int rec[N][N << 1];
void solve() {
reverse (l + 1, l + 1 + n);
reverse (s + 1, s + 1 + n);
memset (dp, 0xcf, sizeof (dp));
memset (rec, 0xcf, sizeof (rec));
for (int i = 1; i <= m; ++i)
rec[i][0] = 0;
int ans = 0;
for (int i = 1; i <= n; ++i) {
for (int j = 0; j <= i; ++j) {
dp[i][j] = rec[l[i]][j];
if (j) dp[i][j] = max(dp[i][j], rec[l[i]][j - 1] + c[l[i]] - s[i]);
int cur = j, nx = 0, bit = 1;
while (cur) nx += (cur >> 1) * c[l[i] + bit], ++bit, cur >>= 1;
ans = max(ans, dp[i][j] + nx);
}
int mx = 0;
for (int j = 0; j <= i; ++j) if (dp[i][j] > -1e8){
int prof = 0, cur = j;
for (int k = l[i]; k <= l[i] + 13; ++k) {
rec[k][cur] = max(rec[k][cur], dp[i][j] + prof);
prof += (cur >> 1) * c[k + 1];
cur >>= 1;
}
mx = max(mx, dp[i][j] + prof);
}
for (int k = l[i] + 14; k <= m + (int)__lg(n); ++k)
rec[k][0] = max(rec[k][0], mx);
}
cout << ans << '\n';
}
int main() {
n = gi(), m = gi();
rep (i, 1, n) l[i] = gi();
rep (i, 1, n) s[i] = gi();
rep (i, 1, n + m) c[i] = gi();
solve();
return 0;
}

浙公网安备 33010602011771号