CF833B The Bakery 题解
题目传送门
题目大意
将一个长度为 \(n\) 的序列分为 \(k\) 段,使得总价值最大。
一段区间的价值表示为区间内不同数字的个数。
\(n≤35000,k≤50\)
输入输出样例 #1
输入 #1
4 1
1 2 2 1
输出 #1
2
输入输出样例 #2
输入 #2
7 2
1 3 3 1 4 4 4
输出 #2
5
输入输出样例 #3
输入 #3
8 3
7 7 8 7 7 8 1 7
输出 #3
6
说明/提示
In the first example Slastyona has only one box. She has to put all cakes in it, so that there are two types of cakes in the box, so the value is equal to $ 2 $ .
In the second example it is profitable to put the first two cakes in the first box, and all the rest in the second. There are two distinct types in the first box, and three in the second box then, so the total value is $ 5 $ .
思路
基础动规
一眼动规(区间dp),设\(\text{dp}_{i,j}\)表示用 \(i\) 个盒子装前 \(j\) 个蛋糕所能达到的最大价值。区间dp典型思路枚举决策点(此处决策为最后一个盒子装哪些蛋糕)。
枚举最后一个盒子为盒子 \(i\) ,装的是 \(a_{k+1},a_{k+2},\cdots,a_j\),此删除此盒子后的状态为则\(\text{dp}_{i-1,k}\)。
状态转移需加上最后一个盒子盒子 \(i\) 的代价,即 \(\text{w}(k + 1,j)\)。(\(\text w\)函数是用来统计给定区间\([l,r]\)里不同数字的个数的函数,即此区间的价值)
\(\text{dp}_{i,j}=\max_{k=i-1}^{j}{\text{dp}_{i-1,k}+\text{w}(k+1,j)}\)
然而! 我们注意到了数据范围,\(n\in [1,3.5\times10^4]\) ,如果使用暴力循环动规的话时间复杂度为 外层循环和w函数的乘积,即\(n\times k\times n\times n=\text{O}(n^3k)\),显然爆炸。
优化
显然时间复杂度O(nk)的循环是 有必要的,那么时间复杂度的第三维就只能是\(\log n\)了
考虑将 \(\text{dp}_{i-1,k}+\text{w}(k+1,j)\) 捆绑,视为一项,这样就可以用最简单的线段树模版(最值线段树)维护,再\(\text{O}(\log n)\)计算w函数,程序的时间复杂度就降到了\(\text O(nk\log n)\),可以通过。
\(\text O(n^2)\rightarrow\text O(\log n)\)
这里我将尽力使用通俗易懂(或者说精确)的语言解释各个题解中对于“数字贡献”的描述,并且尽量提供微量具体的代码片段,以帮助理解。
我们注意到,计算\(\text w\)函数的过程中,一个数的“影响范围”显然是连续的,可以用线段树整体log维护,而不用逐个O(1)(累计O(n))的复杂度去累加
而这点具体表示为,若一个数 \(a_i\) ,与其两侧相邻的数为 \(a_x,a_y\) ,则 \(a_i\) 能够影响到\(\text{w}(L,R)\) 需满足
我们在枚举dp第二层(蛋糕个数为j时),可以不断更新 \(a_x\) 中的下标 \(x\),记为 \(\text{pre}_{a_i}\),将 \(j\) 视为 上述 \(R\) ,这样就不需要考虑在 \(a_i\) 右侧的 \(a_y\)了。枚举到 \(a_j\) 时,自然想到使用线段树累加所有的\(\text{w}(l,j),其中l\in\{\text{pre}_{a_i}+1,j\}\)。做完这一步后,立刻更新\(\text{dp}_{i,j}\)为线段树中\(1,j\)的最值,代表了\(\max_{k=i-1}^{j}{\text{dp}_{i-1,k}+\text{w}(k+1,j)}\),与上式匹配,互相印证,说明我们的方向是正确的。
代码
#include <bits/stdc++.h>
using namespace std;
const int N = 35005, K = 55;
int n, k, a[N];
int dp[N]; // dp[i][j]: maximum value obtained by ordering first j cakes into i boxes
int pre[N]; // pre[i]: 元素i上次出现的位置
namespace ST
{
#define ls i << 1
#define rs i << 1 | 1
// 注意此处线段树是为了区间max而服务的,因此并不是说主要思路就是线段树
struct SegTree
{
struct NODE
{
int l, r; // 节点i表示的区间左右端点,不是左右儿子
int sum; // [l,r]区间max(不用管懒标记)
int lz; // sum的懒标记(写惯了懒得改max了
} tr[N << 2];
void build(int i, int L, int R)
{
tr[i].l = L;
tr[i].r = R;
tr[i].lz = 0;
if (L == R) {
tr[i].sum = dp[L - 1];
return ;
}
int mid = (L + R) >> 1;
build(ls, L, mid);
build(rs, mid + 1, R);
tr[i].sum = max(tr[ls].sum, tr[rs].sum);
}
void pushdown(int i)
{
if (!tr[i].lz) return ;
tr[ls].sum += tr[i].lz;
tr[ls].lz += tr[i].lz;
tr[rs].sum += tr[i].lz;
tr[rs].lz += tr[i].lz;
tr[i].lz = 0;
}
void add(int i, int L, int R, int x)
{
if (R < tr[i].l || tr[i].r < L) return ;
if (L <= tr[i].l && tr[i].r <= R) {
tr[i].sum += x;
tr[i].lz += x;
return ;
}
pushdown(i);
if (tr[ls].r >= L) add(ls, L, R, x);
if (tr[rs].l <= R) add(rs, L, R, x);
tr[i].sum = max(tr[ls].sum, tr[rs].sum);
}
int query(int i, int L, int R)
{
if (R < tr[i].l || tr[i].r < L) return 0;
if (L <= tr[i].l && tr[i].r <= R) return tr[i].sum;
pushdown(i);
int res = 0;
if (tr[ls].r >= L) res = max(res, query(ls, L, R));
if (tr[rs].l <= R) res = max(res, query(rs, L, R));
return res;
}
};
#undef ls
#undef rs
}
using namespace ST;
SegTree st;
int main()
{
scanf("%d%d", &n, &k);
for (int i = 1; i <= n; ++i) scanf("%d", a + i);
/* 草稿:
// 初始化一个盒子的情况:
// for (int i = 1; i <= n; ++i) {
// dp[1][i] = c(1, i);
// }
// for (int i = 2; i <= k; ++i) { // 枚举i个盒子
// for (int j = i; j <= n; ++j) { // 枚举前j个元素
// // 考虑现有阶段由少装一个盒子的阶段得到
// // 枚举最后一个盒子装的第一个元素,注意这里最后一个盒子至少装了i-1个元素
// for (int st = i - 1; st < j; ++st) {
// // 转移方程意义:由上一个状态得到,再加上价值
// dp[i][j] = max(dp[i][j], dp[i - 1][st] + c(st + 1, j));
// }
// }
// }
*/
for (int i = 1; i <= k; ++i) {
fill(pre, pre + n + 1, 0);
st.build(1, 1, n);
for (int j = 1; j <= n; ++j) {
// 一个数的贡献影响到它上一次出现的位置到这个下标
// printf("a[%d]: %d, the last time it appeared was %d\n", j, a[j], pre[a[j]]);
st.add(1, pre[a[j]] + 1, j, 1);
pre[a[j]] = j; // 更新此值最后出现的位置
dp[j] = st.query(1, 1, j);
}
// for (int j = 1; j <= n; ++j) {
// printf("w(%d, %d)=%d\n", j, i, st.query(1, j, j));
// }
}
printf("%d\n", dp[n]);
return 0;
}

浙公网安备 33010602011771号