黑白树

《黑白树》解题报告

题目描述

维护一棵树,他有 \(N\) 个节点,然后每个节点可能有两种颜色,黑色或白色。

你需要支持五种操作:

  • 翻转 \(x\) 点的颜色
  • \(x\) 点的同色连通块的点权加上 \(v\)
  • 将一个子树内的点的点权加上 \(v\)
  • 将一条 \(x\)\(y\) 的链的点的点权加上 \(v\)
  • 查询 \(x\) 同色连通块的点权最大值。

对于全部数据 \(n\le 2\times 10^5\),保证答案不超过 int

题解

咋整捏可以考虑,对于每个连通块,维护一个深度最小的点,称为这个连通块的管辖点。

然后判断一个点和他的某个祖先是不是一个连通块的点,可以通过它到根的异色点数量来判断。

我们考虑线段树维护,那么,我们现在其实就是不会修改 同色连通块这个条件。非常难搞。

我们考虑,如果我们对于一个线段树区间,维护这个区间的 \(lca\) 的到根节点的路径上两个颜色节点的数量。

然后考虑,修改某个连通块的时候,直接修改连通块的管辖点的子树区间,然后我们这个子树区间其实不是全都改的,只改同色的且在一个连通块的区间。

即,我们对于一个线段树节点,记录 \(cnt[2], cnttag[2], mx[2], mxtag[2], tag\),分别表示 \(lca\) 到根节点的每个颜色的点的个数,颜色点数量的标记,这个区间内和 \(lca\) 同一连通块的最大值,最大值标记,还有区间加标记。

然后考虑下传标记的时候,只有儿子内的 \(cnt[col]\) 和 我当前节点的 \(cnt[col]\) 一样才可以下传 \(mxtag[col \oplus 1]\) 的。

同理,Pushup的时候也是类似的。我们一个线段树区间 \([l,r]\) 只记录和区间 \(lca\) 相连通的块的答案。

然后查询的时候也是和这个连通块相通的区间才能计入答案就好了。

具体还是看代码实现吧。

