Three Servers 题解
Three Servers 题解
首先,设 \(f_{i,j,k}\) 表示考虑前 \(i\) 个请求,有两个服务器的时间为 \(j\) 和 \(k\) 的可达性。
此时空间复杂度为 \(O(n^3t^2)\)。无法通过此题。
优化一
发现 \(f\) 为 \(0/1\) 数组,考虑使用 bitset
进行优化。\(f\) 的转移也有规律。
前两个转移可以直接使用 bitset
的或,而最后一个转移可以先左移 \(t_{i + 1}\) 位,再或。代码如下:
f[i + 1][j] |= (f[i][j]); // 第一个转移
f[i + 1][j + t[i + 1]] |= (f[i][j]); //第二个转移
f[i + 1][j] |= (f[i][j] << t[i + 1]); //第三个转移
优化二
我们可以猜想极差不会超过 \(30\)。
下面给出简单证明。
设最终三个服务器的时间为 \(a, b, c\) (\(a \leq b \leq c\) 且 \(c-a>30\) )。那么现在我们可以将 \(c\) 中一个时间为 \(x\) 的任务移到 \(a\) 中。目前的三个数为 \(a+x,b,c-x\),重新排列后可以发现,三个数会越来越接近,即极差有可能减小。我们继续排列这三个数。可以发现,最后的极差一定会小于 \(30\)。
这对于优化空间有什么用呢?
我们考虑最坏情况,最后的三个服务器为 \(x, x + 30, x + 30\),那么 \(x\) 就是 \(O(\frac{nt}{3})\) 量级的了。
现在计算一下总空间复杂度为 \(O(\frac{n^3t^2}{9\omega})\)。仍然不足以通过此题。
目前转移代码如下:
f[i + 1][j] |= (f[i][j]); // 第一个转移
if (j + t[i + 1] < M) // 判断是否超过 nt/3
f[i + 1][j + t[i + 1]] |= (f[i][j]); //第二个转移
f[i + 1][j] |= (f[i][j] << t[i + 1]); //第三个转移
优化三
有人会说,我们不能滚动数组吗?
但是我们需要求方案,所以不能完全滚。
但是,我们发现时间十分宽裕,而空间不够,考虑使用时间换空间。我们可以滚动数组,每 \(B\) 就滚动一次,这样我们就可以省下一部分空间,而这样我们将 \(n\) 个询问分成了很多个段,我们可以从 \(0\) 跑到每一个段,然后再求方案。这里 \(B = \min(\left \lceil \frac{n}{4} \right \rceil ,2)\)。
空间多了一个 \(\frac{1}{4}\) 的常数,可过。
代码实现
Dp 部分
将 Dp
部分写成函数,能大幅缩短代码量。下面的 dp(m)
就是求 \(0\) 到 \(m\) 的Dp
值。
void dp(int m) { //
for (int i = 0; i < M; ++i)
f[0][i].reset();
f[0][0][0] = true; // 初始化
for (int i = 0; i < m; ++i) {
for (int j = 0; j < M; ++j) {
f[(i + 1) % Md][j].reset(); // 清空滚动数组
}
for (int j = 0; j < M; ++j) { // 转移
f[(i + 1) % Md][j] |= (f[i % Md][j]);
if (j + t[i + 1] < M) f[(i + 1) % Md][j + t[i + 1]] |= (f[i % Md][j]);
f[(i + 1) % Md][j] |= (f[i % Md][j] << t[i + 1]);
}
}
}
主函数求答案部分
记录总和,求 Dp
值,然后枚举每种可能结果,求 \(\min\) 即可。
for (int i = 1; i <= n; ++i) {
cin >> t[i];
sum += t[i]; // 记录总和
}
Md = max(int(ceil(n / 4.0)), 2);
int ret = INT_MAX, ri, rj; // ri, rj 为在 ret 最优情况下两个服务器的值
dp(n); // Dp
for (int i = 0; i < M; ++i)
for (int j = 0; j < M; ++j)
if (f[n % Md][i][j]) { // 如果当前状态可行
int lst = sum - i - j; // 找到除了状态中,另外一个服务器的值
int w = max({lst, i, j}) - min({lst, i, j}); // 求极差
if (ret > w) {
ret = w; // 更新
ri = i, rj = j;
}
}
主函数求方案部分
首先将 \(n\) 划分成若干个段 (不超过 \(4\) 个),然后倒推求方案。
考虑对于 \(f_{i+1,ri, rj}\) 可以推回哪些状态。
反写转移式,可以发现,\(f_{i+1,ri,rj}\) 能推到 \(f_{i,ri,rj}\)、\(f_{i,ri-t_{i+1},rj}\) 和 \(f_{i,ri, rj-t_{i+1}}\) 的任意一个,然后继续倒推即可。
pos[0] = 0;
for (ncnt = 1; ; ++ncnt) { // 分段
pos[ncnt] = pos[ncnt - 1] + Md;
if (pos[ncnt] > n) {
pos[ncnt] = n;
break;
}
}
for (int i = ncnt; i; --i) {
int tl = pos[i - 1], tr = min(pos[i] - 1, n); // 找到当前区间
if (tl > tr) continue;
if (i != ncnt)
dp(tr); // 如果是最后一段,不必重新 Dp,直接计算即可。
for (int j = tr; j >= tl; --j) { // 倒推
if (f[j % Md][ri][rj]) {
ans[2].push_back(j + 1);
continue;
}
if (ri >= t[j + 1] && f[j % Md][ri - t[j + 1]][rj]) {
ri -= t[j + 1];
ans[0].push_back(j + 1);
continue;
}
ans[1].push_back(j + 1);
rj -= t[j + 1];
}
}
总代码
#include <bits/stdc++.h>
#define FASTIO ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
using namespace std;
using ll = long long;
using pii = pair<int, int>;
const int N = 105;
const int M = 4005;
int t[N << 2], n, Md, sum[N << 2];
bitset<M> f[N][M];
void dp(int m) {
for (int i = 0; i < M; ++i)
f[0][i].reset();
f[0][0][0] = true;
for (int i = 0; i < m; ++i) {
for (int j = 0; j < M; ++j) {
f[(i + 1) % Md][j].reset();
}
for (int j = 0; j < M; ++j) {
f[(i + 1) % Md][j] |= (f[i % Md][j]);
if (j + t[i + 1] < M) f[(i + 1) % Md][j + t[i + 1]] |= (f[i % Md][j]);
f[(i + 1) % Md][j] |= (f[i % Md][j] << t[i + 1]);
}
}
}
vector<int> ans[3];
int pos[5], ncnt;
void print(int i) {
cout << ans[i].size() << ' ';
for (auto x : ans[i])
cout << x << ' ';
cout << '\n';
}
int main() {
// FASTIO;
cin >> n;
for (int i = 1; i <= n; ++i) {
cin >> t[i];
sum[i] = sum[i - 1] + t[i];
}
Md = max(int(ceil(n / 4.0)), 2);
int ret = INT_MAX, ri, rj;
dp(n);
for (int i = 0; i < M; ++i)
for (int j = 0; j < M; ++j)
if (f[n % Md][i][j]) {
int lst = sum[n] - i - j;
int w = max({lst, i, j}) - min({lst, i, j});
if (ret > w) {
ret = w;
ri = i, rj = j;
}
}
cout << ret << '\n';
pos[0] = 0;
for (ncnt = 1; ; ++ncnt) {
pos[ncnt] = pos[ncnt - 1] + Md;
if (pos[ncnt] > n) {
pos[ncnt] = n;
break;
}
}
for (int i = ncnt; i; --i) {
int tl = pos[i - 1], tr = min(pos[i] - 1, n);
if (tl > tr) continue;
if (i != ncnt)
dp(tr);
for (int j = tr; j >= tl; --j) {
if (f[j % Md][ri][rj]) {
ans[2].push_back(j + 1);
continue;
}
if (ri >= t[j + 1] && f[j % Md][ri - t[j + 1]][rj]) {
ri -= t[j + 1];
ans[0].push_back(j + 1);
continue;
}
ans[1].push_back(j + 1);
rj -= t[j + 1];
}
}
print(0);
print(1);
print(2);
return 0;
}