uoj#418.【集训队作业2018】三角形 做题记录
主要加深一下对 Exchange Argument 的理解。
对 \(n\) 个元素 \(x_{1\dots n}\) 进行排列,然后求排列后的最优化答案。
Exchange Argument 是指,对与相邻两个元素 \(x_1, x_2\),比较 \(F(x_1, x_2)\) 和 \(F(x_2, x_1)\) 以决定是否交换 \(x_1, x_2\)。
然后把这个当作 cmp 进行排序就行了,这样做正确当且仅当 \(\{x_{1\dots n}\}\) 满足全序关系。其中一个集合上的全序关系为下面三者的总和:
-
反对称性:若 \(a \le b\) 且 \(b \le a\),那么 \(a = b\)。
-
传递性:若 \(a \le b\) 且 \(b \le c\),那么 \(a \le c\)。
-
完全性:对于任意 \(a, b\),\(a \le b\) 和 \(b \le a\) 两者中至少一者成立。
全序又称为线性序、简单序。全序集合又称为线序集合、简单序集合、链。注意全序关系不是偏序关系。
首先有一个错误的贪心:对于每个点,按某种顺序依次处理其儿子子树。可以通过反例推翻,例如 \(30 - 10 - 20 - 1 (r) - 20 - 10 - 30\)。
我们相当于选择一个内向树拓扑序。对于点 \(u\) 定义其权值为二元组 \((w_u - \sum\limits_{v \in son(u)} w_v, w_u)\),然后定义两个二元组 \((a_1, b_1), (a_2, b_2)\) 合并的结果为 \((a_1 + a_2, \max(b_1, a_1 + b_2))\)。
问题转化为求拓扑序每个点代表的二元组依次合并后的 \(b\) 的最小值,然后对每个点的子树都求一遍答案。
但是内向拓扑序的条件难以刻划,考虑时光倒流,改为外向树拓扑序。即每个点必须在选完父亲后选,就是 01 on tree 了。
但是离 01 on tree 还差一步,如何比较两个二元组的优先级,此时需要比较 \(\max(b_1, a_1 + b_2)\) 和 \(\max(b_2, a_2 + b_1)\) 的值。
但是这并不全序,无法进行 Exchange Argument(事实上 01 on tree 就是 Exchange Argument 上树)。此时观察式子,可以进行定义如下比较:
-
若 \(a_1\ge 0\) 且 \(a_2 \ge 0\),那么 \(a_1 + b_2 \le a_2 + b_1 \Rightarrow a_1 - b_1 \le a_2 - b_2\)
-
若 \(a_1 < 0\) 且 \(a_2 < 0\),那么 \(b_1 \le b_2\)。
-
若 \([a_1 < 0] \not = [a_2 < 0]\),那么 \(a_1 < a_2\)。
这样整条链为:「\(a < 0\) 的按 \(b\) 排序」$ < $ 「\(a \ge 0\) 的按 \(a - b\) 排序」,显然满足全序关系。
这样就做完了整棵树。考虑每个点的子树的问题,其顺序应该是整棵树的顺序的子序列,可以使用线段树合并维护该子序列,时间复杂度 \(\mathcal O(n\log n)\)。
这题其实转化也很重要。我一开始猜了一个错误的贪心结论,修正方法并不难,即是找一个反例来推翻这个结论。对于更复杂的情况,我们需要寻找的反例也许也会更加复杂,但是若不找反例则会在一错再错,浪费更多的时间。所以,面对一个结论,首先得检查是否依照题目给定条件正确推导出来 / 猜测出来(例如 NOIP2024 T4 中我在错误的结论上思考了 1h+ 未果),其次如果结论看着比较怪的话一定要拼尽全力找反例。
然后就是时光倒流了。对于某一个条件模型,我们可以通过各种不同的小技巧将其转化为熟悉的模型,比如这里转化为外向树变成 01 on tree。
点击查看代码
#include <bits/stdc++.h>
#include<bits/extc++.h>
namespace Initial {
#define ll long long
#define ull unsigned ll
#define fi first
#define se second
#define mkp make_pair
#define pir pair <ll, ll>
#define pb push_back
#define i128 __int128
using namespace std;
const ll maxn = 2e5 + 10, inf = 1e18, mod = 998244353, iv = mod - mod / 2;
ll power(ll a, ll b = mod - 2, ll p = mod) {
ll s = 1;
while(b) {
if(b & 1) s = 1ll * s * a %p;
a = 1ll * a * a %p, b >>= 1;
} return s;
}
template <class T>
const inline ll pls(const T x, const T y) { return x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void add(T &x, const T y) { x = x + y >= mod? x + y - mod : x + y; }
template <class T>
const inline void chkmax(T &x, const T y) { x = x < y? y : x; }
template <class T>
const inline void chkmin(T &x, const T y) { x = x < y? x : y; }
} using namespace Initial;
namespace Read {
char buf[1 << 22], *p1, *p2;
// #define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, (1 << 22) - 10, stdin), p1 == p2)? EOF : *p1++)
template <class T>
const inline void rd(T &x) {
char ch; bool neg = 0;
while(!isdigit(ch = getchar()))
if(ch == '-') neg = 1;
x = ch - '0';
while(isdigit(ch = getchar()))
x = (x << 1) + (x << 3) + ch - '0';
if(neg) x = -x;
}
} using Read::rd;
ll n, w[maxn], d[maxn], par[maxn], ans[maxn]; vector <ll> to[maxn];
ll find(ll x) {return d[x] ^ x? d[x] = find(d[x]) : x;}
struct Data {ll a, b;} c[maxn], e[maxn], tr[maxn * 20];
const Data operator + (const Data A, const Data B) {
return (Data) {A.a + B.a, max(A.b, A.a + B.b)};
}
const bool operator < (const Data A, const Data B) {
if(A.a < 0 && B.a < 0) return A.b < B.b;
if(A.a >= 0 && B.a >= 0) return A.a - A.b < B.a - B.b;
return A.a < B.a;
} set <pair <Data, ll> > st;
ll tot, rt[maxn], lc[maxn * 20], rc[maxn * 20], id[maxn], rk[maxn];
ll tail[maxn], nxt[maxn];
void modify(ll &p, ll l, ll r, ll x) {
if(!p) p = ++tot;
if(l == r) return tr[p] = c[id[l]], void();
ll mid = l + r >> 1;
if(x <= mid) modify(lc[p], l, mid, x);
else modify(rc[p], mid + 1, r, x);
tr[p] = tr[lc[p]] + tr[rc[p]];
}
ll merge(ll p, ll q, ll l, ll r) {
if(!p || !q) return p | q;
ll mid = l + r >> 1;
lc[p] = merge(lc[p], lc[q], l, mid);
rc[p] = merge(rc[p], rc[q], mid + 1, r);
tr[p] = tr[lc[p]] + tr[rc[p]]; return p;
}
void dfs(ll u) {
for(ll v: to[u])
dfs(v), rt[u] = merge(rt[u], rt[v], 1, n);
modify(rt[u], 1, n, rk[u]);
ans[u] = w[u] + tr[rt[u]].b;
}
int main() {
rd(n), rd(n);
for(ll i = 2; i <= n; i++)
rd(par[i]), to[par[i]].pb(i);
for(ll i = 1; i <= n; i++) rd(w[i]), d[i] = i;
for(ll i = 2; i <= n; i++) c[par[i]].b += w[i];
for(ll i = 1; i <= n; i++) c[i].a = c[i].b - w[i], e[i] = c[i];
st.clear(); tail[1] = 1;
for(ll i = 2; i <= n; i++) st.insert(mkp(e[i], i)), tail[i] = i;
for(ll o = 1; o < n; o++) {
ll u = st.begin() -> se; st.erase(st.begin());
ll v = find(par[u]);
if(v > 1) st.erase(mkp(e[v], v));
e[v] = e[v] + e[u], d[u] = v;
nxt[tail[v]] = u, tail[v] = tail[u];
if(v > 1) st.insert(mkp(e[v], v));
}
for(ll x = 1, o = 1; x; o++, x = nxt[x]) id[rk[x] = o] = x;
dfs(1);
for(ll i = 1; i <= n; i++) printf("%lld ", ans[i]);
return 0;
}

浙公网安备 33010602011771号