Loading

uoj#418.【集训队作业2018】三角形 做题记录

主要加深一下对 Exchange Argument 的理解。

\(n\) 个元素 \(x_{1\dots n}\) 进行排列,然后求排列后的最优化答案。

Exchange Argument 是指,对与相邻两个元素 \(x_1, x_2\),比较 \(F(x_1, x_2)\)\(F(x_2, x_1)\) 以决定是否交换 \(x_1, x_2\)

然后把这个当作 cmp 进行排序就行了,这样做正确当且仅当 \(\{x_{1\dots n}\}\) 满足全序关系。其中一个集合上的全序关系为下面三者的总和:

  • 反对称性:若 \(a \le b\)\(b \le a\),那么 \(a = b\)

  • 传递性:若 \(a \le b\)\(b \le c\),那么 \(a \le c\)

  • 完全性:对于任意 \(a, b\)\(a \le b\)\(b \le a\) 两者中至少一者成立。

全序又称为线性序、简单序。全序集合又称为线序集合、简单序集合、链。注意全序关系不是偏序关系。


link

首先有一个错误的贪心:对于每个点,按某种顺序依次处理其儿子子树。可以通过反例推翻,例如 \(30 - 10 - 20 - 1 (r) - 20 - 10 - 30\)

我们相当于选择一个内向树拓扑序。对于点 \(u\) 定义其权值为二元组 \((w_u - \sum\limits_{v \in son(u)} w_v, w_u)\),然后定义两个二元组 \((a_1, b_1), (a_2, b_2)\) 合并的结果为 \((a_1 + a_2, \max(b_1, a_1 + b_2))\)

问题转化为求拓扑序每个点代表的二元组依次合并后的 \(b\) 的最小值,然后对每个点的子树都求一遍答案。

但是内向拓扑序的条件难以刻划,考虑时光倒流,改为外向树拓扑序。即每个点必须在选完父亲后选,就是 01 on tree 了。

但是离 01 on tree 还差一步,如何比较两个二元组的优先级,此时需要比较 \(\max(b_1, a_1 + b_2)\)\(\max(b_2, a_2 + b_1)\) 的值。

但是这并不全序,无法进行 Exchange Argument(事实上 01 on tree 就是 Exchange Argument 上树)。此时观察式子,可以进行定义如下比较:

  • \(a_1\ge 0\)\(a_2 \ge 0\),那么 \(a_1 + b_2 \le a_2 + b_1 \Rightarrow a_1 - b_1 \le a_2 - b_2\)

  • \(a_1 < 0\)\(a_2 < 0\),那么 \(b_1 \le b_2\)

  • \([a_1 < 0] \not = [a_2 < 0]\),那么 \(a_1 < a_2\)

这样整条链为:「\(a < 0\) 的按 \(b\) 排序」$ < $ 「\(a \ge 0\) 的按 \(a - b\) 排序」,显然满足全序关系。

这样就做完了整棵树。考虑每个点的子树的问题,其顺序应该是整棵树的顺序的子序列,可以使用线段树合并维护该子序列,时间复杂度 \(\mathcal O(n\log n)\)


这题其实转化也很重要。我一开始猜了一个错误的贪心结论,修正方法并不难,即是找一个反例来推翻这个结论。对于更复杂的情况,我们需要寻找的反例也许也会更加复杂,但是若不找反例则会在一错再错,浪费更多的时间。所以,面对一个结论,首先得检查是否依照题目给定条件正确推导出来 / 猜测出来(例如 NOIP2024 T4 中我在错误的结论上思考了 1h+ 未果),其次如果结论看着比较怪的话一定要拼尽全力找反例。

然后就是时光倒流了。对于某一个条件模型,我们可以通过各种不同的小技巧将其转化为熟悉的模型,比如这里转化为外向树变成 01 on tree。


