AT_abc432_f [ABC432F] Candy Redistribution 题解
思路
好题。
无解很好判。如果 \(\sum a_i\) 为 \(n\) 的倍数则一定有解,否则一定无解。令 \(t = \displaystyle \frac{\sum a_i}{n}\),则最终需要将每一个 \(a_i\) 都变为 \(t\)。
对于一个可行解 \((x_1, y_1, z_1), (x_2, y_2, z_2), \cdots, (x_k, y_k, z_k)\),若以 \((x_1, y_1), (x_2, y_2), \cdots, (x_k, y_k)\) 作为一张 \(n\) 个点的无向图的边,则对于每一个连通分量 \(V\),必须满足 \(\sum \limits_{u \in V} a_u = t \times |V|\),其中 \(|V|\) 为连通分量 \(V\) 中的点数。
定义 合法的 代表满足 \(\sum \limits_{u \in V} a_u = t \times |V|\) 的集合。
对于每一个合法的连通分量 \(V\) 进行考虑。设 \(v_1, \cdots, v_k\) 代表 \(V\) 中的所有元素按降序排列后的结果,则我们可以构造出一种最优方案:对于每一个 \(1 \le i \le k\) 的 \(i\),\(i\) 多了多少颗糖果,就把它给 \(i + 1\)。因为满足 \(\sum \limits_{u \in V} a_u = t \times |V|\) 的条件,所以当 \(i = n\) 时,多的糖果数一定为 \(0\)。此时,这个连通分量需要 \(|V| - 1\) 此操作。
回到原问题,其实就是把 \(1 \sim n\) 划分为若干合法子集,使得 \(V_1 \cup V_2 \cup \cdots \cup V_k = \{1, 2, \cdots, n\}\) 且 \(V_1 \cap V_2 \cap \cdots \cap V_k = \varnothing\),然后对于每一个 \(V\) 独立进行上述操作。那我们会发现这种情况下的答案为 \(ans = \sum \limits_{i=1}^k (|V_i| - 1) = \sum \limits_{i=1}^k |V_i| - k = n - k\)。所以最小化操作数 \(ans\) 即为最大化合法子集数。
这个问题可以用状压 DP 解决。设 \(dp_S\) 代表将集合 \(S\) 通过划分成合法子集的最大数量。考虑最后一个放入 \(S\) 的数 \(u\),若 \(S\) 本身为合法子集,则不需要为 \(u\) 多划分一个,否则就需要。预处理 \(f_S\) 代表 \(S\) 是否为合法子集,则:
最后如何将方案输出呢?在每次更新 \(dp\) 时存他的前继,最后倒过来向前找即可。
代码
#include <bits/stdc++.h>
using namespace std;
const int N = 25, M = (1 << 20) + 1;
struct node
{
int x, y, z;
};
int n;
int a[N], dp[M], lst[M];
bool st[M];
vector<node> ans;
int main()
{
memset(dp, -0x3f, sizeof(dp));
scanf("%d", &n);
int sum = 0;
for (int i = 0; i < n; i++)
scanf("%d", &a[i]), sum += a[i];
if (sum % n != 0)
{
printf("-1\n");
return 0;
}
sum /= n;
for (int j = 0; j < (1 << n); j++)
{
int res = 0;
for (int i = 0; i < n; i++)
if ((j >> i) & 1) res += (a[i] - sum);
if (res == 0) st[j] = 1;
}
dp[0] = 0;
for (int j = 1; j < (1 << n); j++)
{
int now = st[j];
for (int i = 0; i < n; i++)
{
if ((j >> i) & 1)
{
int tmp = dp[j - (1 << i)] + now;
if (dp[j] < tmp)
{
dp[j] = tmp;
lst[j] = i;
}
}
}
}
int S = (1 << n) - 1;
vector<pair<int, int> > tmp;
while (S > 0)
{
int t = lst[S];
tmp.push_back({a[t], t});
S -= (1 << t);
if (st[S])
{
sort(tmp.begin(), tmp.end());
int res = tmp[tmp.size() - 1].first - sum;
for (int i = tmp.size() - 1; i; i--)
{
int nx = tmp[i].second, ny = tmp[i - 1].second;
ans.push_back({nx + 1, ny + 1, res});
res += tmp[i - 1].first - sum;
}
tmp.clear();
}
}
printf("%d\n", n - dp[(1 << n) - 1]);
for (node i : ans) printf("%d %d %d\n", i.x, i.y, i.z);
return 0;
}

浙公网安备 33010602011771号