P2486 [SDOI2011] 染色

P2486 [SDOI2011] 染色

感觉不难想,但代码调了好久/ll

不难想到用树链剖分,线段树部分只需要区间赋值,区间查询连续段个数的操作,关键在于树上两点如何计算答案。

假设当前点 \(x\) 要跳到点 \(fa[top[x]]\),要求的即为点 \(x\) 到点 \(top[x]\),用一个变量记录这条链的最低端的点的儿子的颜色,可能有点拗口,就是这条链底下的点,随后判断颜色是否相同而决定是否加 \(1\) 了,随后变量即为点 \(top[x]\) 的颜色。当点 \(x,y\) 在同一条重链上时,假设 \(x\)\(y\) 上面,此时 \(y\) 需要向上跳才行,而 \(x,y\) 分别带的链就拼在一块了,所以别忘记考虑两个相邻的端点。

code:

#include <bits/stdc++.h>
#define ls p<<1
#define rs p<<1|1
#define fi first
#define se second
#define pb push_back
#define mk make_pair
#define ll long long
#define space putchar(' ')
#define enter putchar('\n')
using namespace std;

inline int read() {
	int x = 0, f = 1;
	char c = getchar();
	while (c < '0' || c > '9') f = c == '-' ? -1 : f, c = getchar();
	while (c >= '0' && c <= '9') x = (x<<3)+(x<<1)+(c^48), c = getchar();
	return x*f;
}

inline void write(int x) {
	if (x < 0) x = -x, putchar('-');
	if (x > 9) write(x/10);
	putchar('0'+x%10);
}

const int N = 1e5+5, M = N<<2;
int n, m, idx, lcol, rcol, a[N], d[N], fa[N], id[N], to[N], siz[N], son[N], top[N], lc[M], rc[M], tag[M], res[M];
vector <int> E[N];

inline void pushup(int p) {
	lc[p] = lc[ls], rc[p] = rc[rs];
	res[p] = res[ls]+res[rs]-(rc[ls]==lc[rs]);
}

inline void pushdown(int p) {
	if (!tag[p]) return;
	lc[ls] = lc[rs] = tag[p];
	rc[ls] = rc[rs] = tag[p];
	tag[ls] = tag[rs] = tag[p];
	res[ls] = res[rs] = 1;
	tag[p] = 0;
}

inline void build(int p, int l, int r) {
	if (l == r) return (void)(lc[p] = rc[p] = a[to[l]], res[p] = 1);
	int mid = l+r>>1;
	build(ls, l, mid);
	build(rs, mid+1, r);
	pushup(p);
}

inline void upd(int p, int l, int r, int ql, int qr, int x) {
	if (ql <= l && r <= qr) return (void)(lc[p] = rc[p] = tag[p] = x, res[p] = 1);
	pushdown(p); int mid = l+r>>1;
	if (ql <= mid) upd(ls, l, mid, ql, qr, x);
	if (mid < qr) upd(rs, mid+1, r, ql, qr, x);
	pushup(p);
}

inline int que(int p, int l, int r, int ql, int qr) {
	if (ql <= l && r <= qr) {
		if (l == ql) lcol = lc[p];
		if (r == qr) rcol = rc[p];
		return res[p];
	}
	pushdown(p); int mid = l+r>>1;
	if (qr <= mid) return que(ls, l, mid, ql, qr);
	if (mid < ql) return que(rs, mid+1, r, ql, qr);
	return que(ls, l, mid, ql, qr)+que(rs, mid+1, r, ql, qr)-(rc[ls]==lc[rs]);
}

inline void dfs1(int x, int f) {
	d[x] = d[f]+1, fa[x] = f, siz[x] = 1; int mx = 0;
	for (int y:E[x]) if (y != f) {
		dfs1(y, x);
		siz[x] += siz[y];
		if (siz[y] > mx) son[x] = y, mx = siz[y];
	}
}

inline void dfs2(int x, int topf) {
	id[x] = ++idx, to[idx] = x, top[x] = topf;
	if (!son[x]) return;
	dfs2(son[x], topf);
	for (int y:E[x]) if (y != fa[x] && y != son[x]) dfs2(y, y);
}

inline void upd_path(int x, int y, int k) {
	while (top[x] != top[y]) {
		if (d[top[x]] < d[top[y]]) swap(x, y);
		upd(1, 1, n, id[top[x]], id[x], k);
		x = fa[top[x]];
	}
	if (d[x] > d[y]) swap(x, y);
	upd(1, 1, n, id[x], id[y], k);
}

inline int que_path(int x, int y) {
	int ans = 0, c1 = 0, c2 = 0;
	while (top[x] != top[y]) {
		if (d[top[x]] < d[top[y]]) swap(x, y), swap(c1, c2);
		ans += que(1, 1, n, id[top[x]], id[x])-(rcol==c1), c1 = lcol;
		x = fa[top[x]];
	}
	if (d[x] > d[y]) swap(x, y), swap(c1, c2);
	ans += que(1, 1, n, id[x], id[y])-(lcol==c1)-(rcol==c2);
	return ans;
}

int main() {
	n = read(), m = read();
	for (int i = 1; i <= n; ++i) a[i] = read();
	for (int i = 1; i < n; ++i) {
		int x = read(), y = read();
		E[x].pb(y), E[y].pb(x);
	}
	dfs1(1, 0), dfs2(1, 1), build(1, 1, n);
	while (m--) {
		char c; cin >> c; int x = read(), y = read(), k;
		if (c == 'C') k = read(), upd_path(x, y, k);
		else write(que_path(x, y)), enter;
	}
	return 0;
}
posted @ 2023-12-27 21:17  123wwm  阅读(35)  评论(0)    收藏  举报