AcWing 4546. 最大和加强加强版
看到数据范围 \(n\times m \le 5\times 10^7\) 就可以猜出来状态肯定是 \(f_{i, j}\) 的,然后用滚动数组优化,或者直接用类似于 \(01\) 背包的方法优化。
我们设 \(f_{i, j}\) 表示,对于前 \(i\) 个数,已经找到了 \(j\) 个组,且第 \(i\) 个数属于 \(j\) 组,那么可以列转移式:
\[\begin{aligned}f_{i, j} = \max\{\max\limits_{k=1}^{i-1}\{f_{k, j - 1}\},f_{i-1,j}\}+a_i\end{aligned}
\]
左边的东西表示,第 \(j-1\) 组结束,开始选第 \(j\) 组,第 \(j\) 组第一个是 \(a_i\),右边的表示把 \(i\) 加入到 \(j\) 组中,第 \(i\) 个数不是开头。
可以显然发现会超时,我们可以在计算左边的东西时,直接用 \(g_{j}\) 辅助一下(注意:要倒着循环,不然会错的,因为如果正着循环,会更新了在 \(i\) 时更新的数 \(g_{j}\))
// #define FILE_INPUT
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#pragma GCC optimize(2)
#pragma GCC optimize(3)
using namespace std;
// #define int long long
#define rep(i, a, b) for (int i = a, END##i = b; i <= END##i; i++)
#define per(i, a, b) for (int i = a, END##i = b; i >= END##i; i--)
void Init();
void Solve();
signed main() {
cin.sync_with_stdio(0);
cin.tie(0), cout.tie(0);
#ifdef FILE_INPUT
freopen("input.in", "r", stdin);
#endif
int T = 1;
// cin >> T;
while (T--) {
Init();
Solve();
}
return 0;
}
using LL = long long;
using ULL = unsigned long long;
const int Mod = 1e9 + 7;
const int Inf = 0x3f3f3f3f;
const LL InfLL = 0x3f3f3f3f3f3f3f3f;
const int N = 1e6 + 10, M = 1010;
LL m, n, f[2][M], a[N], g[M];
void Init() {
}
#define max(x, y) (x > y ? x : y) // 少用,递归的时候时间复杂度会爆炸
#define min(x, y) (x < y ? x : y) // 少用,递归的时候时间复杂度会爆炸
void Solve() {
while (cin >> m >> n) {
rep(i, 1, n) cin >> a[i];
memset(f, -0x3f, sizeof(f));
memset(g, -0x3f, sizeof(g));
f[0][0] = f[1][0] = 0;
g[0] = 0;
rep(i, 1, n) {
int t = min(i, m);
per(j, t, 1) {
f[i & 1][j] = max(g[j - 1], f[i - 1 & 1][j]) + a[i];
g[j] = max(g[j], f[i & 1][j]);
}
}
cout << g[m] << "\n";
}
}
然后呢我们还可以再进行优化(虽然没有必要了),用类似于 \(01\) 背包的东西优化,把 \(f\) 数组直接变成一维的。
// #define FILE_INPUT
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#pragma GCC optimize(2)
#pragma GCC optimize(3)
using namespace std;
// #define int long long
#define rep(i, a, b) for (int i = a, END##i = b; i <= END##i; i++)
#define per(i, a, b) for (int i = a, END##i = b; i >= END##i; i--)
void Init();
void Solve();
signed main() {
cin.sync_with_stdio(0);
cin.tie(0), cout.tie(0);
#ifdef FILE_INPUT
freopen("input.in", "r", stdin);
#endif
int T = 1;
// cin >> T;
while (T--) {
Init();
Solve();
}
return 0;
}
using LL = long long;
using ULL = unsigned long long;
const int Mod = 1e9 + 7;
const int Inf = 0x3f3f3f3f;
const LL InfLL = 0x3f3f3f3f3f3f3f3f;
const int N = 1e6 + 10, M = 1010;
LL m, n, f[M], a[N], g[M];
void Init() {
}
#define max(x, y) (x > y ? x : y) // 少用,递归的时候时间复杂度会爆炸
#define min(x, y) (x < y ? x : y) // 少用,递归的时候时间复杂度会爆炸
void Solve() {
while (cin >> m >> n) {
rep(i, 1, n) cin >> a[i];
memset(f, -0x3f, sizeof(f));
memset(g, -0x3f, sizeof(g));
f[0] = 0;
g[0] = 0;
rep(i, 1, n) {
int t = min(i, m);
per(j, t, 1) {
f[j] = max(g[j - 1], f[j]) + a[i];
g[j] = max(g[j], f[j]);
}
}
cout << g[m] << "\n";
}
}
注意:\(n,m\) 不要写反了,边界一定要处理好

浙公网安备 33010602011771号