AtCoder Beginner Contest 214 H Collecting
好毒瘤啊……
首先考虑一个环内的 \(a_i\) 到了就一定能拿完,所以肯定是要先缩点的。设 \(s_i\) 为点 \(i\) 所在强联通分量编号(按拓扑序编号),\(b_i\) 为第 \(i\) 个分量的点权和,\(t\) 为分量个数。
考虑最大费用最大流。对每个点拆成入点和出点 \(u, u'\)。下面用 \((u, v, x, y)\) 表示 \(u \to v\),容量为 \(x\),单位流量费用为 \(y\) 的边。考虑这样建图:
-
\(\forall u \in [1, t], (u, u', 1, b_u), (u, u', +\infty, 0)\),表示一个点只有第一个走的人能获得 \(b_u\) 的权值。
-
\(\forall (u, v) \in E, (s_u', s_v, +\infty, 0)\),表示边不产生贡献。
-
\((S, s_1, K, 0)\),表示所有人都要从 \(s_1\) 出发;
-
\(\forall u \in [1, t], (u', T, +\infty, 0)\),表示任何点都能作为终点。
最后从 \(S\) 到 \(T\) 的最大流的最大费用。求最大费用最大流,可以把费用全部乘上 \(-1\),做最小费用最大流。
但是这样是 \(O(Knm)\) 的。
因为最大流很小,所以如果我们能把费用都变成非负的,就可以用 Primal-Dual 原始对偶算法,把时间复杂度优化到 \(O(Km \log n)\) 了。因此考虑换一种建图方式,与其计算能得到的权值,不如计算损失了多少权值。
考虑因为缩点之后是一个 DAG,所以走一条边 \(x \to y\) 意味着 \([x + 1, y - 1]\) 都走不到了。基于此尝试以下连边:
-
\(\forall u \in [1, t], (u, u', 1, 0), (u, u', +\infty, b_u)\)。
-
\(\forall (u, v) \in E, (s_u', s_v, +\infty, \sum\limits_{i = u + 1}^{v - 1} b_i)\)。
-
\((S, s_1, K, \sum\limits_{i = 1}^{s_1 - 1} b_i)\)。
-
\(\forall u \in [1, t], (u', T, +\infty, \sum\limits_{i = u + 1}^{t})\)。
最后答案是 \((K \times \sum\limits_{i = 1}^t b_i) - \text{mincost}\)。
这里简单讲一下 Primal-Dual 原始对偶算法。其思想是通过设置一个每个点的势能 \(h_i\),把每条边的费用 \(w\) 变成 \(w + h_u - h_v\),使得它 \(\ge 0\),这样就能跑 Dijkstra 找出增广路。每个点的初始势能是它到 \(S\) 的最短路,每次增广,\(h_u \gets h_u + d_u\),这样能保证费用非负。如果原图所有边费用都是非负的,第一次求 \(h_i\) 也可以用 Dijkstra。对于具体证明,可以看这里。
code
// Problem: H - Collecting
// Contest: AtCoder - AtCoder Beginner Contest 214
// URL: https://atcoder.jp/contests/abc214/tasks/abc214_h
// Memory Limit: 1024 MB
// Time Limit: 4000 ms
//
// Powered by CP Editor (https://cpeditor.org)
#include <bits/stdc++.h>
#define pb emplace_back
#define fst first
#define scd second
#define mems(a, x) memset((a), (x), sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef long double ldb;
typedef pair<ll, ll> pii;
const int maxn = 1000100;
const ll inf = 0x3f3f3f3f3f3f3f3fLL;
ll n, m, K, a[maxn], b[maxn], c[maxn], head[maxn], len = 1, ntot, id[maxn][2], S, T;
ll dfn[maxn], low[maxn], scc[maxn], times, tot, stk[maxn], top;
vector<int> G[maxn];
struct edge {
ll from, to, next, cap, flow, cost;
} edges[maxn * 5];
inline void add_edge(ll u, ll v, ll c, ll f, ll co) {
edges[++len].from = u;
edges[len].to = v;
edges[len].next = head[u];
edges[len].cap = c;
edges[len].flow = f;
edges[len].cost = co;
head[u] = len;
}
void dfs(int u) {
dfn[u] = low[u] = ++times;
stk[++top] = u;
for (int v : G[u]) {
if (!dfn[v]) {
dfs(v);
low[u] = min(low[u], low[v]);
} else if (!scc[v]) {
low[u] = min(low[u], dfn[v]);
}
}
if (low[u] == dfn[u]) {
++tot;
while (1) {
int x = stk[top--];
scc[x] = tot;
b[tot] += a[x];
if (x == u) {
break;
}
}
}
}
struct node {
ll u, d;
node(ll a = 0, ll b = 0) : u(a), d(b) {}
};
inline bool operator < (const node &a, const node &b) {
return a.d > b.d;
}
struct MCMF {
ll d[maxn], h[maxn], p[maxn], f[maxn];
bool vis[maxn];
inline void add(ll u, ll v, ll c, ll co) {
add_edge(u, v, c, 0, co);
add_edge(v, u, 0, 0, -co);
}
inline bool dij() {
for (int i = 1; i <= ntot; ++i) {
vis[i] = 0;
d[i] = inf;
}
priority_queue<node> pq;
pq.emplace(S, 0);
d[S] = 0;
f[S] = inf;
while (pq.size()) {
int u = pq.top().u;
pq.pop();
if (vis[u]) {
continue;
}
vis[u] = 1;
for (int i = head[u]; i; i = edges[i].next) {
edge &e = edges[i];
if (d[e.to] > d[u] + e.cost + h[u] - h[e.to] && e.cap > e.flow) {
d[e.to] = d[u] + e.cost + h[u] - h[e.to];
p[e.to] = i;
f[e.to] = min(f[u], e.cap - e.flow);
if (!vis[e.to]) {
pq.emplace(e.to, d[e.to]);
}
}
}
}
return d[T] < inf;
}
pii solve() {
ll flow = 0, cost = 0;
while (dij()) {
for (int i = 1; i <= ntot; ++i) {
h[i] = min(h[i] + d[i], inf);
}
flow += f[T];
cost += f[T] * h[T];
for (int u = T; u != S; u = edges[p[u]].from) {
edges[p[u]].flow += f[T];
edges[p[u] ^ 1].flow -= f[T];
}
}
return make_pair(flow, cost);
}
} solver;
void solve() {
scanf("%lld%lld%lld", &n, &m, &K);
while (m--) {
int u, v;
scanf("%d%d", &u, &v);
G[u].pb(v);
}
for (int i = 1; i <= n; ++i) {
scanf("%lld", &a[i]);
}
for (int i = 1; i <= n; ++i) {
if (!dfn[i]) {
dfs(i);
}
}
reverse(b + 1, b + tot + 1);
for (int i = 1; i <= n; ++i) {
scc[i] = tot - scc[i] + 1;
}
for (int i = 1; i <= tot; ++i) {
c[i] = c[i - 1] + b[i];
}
S = ++ntot;
T = ++ntot;
for (int i = 1; i <= tot; ++i) {
id[i][0] = ++ntot;
id[i][1] = ++ntot;
}
set<pii> st;
for (int u = 1; u <= n; ++u) {
for (int v : G[u]) {
int x = scc[u], y = scc[v];
if (x != y && st.find(make_pair(x, y)) == st.end()) {
st.emplace(x, y);
solver.add(id[x][1], id[y][0], inf, c[y - 1] - c[x]);
}
}
}
for (int i = 1; i <= tot; ++i) {
solver.add(id[i][0], id[i][1], 1, 0);
solver.add(id[i][0], id[i][1], inf, b[i]);
solver.add(id[i][1], T, inf, c[tot] - c[i]);
}
solver.add(S, id[scc[1]][0], K, c[scc[1] - 1]);
pii ans = solver.solve();
printf("%lld\n", c[tot] * K - ans.scd);
}
int main() {
int T = 1;
// scanf("%d", &T);
while (T--) {
solve();
}
return 0;
}

浙公网安备 33010602011771号