#include <bits/stdc++.h>
using std::cin;
using std::cout;
using std::vector;
using std::copy;
using std::reverse;
using std::sort;
using std::get;
using std::unique;
using std::swap;
using std::array;
using std::cerr;
using std::function;
using std::map;
using std::set;
using std::pair;
using std::mt19937;
using std::make_pair;
using std::tuple;
using std::make_tuple;
using std::uniform_int_distribution;
using ll = long long;
namespace qwq {
	mt19937 eng;
	void init(int Seed) { return eng.seed(Seed); }
	int rnd(int l = 1, int r = 1000000000) { return uniform_int_distribution<int> (l, r)(eng); }
}
template <typename T>
inline T min(const T &x, const T &y) { return x < y ? x : y; }
template<typename T>
inline T max(const T &x, const T &y) { return x > y ? x : y; }
template <typename T>
inline void read(T &x) {
	x = 0;
	bool f = 0;
	char ch = getchar();
	while (!isdigit(ch)) f = ch == '-', ch = getchar();
	while (isdigit(ch)) x = x * 10 + ch - '0', ch = getchar();
	if (f) x = -x;
}
template<typename T, typename ...Arg>
inline void read(T &x, Arg &... y) {
	read(x);
	read(y...);
}
#define O(x) cerr << #x << " : " << x << '\n'
const double Pi = acos(-1);
const int MAXN = 262144, MOD = 998244353, inv2 = (MOD + 1) / 2, I32_INF = 0x3f3f3f3f;
const long long I64_INF = 0x3f3f3f3f3f3f3f3f;
auto Ksm = [] (int x, int y) -> int {
	if (y < 0) {
		y %= MOD - 1;
		y += MOD - 1;
	}
	int ret = 1;
	for (; y; y /= 2, x = (long long) x * x % MOD) if (y & 1) ret = (long long) ret * x % MOD;
	return ret;
};
auto Mod = [] (int x) -> int {
	if (x >= MOD) return x - MOD;
	else if (x < 0) return x + MOD;
	else return x;
};
inline int ls(int k) { return k << 1; }
inline int rs(int k) { return k << 1 | 1; }
struct Segment_Tree {
	struct Node {
		int cnt[2], cnttag[2], mx[2], mxtag[2], tag;
	} nd[MAXN * 4];
	int SZ;
	void build(int _N) { SZ = _N; }
	inline void add_s(int u, int op, int v) {
		nd[u].cnt[op] += v;
		nd[u].cnttag[op] += v;
	}
	inline void add_mx(int u, int op, int v) {
		nd[u].mx[op] += v;
		nd[u].mxtag[op] += v;
	}
	inline void add_all(int u, int v) {
		nd[u].mx[0] += v;
		nd[u].mx[1] += v;
		nd[u].tag += v;
	}
	inline void pushup(int u) {
		nd[u].mx[0] = max((nd[ls(u)].cnt[1] == nd[u].cnt[1]) ? nd[ls(u)].mx[0] : -I32_INF, (nd[rs(u)].cnt[1] == nd[u].cnt[1]) ? nd[rs(u)].mx[0] : -I32_INF);
		nd[u].mx[1] = max((nd[ls(u)].cnt[0] == nd[u].cnt[0]) ? nd[ls(u)].mx[1] : -I32_INF, (nd[rs(u)].cnt[0] == nd[u].cnt[0]) ? nd[rs(u)].mx[1] : -I32_INF);
	}
	void pushdown(int u) {
		for (int o = 0; o < 2; ++o) {
			add_s(ls(u), o, nd[u].cnttag[o]);
			add_s(rs(u), o, nd[u].cnttag[o]);
			nd[u].cnttag[o] = 0;
			if (nd[u].cnt[o] == nd[ls(u)].cnt[o]) add_mx(ls(u), o ^ 1, nd[u].mxtag[o ^ 1]);
			if (nd[u].cnt[o] == nd[rs(u)].cnt[o]) add_mx(rs(u), o ^ 1, nd[u].mxtag[o ^ 1]);
			nd[u].mxtag[o ^ 1] = 0;
		}
		add_all(ls(u), nd[u].tag);
		add_all(rs(u), nd[u].tag);
		nd[u].tag = 0;
	}
	// 改变 u 点的颜色
	void rev(int u) {
		function<void(int, int, int)> dfs = [&] (int k, int l, int r) -> void {
			if (l == r) {
				return swap(nd[k].mx[0], nd[k].mx[1]);
			}
			int mid = (l + r) / 2;
			pushdown(k);
			u <= mid ? dfs(ls(k), l, mid) : dfs(rs(k), mid + 1, r);
			pushup(k);
		};
		return dfs(1, 1, SZ);
	}
	void mfy(int ql, int qr, int v) {
		function<void(int, int, int)> dfs = [&] (int k, int l, int r) -> void {
			if (ql <= l && r <= qr) return add_all(k, v);
			pushdown(k);
			int mid = (l + r) / 2;
			if (ql <= mid) dfs(ls(k), l, mid);
			if (mid < qr) dfs(rs(k), mid + 1, r);
			pushup(k);
		};
		return dfs(1, 1, SZ);
	}
	void mfys(int qcol, int ql, int qr, int v) {
		function<void(int, int, int)> dfs = [&] (int k, int l, int r) -> void {
			if (ql <= l && r <= qr) return add_s(k, qcol, v);
			pushdown(k);
			int mid = (l + r) / 2;
			if (ql <= mid) dfs(ls(k), l, mid);
			if (mid < qr) dfs(rs(k), mid + 1, r);
			pushup(k);
		};
		return dfs(1, 1, SZ);
	}
	void mfymx(int qcol, int ql, int qr, int v, int qm) {
		function<void(int, int, int)> dfs = [&] (int k, int l, int r) -> void {
			if (nd[k].cnt[qcol ^ 1] > qm) return;
			if (ql <= l && r <= qr) {
				return add_mx(k, qcol, v);
			}
			pushdown(k);
			int mid = (l + r) / 2;
			if (ql <= mid) dfs(ls(k), l, mid);
			if (mid < qr) dfs(rs(k), mid + 1, r);
			pushup(k);
		};
		return dfs(1, 1, SZ);
	}
	int qrymx(int qcol, int ql, int qr, int qm) {
		function<int(int, int, int)> dfs = [&] (int k, int l, int r) -> int {
			if (nd[k].cnt[qcol ^ 1] > qm) return -I32_INF;
			if (ql <= l && r <= qr) {
				return nd[k].mx[qcol];
			}
			pushdown(k);
			int ret = -I32_INF, mid = (l + r) / 2;
			if (ql <= mid) ret = max(ret, dfs(ls(k), l, mid));
			if (mid < qr) ret = max(ret, dfs(rs(k), mid + 1, r));
			return ret;
		};
		return dfs(1, 1, SZ);
	}
	int qrys(int qcol, int qpos) {
		function<int(int, int, int)> dfs = [&] (int k, int l, int r) -> int {
			if (l == r) return nd[k].cnt[qcol];
			pushdown(k);
			int mid = (l + r) / 2;
			return qpos <= mid ? dfs(ls(k), l, mid) : dfs(rs(k), mid + 1, r);
		};
		return dfs(1, 1, SZ);
	}
} cyc;
vector<int> t[MAXN];
set<pair<int, int>> cha[2][MAXN];
int N, Q, son[MAXN], dep[MAXN], top[MAXN], sz[MAXN], fa[MAXN], clk, dfn[MAXN], ed[MAXN], col[MAXN], val[MAXN];
int find(int x) {
	pair<int, int> res(dep[x], x);
	for (int c = !col[x]; x; x = fa[top[x]]) {
		auto tmp = cha[c][top[x]].upper_bound({dep[x], I32_INF});
		if (tmp != cha[c][top[x]].begin()) {
			--tmp;
			if (1 + tmp->first < res.first) res = {1 + tmp->first, son[tmp->second]};
			break;
		}
		res = {dep[top[x]], top[x]};
	}
	return res.second;
}
int main() {
	freopen("astil.in", "r", stdin);
	freopen("astil.out", "w", stdout);
	qwq::init(20050112);
	read(N, Q);
	for (int i = 1, x, y; i < N; ++i) {
		read(x, y);
		t[x].push_back(y);
		t[y].push_back(x);
	}
	for (int i = 1; i <= N; ++i) read(col[i]);
	for (int i = 1; i <= N; ++i) read(val[i]);
	function<void(int, int)> dfs1 = [&] (int u, int lst) -> void {
		dep[u] = dep[fa[u] = lst] + (sz[u] = 1);
		for (auto &i: t[u]) {
			if (i != lst) {
				dfs1(i, u);
				sz[u] += sz[i];
				if (sz[son[u]] < sz[i]) son[u] = i;
			}
		}
	};
	dfs1(1, 0);
	function<void(int, int)> dfs2 = [&] (int u, int tp) -> void {
		top[u] = tp;
		cha[col[u]][top[u]].insert({dep[u], u});
		dfn[u] = ++clk;
		if (son[u]) dfs2(son[u], tp);
		for (auto &i: t[u]) if (i != fa[u] && i != son[u]) dfs2(i, i);
		ed[u] = clk;
	};
	dfs2(1, 1);
	cyc.build(N);
	// cout << (-3 / 2);
	for (int i = 1; i <= N; ++i) cyc.mfys(col[i], dfn[i], ed[i], 1);
	for (int i = 1; i <= N; ++i) cyc.mfy(dfn[i], dfn[i], val[i]);
	// for (int i = 1;i <= N; ++i) {
	// 	if (dfn[i] == 870) {
	// 		cout << i << '\n';
	// 	}
	// }
	for (int opt, x, y, z, cnt = 0; Q--; ) {
		read(opt, x);
		if (opt == 1) {
			if (x == 602) {
				x = 602;
			}
			cyc.mfys(col[x], dfn[x], ed[x], -1);
			col[x] ^= 1;
			cha[!col[x]][top[x]].erase({dep[x], x});
			cha[col[x]][top[x]].insert({dep[x], x});
			cyc.rev(dfn[x]);
			cyc.mfys(col[x], dfn[x], ed[x], 1);
		}
		else if (opt == 2) {
			read(y);
			int p = find(x);
			cyc.mfymx(col[p], dfn[p], ed[p], y, cyc.qrys(!col[p], dfn[p]));
		}
		else if (opt == 3) {
			int p = find(x);
			++cnt;
			// if (cnt == 27) {
			// 	printf("%d\n", x);
			// }
			// printf("p%d\n", p);
			printf("%d\n", cyc.qrymx(col[p], dfn[p], ed[p], cyc.qrys(!col[p], dfn[p])));
		}
		else if (opt == 4) {
			read(y, z);
			while (top[x] != top[y]) {
				if (dep[top[x]] < dep[top[y]]) {
					cyc.mfy(dfn[top[y]], dfn[y], z);
					y = fa[top[y]];
				}
				else {
					cyc.mfy(dfn[top[x]], dfn[x], z);
					x = fa[top[x]];
				}
			}
			if (dep[x] < dep[y]) cyc.mfy(dfn[x], dfn[y], z);
			else cyc.mfy(dfn[y], dfn[x], z);
		}
		else {
			read(y);
			cyc.mfy(dfn[x], ed[x], y);
		}
	}
	cerr << ((double) clock() / CLOCKS_PER_SEC) << '\n';
	return (0-0);
}
posted @ 2022-03-27 00:07  siriehn_nx  阅读(156)  评论(0)    收藏  举报