点击查看代码
#include <bits/stdc++.h>
#include<bits/extc++.h>
namespace Initial {
	#define ll long long
	#define ull unsigned ll
	#define fi first
	#define se second
	#define mkp make_pair
	#define pir pair <ll, ll>
	#define pb push_back
	#define i128 __int128
	using namespace std;
	const ll maxn = 2e5 + 10, inf = 1e18, mod = 998244353, iv = mod - mod / 2;
	ll power(ll a, ll b = mod - 2, ll p = mod) {
		ll s = 1;
		while(b) {
			if(b & 1) s = 1ll * s * a %p;
			a = 1ll * a * a %p, b >>= 1;
		} return s;
	}
	template <class T>
	const inline ll pls(const T x, const T y) { return x + y >= mod? x + y - mod : x + y; }
	template <class T>
	const inline void add(T &x, const T y) { x = x + y >= mod? x + y - mod : x + y; }
	template <class T>
	const inline void chkmax(T &x, const T y) { x = x < y? y : x; }
	template <class T>
	const inline void chkmin(T &x, const T y) { x = x < y? x : y; }
} using namespace Initial;

namespace Read {
	char buf[1 << 22], *p1, *p2;
	// #define getchar() (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, (1 << 22) - 10, stdin), p1 == p2)? EOF : *p1++)
	template <class T>
	const inline void rd(T &x) {
		char ch; bool neg = 0;
		while(!isdigit(ch = getchar()))
			if(ch == '-') neg = 1;
		x = ch - '0';
		while(isdigit(ch = getchar()))
			x = (x << 1) + (x << 3) + ch - '0';
		if(neg) x = -x;
	}
} using Read::rd;

ll n, w[maxn], d[maxn], par[maxn], ans[maxn]; vector <ll> to[maxn];
ll find(ll x) {return d[x] ^ x? d[x] = find(d[x]) : x;}

struct Data {ll a, b;} c[maxn], e[maxn], tr[maxn * 20];
const Data operator + (const Data A, const Data B) {
	return (Data) {A.a + B.a, max(A.b, A.a + B.b)};
}
const bool operator < (const Data A, const Data B) {
	if(A.a < 0 && B.a < 0) return A.b < B.b;
	if(A.a >= 0 && B.a >= 0) return A.a - A.b < B.a - B.b;
	return A.a < B.a;
} set <pair <Data, ll> > st;

ll tot, rt[maxn], lc[maxn * 20], rc[maxn * 20], id[maxn], rk[maxn];
ll tail[maxn], nxt[maxn];
void modify(ll &p, ll l, ll r, ll x) {
	if(!p) p = ++tot;
	if(l == r) return tr[p] = c[id[l]], void();
	ll mid = l + r >> 1;
	if(x <= mid) modify(lc[p], l, mid, x);
	else modify(rc[p], mid + 1, r, x);
	tr[p] = tr[lc[p]] + tr[rc[p]];
}
ll merge(ll p, ll q, ll l, ll r) {
	if(!p || !q) return p | q;
	ll mid = l + r >> 1;
	lc[p] = merge(lc[p], lc[q], l, mid);
	rc[p] = merge(rc[p], rc[q], mid + 1, r);
	tr[p] = tr[lc[p]] + tr[rc[p]]; return p;
}

void dfs(ll u) {
	for(ll v: to[u])
		dfs(v), rt[u] = merge(rt[u], rt[v], 1, n);
	modify(rt[u], 1, n, rk[u]);
	ans[u] = w[u] + tr[rt[u]].b;
}

int main() {
	rd(n), rd(n);
	for(ll i = 2; i <= n; i++)
		rd(par[i]), to[par[i]].pb(i);
	for(ll i = 1; i <= n; i++) rd(w[i]), d[i] = i;
	for(ll i = 2; i <= n; i++) c[par[i]].b += w[i];
	for(ll i = 1; i <= n; i++) c[i].a = c[i].b - w[i], e[i] = c[i];
	st.clear(); tail[1] = 1;
	for(ll i = 2; i <= n; i++) st.insert(mkp(e[i], i)), tail[i] = i;
	for(ll o = 1; o < n; o++) {
		ll u = st.begin() -> se; st.erase(st.begin());
		ll v = find(par[u]);
		if(v > 1) st.erase(mkp(e[v], v));
		e[v] = e[v] + e[u], d[u] = v;
		nxt[tail[v]] = u, tail[v] = tail[u];
		if(v > 1) st.insert(mkp(e[v], v));
	}
	for(ll x = 1, o = 1; x; o++, x = nxt[x]) id[rk[x] = o] = x;
	dfs(1);
	for(ll i = 1; i <= n; i++) printf("%lld ", ans[i]);
	return 0;
}
posted @ 2025-02-14 20:49  Sktn0089  阅读(120)  评论(0)    收藏  举报