QOJ #5421. Factories Once More 题解
Description
有一片由 \(n\) 座城市组成的王国。城市的编号从 \(1\) 到 \(n\)(含两端)且有 \((n − 1)\) 条道路连接各个城市。对于任意两座城市,居民们都可以沿着这些道路互相访问。
皇后最近决定建设 \(k\) 座新的工厂。为了防止污染,她规定每座城市最多只能建立一座工厂。
您作为皇家设计师,需要在规划建设的同时,求出两两工厂之间距离之和的最大值。
两座工厂之间的距离,即为两座工厂所在的两座城市之间的最短路径长度。路径的长度即为路径中所有边的长度之和。
\(n\leq 10^5\)。
Solution
对于每条边算贡献,假设一条边下面那个点的子树里选了 \(c\) 个点,那么贡献为 \(c(k-c)\),所以考虑对于子树树从下往上 dp。
设 \(f_{u,i}\) 表示 \(u\) 的子树里选了 \(i\) 个点,\(u\) 子树里的边(包括 \(u\) 与其父亲的边)的最大总贡献。
那么首先有转移:\(f_{u,i+j}\leftarrow f'_{u,i}+f_{v,j}\)。
得到子树后的贡献后,再加上 \(u\) 父亲上的边的贡献,即:\(f_{u,i}\leftarrow f_{u,i}+wi(k-i)\)。
但是这里有个 \((\max,+)\) 卷积,不能暴力转移。注意到 \(wi(k-i)\) 是下凸的,其差分为 \(w(k-2i+1)\),所以 dp 数组也是下凸的。用平衡树维护差分数组,需要支持启发式合并以及加等差数列,直接打标记即可。
时间复杂度:\(O(n\log^2n)\)。
Code
#include <bits/stdc++.h>
#define int int64_t
const int kMaxN = 1e5 + 5;
int n, k;
int rt[kMaxN];
std::vector<std::pair<int, int>> G[kMaxN];
std::mt19937 rnd(114514);
struct FHQTreap {
int tot, ls[kMaxN], rs[kMaxN], sz[kMaxN], val[kMaxN], rd[kMaxN], tag1[kMaxN], tag2[kMaxN];
int newnode(int v) {
sz[++tot] = 1, val[tot] = v, ls[tot] = rs[tot] = tag1[tot] = tag2[tot] = 0, rd[tot] = rnd();
return tot;
}
void pushup(int x) {
sz[x] = sz[ls[x]] + sz[rs[x]] + 1;
}
void addtag1(int x, int v) { val[x] += v, tag1[x] += v; }
void addtag2(int x, int v) {
val[x] += v * (sz[ls[x]] + 1), tag2[x] += v;
}
void pushdown(int x) {
if (tag1[x]) {
if (ls[x]) addtag1(ls[x], tag1[x]);
if (rs[x]) addtag1(rs[x], tag1[x]);
tag1[x] = 0;
}
if (tag2[x]) {
if (ls[x]) addtag2(ls[x], tag2[x]);
if (rs[x]) addtag1(rs[x], tag2[x] * (sz[ls[x]] + 1)), addtag2(rs[x], tag2[x]);
tag2[x] = 0;
}
}
int merge(int x, int y) {
// std::cerr << "??? " << x << ' ' << y << ' ' << ls[x] << ' ' << ls[y] << ' ' << rs[x] << ' ' << rs[y] << '\n';
if (!x || !y) return x + y;
pushdown(x), pushdown(y);
if (rd[x] < rd[y]) {
rs[x] = merge(rs[x], y), pushup(x);
return x;
} else {
ls[y] = merge(x, ls[y]), pushup(y);
return y;
}
}
void split(int x, int v, int &a, int &b) {
if (!x) return void(a = b = 0);
pushdown(x);
if (val[x] >= v) {
a = x, split(rs[x], v, rs[x], b);
pushup(a);
} else {
b = x, split(ls[x], v, a, ls[x]);
pushup(b);
}
}
void ins(int &rt, int x) {
int a, b;
// std::cerr << "??? " << rt << ' ' << x << '\n';
split(rt, val[x], a, b);
rt = merge(a, merge(x, b));
}
void update(int rt, int v1, int v2) {
// std::cerr << "fuck " << rt << ' ' << v1 << ' ' << v2 << '\n';
addtag1(rt, v1), addtag2(rt, v2);
}
void insall(int &x, int y) {
if (sz[x] < sz[y]) std::swap(x, y);
// std::cerr << sz[x] << ' ' << sz[y] << ' ';
std::vector<int> id;
std::function<void(int)> dfs = [&] (int x) {
if (!x) return;
pushdown(x);
if (ls[x]) dfs(ls[x]);
if (rs[x]) dfs(rs[x]);
ls[x] = rs[x] = tag1[x] = tag2[x] = 0, sz[x] = 1;
id.emplace_back(x);
};
dfs(y);
for (auto i : id) ins(x, i);
// std::cerr << id.size() << ' ' << sz[x] << '\n';
}
void print(int x) {
if (!x) return;
pushdown(x);
if (ls[x]) print(ls[x]);
std::cerr << val[x] << ' ';
if (rs[x]) print(rs[x]);
}
int getsum(int x, int k) {
if (!x) return 0;
// std::cerr << "??? " << x << ' ' << val[x] << ' ' << k << ' ' << sz[ls[x]] << '\n';
pushdown(x);
assert(sz[x] >= k);
if (k <= sz[ls[x]]) return getsum(ls[x], k);
else if (k <= sz[ls[x]] + 1) return getsum(ls[x], sz[ls[x]]) + val[x];
else return getsum(ls[x], sz[ls[x]]) + val[x] + getsum(rs[x], k - sz[ls[x]] - 1);
}
} t;
void dfs(int u, int fa, int faw) {
rt[u] = t.newnode(0);
// std::cerr << t.sz[rt[u]] << '\n';
for (auto [v, w] : G[u]) {
if (v == fa) continue;
dfs(v, u, w);
t.insall(rt[u], rt[v]);
}
t.update(rt[u], faw * (k + 1), -2 * faw);
// t.print(rt[u]), std::cerr << '\n';
}
void dickdreamer() {
std::cin >> n >> k;
for (int i = 1; i < n; ++i) {
int u, v, w;
std::cin >> u >> v >> w;
G[u].emplace_back(v, w), G[v].emplace_back(u, w);
}
dfs(1, 0, 0);
// t.print(rt[1]), std::cerr << '\n';
std::cout << t.getsum(rt[1], k) << '\n';
// std::cerr << t.sz[rt[1]] << '\n';
}
int32_t main() {
#ifdef ORZXKR
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#endif
std::ios::sync_with_stdio(0), std::cin.tie(0), std::cout.tie(0);
int T = 1;
// std::cin >> T;
while (T--) dickdreamer();
// std::cerr << 1.0 * clock() / CLOCKS_PER_SEC << "s\n";
return 0;
}