树链剖分

思想

我们都做过这样一道题------查询一棵树上任意两个节点间的路径上的点权之和, 多组询问. 这相当得简单, \(LCA\) 板子题. 那么, 如果我们加上修改呢?

"给出一棵结点数不超过 \(10^5\) 的树, 以及有不超过 \(10^5\) 次操作. 每次操作分为两种: 1. 将树上 \(u\)\(v\) 的路径上的点权全部加上 \(k\) (可能是负数); 2. 求树上 \(u\)\(v\) 的路径上的点权之和为多少."

很明显,这道题就不能用 \(LCA\) 去做了, 不然每次修改都要跑一遍 \(dfs\), 时间炸了.

这时, 我们想到了以前做的线段树板子题: 给出一个序列, 每次区间加上一个数或区间求和. 而这道题不过是把序列换成了一棵树. 那么, 我们可不可以把一棵树转化成一个序列,且方便我们统计两个结点之间的路径呢? 当然可以. 这就是------树链剖分!!!

首先,为了达到将树转化成一条链且时间优秀,方便统计两个结点间的信息的目的, 我们要定义如下概念:

概念 定义
重儿子 某结点的众儿子中,以该儿子为根的子树的结点数最大
轻儿子 某结点的众儿子中除了重儿子以外的所有儿子
重边 与重儿子连接的边
轻边 与轻儿子连接的边
重链 由若干条重边连接而成
轻链 由若干条轻边连接而成

然后, 我们将同一条链上的序号标成连续的,用一种数据结构去维护它(因为带修), 然后每次向上跳一整条链,链中信息可直接统计.为什么我们要这么做呢? 因为有这样一个定理:

如果(\(u\), \(v\))是一条轻边, 那么以\(v\)为根节点的子树的\(size\) \(<\)\(u\)为根节点的子树的\(size\)

为什么? 因为它是一条轻边!!! 如果\(tree[v].size>=tree[u].size\), 那么\(v\)一定是重儿子, \((u, v)\) 便不可能是轻边. 因此, 我们又可以得到一个定理:

从根节点到下面任意一个点, 所经过的链小于\(logn\)条.

又一个为什么? 因为我们有上面那个定理. 如果走重儿子, 那一定会直接跳完一整条链; 否则一定走向了轻儿子.而此时\(size\)直接缩小了一半多(定理:如果(\(u\), \(v\))是一条轻边, 那么以\(v\)为根节点的子树的\(size\) \(<\)\(u\)为根节点的子树的\(size\)).因此,这样每次变成原来的一半还要少,经过链的数量就必定小于\(logn\).

以上便是树链剖分的大体思想.

实现

为了实现树链剖分,我们需要定义以下几个变量:

变量名 用途
xh(序号) 给每个结点的序号, 使得同一条链的序号是连续的(与"结点编号"不同)
father 当前结点的父结点编号
size 以当前结点为根的子树的结点数
val 当前结点的初始点权
dep 当前结点的深度
top 当前结点所在链的顶端的结点的编号
big_son 当前结点的重儿子的结点编号

关键问题: 如何求出一个结点的序号呢?我们可以借助\(dfs\)序来完成.不过遍历顺序有一点小变化:先遍历重儿子,再遍历其他儿子. 于是, 我们可以做两遍\(dfs\), 第一次求出除\(top\)\(xh\)以外的变量,\(bigson\)可以在回溯时求(此时儿子们的\(size\)已经求好).接着,我们就可以做第二遍\(dfs\)了.这一遍\(dfs\)\(top\)\(xh\).只要满足先遍历重儿子的顺序即可.

接着, 我们要将其以序号为原序列下标的线段树维护区间信息.

终于, 我们来到了操作部分. 我们的主要目的是获取任意两点之间的信息. 对此, 我们可以这样做(两个点\(u\), \(v\))(我是递归的写法):

  1. 如果两个点再同一条链上, 即\(tree[u].top == tree[v].top\), 因为一条链上的序号是连续的, 我们可以直接求出其信息, 并返回.

  2. 如果不在同一条链上, 将顶端深度更大的点向上跳一整条链, 再运用一下"同一条链上的序号是连续的一段"的条件, 就可以直接统计这一段对答案的贡献了. 然后, 递归跳了的点的顶端的\(father\)以及另外一个点.

(伪)完结撒花(伪)

模板(洛谷 P3384 【模板】轻重链剖分/树链剖分

树链剖分, 线段树维护区间信息. 上面不懂得看代码就懂了,相当简单.

#include<cstdio>
#include<vector>
#define int long long
#define Maxn 100000
using namespace std;
int read() {
	int f = 1, sum = 0;
	char ch = getchar();
	while(ch < '0' || ch > '9') {
		if(ch == '-') f = -1;
		ch = getchar();
	}
	while(ch >= '0' && ch <= '9') {
		sum = sum * 10 + ch - '0';
		ch = getchar();
	}
	return f * sum;
}
void write(int x) {
    (x < 0) ? (putchar('-'), write(-x)) : (void)((x <= 9) ? (putchar(x + '0')) : (write(x / 10), putchar(x % 10 + '0')));
}
int n, root, cnt, a[Maxn + 9], T, Mod, tmpp[Maxn + 9], answer;
struct Svv {
	int top, size, big, dep, xh, fa;
} t[Maxn + 9];
struct XDS {
	int val, lazy, l, r;
} c[(Maxn << 3) + 9];
vector<int> V[Maxn + 9];
void Dfs1(int x, int dep, int fa) {
	t[x].fa = fa;
	t[x].top = x;
	t[x].dep = dep;
	int maxsize = 0, id;
	for(int i = 0; i < V[x].size(); ++i) {
		int y = V[x][i];
		if(y != fa) {
			Dfs1(y, dep + 1, x);
			t[x].size += t[y].size;
			if(t[y].size > maxsize) {
				maxsize = t[y].size;
				id = y;
			}
			t[x].big = id;
		}
	}
	++t[x].size;
}
void Dfs2(int x, int top, int fa) {
	if(top) t[x].top = top;
	t[x].xh = ++cnt;
	if(t[x].big) Dfs2(t[x].big, t[x].top, x);
	for(int i = 0; i < V[x].size(); ++i) {
		int y = V[x][i];
		if(y != fa && y != t[x].big) {
			Dfs2(y, 0, x);
		}
	}
}
void Swap(int &a, int &b) {
	int t = a;
	a = b;
	b = t;
}
void PushUp(int p) {
	c[p].val = c[p << 1].val + c[p << 1 | 1].val;
}
void PushDown(int p) {
	if(c[p].lazy) {
		c[p << 1].val += c[p].lazy * (c[p << 1].r - c[p << 1].l + 1);
		c[p << 1].val %= Mod;
		c[p << 1].lazy += c[p].lazy;
		c[p << 1].lazy %= Mod;
		c[p << 1 | 1].val += c[p].lazy * (c[p << 1 | 1].r - c[p << 1 | 1].l + 1);
		c[p << 1 | 1].val %= Mod;
		c[p << 1 | 1].lazy += c[p].lazy;
		c[p << 1 | 1].lazy %= Mod;
		c[p].lazy = 0;
	}
}
void Build(int p, int l, int r) {
	c[p].l = l, c[p].r = r;
	if(l == r) {
		c[p].val = a[l];
		return;
	}
	int mid = l + r >> 1;
	Build(p << 1, l, mid);
	Build(p << 1 | 1, mid + 1, r);
	PushUp(p);
}
void Update(int p, int l, int r, int L, int R, int val) {
	if(l >= L && r <= R) {
		c[p].val += val * (c[p].r - c[p].l + 1);
		c[p].val %= Mod;
		c[p].lazy += val;
		c[p].lazy %= Mod;
		return;
	}
	PushDown(p);
	int mid = l + r >> 1;
	if(L <= mid) Update(p << 1, l, mid, L, R, val);
	if(mid + 1 <= R) Update(p << 1 | 1, mid + 1, r, L, R, val);
	PushUp(p);
}
int Query(int p, int l, int r, int L, int R) {
	if(l >= L && r <= R) {
		return c[p].val;
	}
	PushDown(p);
	int mid = l + r >> 1, ans = 0;
	if(L <= mid) ans += Query(p << 1, l, mid, L, R);
	ans %= Mod;
	if(mid + 1 <= R) ans += Query(p << 1 | 1, mid + 1, r, L, R);
	ans %= Mod;
	return ans;
}
void Add1(int u, int v, int val) {
	if(t[u].dep < t[v].dep) Swap(u, v);		//用以增大码量
	if(t[u].top == t[v].top) {
		Update(1, 1, n, t[v].xh, t[u].xh, val);
	}
	else if(t[t[u].top].dep > t[t[v].top].dep) {
		Update(1, 1, n, t[t[u].top].xh, t[u].xh, val);
		Add1(t[t[u].top].fa, v, val);
	}
	else {
		Update(1, 1, n, t[t[v].top].xh, t[v].xh, val);
		Add1(u, t[t[v].top].fa, val);
	}
}
void Solve1(int u, int v) {
	if(t[u].dep < t[v].dep) Swap(u, v);		//用以增大码量
	if(t[u].top == t[v].top) {
		answer += Query(1, 1, n, t[v].xh, t[u].xh);
		answer %= Mod;
	}
	else if(t[t[u].top].dep > t[t[v].top].dep) {
		answer += Query(1, 1, n, t[t[u].top].xh, t[u].xh);
		answer %= Mod;
		Solve1(t[t[u].top].fa, v);
	}
	else {
		answer += Query(1, 1, n, t[t[v].top].xh, t[v].xh);
		answer %= Mod;
		Solve1(u, t[t[v].top].fa);
	}
}
void Add2(int x, int val) {
	Update(1, 1, n, t[x].xh, t[x].xh + t[x].size - 1, val);
}
void Solve2(int x) {
	answer += Query(1, 1, n, t[x].xh, t[x].xh + t[x].size - 1);
	answer %= Mod;
}
void Solve() {
	int opt = read();
	if(opt == 1) {
		int x = read(), y = read(), z = read();
		Add1(x, y, z);
	}
	else if(opt == 2) {
		int x = read(), y = read();
		answer = 0;
		Solve1(x, y);
		write(answer);
		putchar('\n');
	}
	else if(opt == 3) {
		int x = read(), z = read();
		Add2(x, z);
	}
	else {
		int x = read();
		answer = 0;
		Solve2(x);
		write(answer);
		putchar('\n');
	}
}
signed main() {
	n = read(), T = read(), root = read(), Mod = read();
	for(int i = 1; i <= n; ++i) {
		tmpp[i] = read();
	}
	for(int i = 1; i < n; ++i) {
		int u = read(), v = read();
		V[u].push_back(v);
		V[v].push_back(u);
	}
	Dfs1(root, 1, 0);
	Dfs2(root, 0, 0);
	for(int i = 1; i <= n; ++i) {
		a[t[i].xh] = tmpp[i];
	}
	Build(1, 1, n);
	while(T--) {
		Solve();
	}
	return 0;
}

例题1(洛谷 P2590 [ZJOI2008]树的统计

讲真的这道题比模板还简单, 一道模板的模板题

#include<cstdio>
#include<vector>
#define int long long
#define Maxn 30000
using namespace std;
int read() {
	int f = 1, sum = 0;
	char ch = getchar();
	while(ch < '0' || ch > '9') {
		if(ch == '-') f = -1;
		ch = getchar();
	}
	while(ch >= '0' && ch <= '9') {
		sum = sum * 10 + ch - '0';
		ch = getchar();
	}
	return f * sum;
}
void write(int x) {
    (x < 0) ? (putchar('-'), write(-x)) : (void)((x <= 9) ? (putchar(x + '0')) : (write(x / 10), putchar(x % 10 + '0')));
}
int n, cnt, a[Maxn + 9], T, ans;
char opt[6];
struct Svv {
	int top, size, big, dep, xh, fa;
} t[Maxn + 9];
struct XDS {
	int valmax, valsum;
} c[(Maxn << 2) + 9];
vector<int> V[Maxn + 9];
void Dfs1(int x, int dep, int fa) {
	t[x].fa = fa;
	t[x].top = x;
	t[x].dep = dep;
	int maxsize = 0, id;
	for(int i = 0; i < V[x].size(); ++i) {
		int y = V[x][i];
		if(y != fa) {
			Dfs1(y, dep + 1, x);
			t[x].size += t[y].size;
			if(t[y].size > maxsize) {
				maxsize = t[y].size;
				id = y;
			}
			t[x].big = id;
		}
	}
	++t[x].size;
}
void Dfs2(int x, int top, int fa) {
	if(top) t[x].top = top;
	t[x].xh = ++cnt;
	if(t[x].big) Dfs2(t[x].big, t[x].top, x);
	for(int i = 0; i < V[x].size(); ++i) {
		int y = V[x][i];
		if(y != fa && y != t[x].big) {
			Dfs2(y, 0, x);
		}
	}
}
int Max(int a, int b) {
	return a > b ? a : b;
}
void Swap(int &a, int &b) {
	int t = a;
	a = b;
	b = t;
}
void PushUpMax(int p) {
	c[p].valmax = Max(c[p << 1].valmax, c[p << 1 | 1].valmax);
}
void PushUpSum(int p) {
	c[p].valsum = c[p << 1].valsum + c[p << 1 | 1].valsum;
}
void Build(int p, int l, int r) {
	if(l == r) {
		c[p].valmax = c[p].valsum = a[l];
		return;
	}
	int mid = l + r >> 1;
	Build(p << 1, l, mid);
	Build(p << 1 | 1, mid + 1, r);
	PushUpMax(p);
	PushUpSum(p);
}
void Update(int p, int l, int r, int x, int y) {
	if(l == r) {
		c[p].valmax = c[p].valsum = y;
		return;
	}
	int mid = l + r >> 1;
	if(x <= mid) Update(p << 1, l, mid, x, y);
	if(mid + 1 <= x) Update(p << 1 | 1, mid + 1, r, x, y);
	PushUpMax(p);
	PushUpSum(p);
}
int QueryMax(int p, int l, int r, int L, int R) {
	if(l >= L && r <= R) {
		return c[p].valmax;
	}
	int mid = l + r >> 1, ans = -1e9;
	if(L <= mid) ans = Max(ans, QueryMax(p << 1, l, mid, L, R));
	if(mid + 1 <= R) ans = Max(ans, QueryMax(p << 1 | 1, mid + 1, r, L, R));
	return ans;
}
int QuerySum(int p, int l, int r, int L, int R) {
	if(l >= L && r <= R) {
		return c[p].valsum;
	}
	int mid = l + r >> 1, ans = 0;
	if(L <= mid) ans += QuerySum(p << 1, l, mid, L, R);
	if(mid + 1 <= R) ans += QuerySum(p << 1 | 1, mid + 1, r, L, R);
	return ans;
}
void SolveMax(int u, int v) {
	if(t[u].dep < t[v].dep) {
		Swap(u, v);
	}
	if(t[u].top == t[v].top) {
		ans = Max(ans, QueryMax(1, 1, n, t[v].xh, t[u].xh));
		return;
	}
	else if(t[t[u].top].dep > t[t[v].top].dep) {
		ans = Max(ans, QueryMax(1, 1, n, t[t[u].top].xh, t[u].xh));
		SolveMax(t[t[u].top].fa, v);
	}
	else {
		ans = Max(ans, QueryMax(1, 1, n, t[t[v].top].xh, t[v].xh));
		SolveMax(u, t[t[v].top].fa);
	}
}
void SolveSum(int u, int v) {
	if(t[u].dep < t[v].dep) {
		Swap(u, v);
	}
	if(t[u].top == t[v].top) {
		ans += QuerySum(1, 1, n, t[v].xh, t[u].xh);
		return;
	}
	else if(t[t[u].top].dep > t[t[v].top].dep) {
		ans += QuerySum(1, 1, n, t[t[u].top].xh, t[u].xh);
		SolveSum(t[t[u].top].fa, v);
	}
	else {
		ans += QuerySum(1, 1, n, t[t[v].top].xh, t[v].xh);
		SolveSum(u, t[t[v].top].fa);
	}
}
void Solve() {
	scanf(" %s", opt);
	if(opt[0] == 'C') {
		int u = read(), tt = read();
		Update(1, 1, n, t[u].xh, tt);
	}
	else {
		int u = read(), v = read();
		if(opt[1] == 'M') {
			ans = -1e9;
			SolveMax(u, v);
		}
		else {
			ans = 0;
			SolveSum(u, v);
		}
		write(ans);
		putchar('\n');
	}
}
signed main() {
	n = read();
	for(int i = 1; i < n; ++i) {
		int u = read(), v = read();
		V[u].push_back(v);
		V[v].push_back(u);
	}
	Dfs1(1, 1, 0);
	Dfs2(1, 1, 0);
	for(int i = 1; i <= n; ++i) {
		a[t[i].xh] = read();
	}
	Build(1, 1, n);
	T = read();
	while(T--) {
		Solve();
	}
	return 0;
}

例题2(洛谷 P4114 Qtree1

将树从根节点提起, 再把每条边的边权推给其连接的儿子. 最后查询/更改的时候不统计/更改\(LCA\)的权值.

#include<cstdio>
#include<vector>
#define int long long
#define Maxn 300000
using namespace std;
int read() {
	int f = 1, sum = 0;
	char ch = getchar();
	while(ch < '0' || ch > '9') {
		if(ch == '-') f = -1;
		ch = getchar();
	}
	while(ch >= '0' && ch <= '9') {
		sum = sum * 10 + ch - '0';
		ch = getchar();
	}
	return f * sum;
}
void write(int x) {
    (x < 0) ? (putchar('-'), write(-x)) : (void)((x <= 9) ? (putchar(x + '0')) : (write(x / 10), putchar(x % 10 + '0')));
}
int n, cnt, a[Maxn + 9], T, ans, change[Maxn + 9];
char opt[6];
struct Svv {
	int top, size, big, dep, xh, fa;
} t[Maxn + 9];
struct XDS {
	int valmax;
} c[(Maxn << 2) + 9];
struct TTT {
	int to, val, id;
};
vector<TTT> V[Maxn + 9];
void Dfs1(int x, int dep, int fa) {
	t[x].fa = fa;
	t[x].top = x;
	t[x].dep = dep;
	int maxsize = 0, id;
	for(int i = 0; i < V[x].size(); ++i) {
		int y = V[x][i].to;
		if(y != fa) {
			Dfs1(y, dep + 1, x);
			t[x].size += t[y].size;
			if(t[y].size > maxsize) {
				maxsize = t[y].size;
				id = y;
			}
			t[x].big = id;
		}
	}
	++t[x].size;
}
void Dfs2(int x, int top, int fa) {
	if(top) t[x].top = top;
	if(t[x].big) t[t[x].big].xh = ++cnt, Dfs2(t[x].big, t[x].top, x);
	for(int i = 0; i < V[x].size(); ++i) {
		int y = V[x][i].to;
		if(y != fa) {
			if(y != t[x].big) t[y].xh = ++cnt, Dfs2(y, 0, x);
			a[t[y].xh] = V[x][i].val;
			change[V[x][i].id] = y;
		}
	}
}
int Max(int a, int b) {
	return a > b ? a : b;
}
void Swap(int &a, int &b) {
	int t = a;
	a = b;
	b = t;
}
void PushUpMax(int p) {
	c[p].valmax = Max(c[p << 1].valmax, c[p << 1 | 1].valmax);
}
void Build(int p, int l, int r) {
	if(l == r) {
		c[p].valmax = a[l];
		return;
	}
	int mid = l + r >> 1;
	Build(p << 1, l, mid);
	Build(p << 1 | 1, mid + 1, r);
	PushUpMax(p);
}
void Update(int p, int l, int r, int x, int y) {
	if(l == r) {
		c[p].valmax = y;
		return;
	}
	int mid = l + r >> 1;
	if(x <= mid) Update(p << 1, l, mid, x, y);
	if(mid + 1 <= x) Update(p << 1 | 1, mid + 1, r, x, y);
	PushUpMax(p);
}
int QueryMax(int p, int l, int r, int L, int R) {
	if(l >= L && r <= R) {
		return c[p].valmax;
	}
	int mid = l + r >> 1, ans = -1e9;
	if(L <= mid) ans = Max(ans, QueryMax(p << 1, l, mid, L, R));
	if(mid + 1 <= R) ans = Max(ans, QueryMax(p << 1 | 1, mid + 1, r, L, R));
	return ans;
}
void SolveMax(int u, int v) {
	if(t[u].dep < t[v].dep) {
		Swap(u, v);
	}
	if(t[u].top == t[v].top) {
		ans = Max(ans, QueryMax(1, 1, n, t[v].xh + 1, t[u].xh));
		return;
	}
	else if(t[t[u].top].dep > t[t[v].top].dep) {
		ans = Max(ans, QueryMax(1, 1, n, t[t[u].top].xh, t[u].xh));
		SolveMax(t[t[u].top].fa, v);
	}
	else {
		ans = Max(ans, QueryMax(1, 1, n, t[t[v].top].xh, t[v].xh));
		SolveMax(u, t[t[v].top].fa);
	}
}
bool Check(int u, int v, int high) {
	if(t[u].dep < t[v].dep) return 0;
	if(t[u].top != t[v].top) return Check(t[t[u].top].fa, v, high);
	else {
		if(v == high) return 1;
		else return 0;
	}
}
void Solve() {
	scanf(" %s", opt);
	if(opt[0] == 'D') {
		return;
	}
	else if(opt[0] == 'C'){
		int x = read(), tt = read();
		Update(1, 1, n, t[change[x]].xh, tt);
	}
	else {
		int x = read(), y = read();
		if(t[x].dep < t[y].dep) Swap(x, y);
		if(x == y) {
			write(0);
			putchar('\n');
		}
		else if(Check(x, y, y)) {
			int tmpmp = QueryMax(1, 1, n, t[y].xh, t[y].xh);
			Update(1, 1, n, t[y].xh, -1e18);
			ans = -1e9;
			SolveMax(x, y);
			write(ans);
			putchar('\n');
			Update(1, 1, n, t[y].xh, tmpmp);
		}
		else {
			ans = -1e9;
			SolveMax(x, y);
			write(ans);
			putchar('\n');
		}
	}
	Solve();
}
signed main() {
	n = read();
	for(int i = 1; i < n; ++i) {
		int u = read(), v = read(), w = read();
		TTT temp;
		temp.to = v, temp.val = w, temp.id = i;
		V[u].push_back(temp);
		temp.to = u;
		V[v].push_back(temp);
	}
	Dfs1(1, 1, 0);
	t[1].xh = ++cnt, a[1] = -1e18;
	Dfs2(1, 1, 0);
	Build(1, 1, n);
	Solve();
	return 0;
}

例题3(洛谷 P4116 Qtree3

我的解法相当暴力, 就是纯粹的枚举. 我有个同学的做法是用线段树统计区间内最左边的黑点的编号, 是真正的\(n log^2 n\).

不过不知道为什么我过了他超时了???!!!

由于我的过而他没过, 所以就放我的代码吧.

#include<cstdio>
#include<vector>
#define int long long
#define Maxn 100000
using namespace std;
int read() {
	int f = 1, sum = 0;
	char ch = getchar();
	while(ch < '0' || ch > '9') {
		if(ch == '-') f = -1;
		ch = getchar();
	}
	while(ch >= '0' && ch <= '9') {
		sum = sum * 10 + ch - '0';
		ch = getchar();
	}
	return f * sum;
}
void write(int x) {
    (x < 0) ? (putchar('-'), write(-x)) : (void)((x <= 9) ? (putchar(x + '0')) : (write(x / 10), putchar(x % 10 + '0')));
}
int n, cnt, Q, ans;
char opt[6];
struct Svv {
	int top, size, big, dep, xh, fa;
} t[Maxn + 9];
struct XDS {
	int val;
} c[(Maxn << 2) + 9];
vector<int> V[Maxn + 9];
void Dfs1(int x, int dep, int fa) {
	t[x].fa = fa;
	t[x].top = x;
	t[x].dep = dep;
	int maxsize = 0, id;
	for(int i = 0; i < V[x].size(); ++i) {
		int y = V[x][i];
		if(y != fa) {
			Dfs1(y, dep + 1, x);
			t[x].size += t[y].size;
			if(t[y].size > maxsize) {
				maxsize = t[y].size;
				id = y;
			}
			t[x].big = id;
		}
	}
	++t[x].size;
}
void Dfs2(int x, int top, int fa) {
	if(top) t[x].top = top;
	t[x].xh = ++cnt;
	if(t[x].big) Dfs2(t[x].big, t[x].top, x);
	for(int i = 0; i < V[x].size(); ++i) {
		int y = V[x][i];
		if(y != fa && y != t[x].big) {
			Dfs2(y, 0, x);
		}
	}
}
void Swap(int &a, int &b) {
	int t = a;
	a = b;
	b = t;
}
void PushUpSum(int p) {
	c[p].val = c[p << 1].val + c[p << 1 | 1].val;
}
void Update(int p, int l, int r, int x) {
	if(l == r) {
		c[p].val = 1 - c[p].val;
		return;
	}
	int mid = l + r >> 1;
	if(x <= mid) Update(p << 1, l, mid, x);
	if(mid + 1 <= x) Update(p << 1 | 1, mid + 1, r, x);
	PushUpSum(p);
}
int Query(int p, int l, int r, int L, int R) {
	if(l >= L && r <= R) {
		return c[p].val;
	}
	int mid = l + r >> 1, ans = 0;
	if(L <= mid) ans += Query(p << 1, l, mid, L, R);
	if(mid + 1 <= R) ans += Query(p << 1 | 1, mid + 1, r, L, R);
	return ans;
}
void Find(int x) {
	if(!x) return;
	if(Query(1, 1, n, t[t[x].top].xh, t[x].xh)) {
		if(Query(1, 1, n, t[x].xh, t[x].xh)) ans = x;
		for(int i = t[x].top; i != x; i = t[i].big) {
			if(Query(1, 1, n, t[t[x].top].xh, t[i].xh)) {
				ans = i;
				break;
			}
		}
	}
	Find(t[t[x].top].fa);
}
void Solve() {
	int opt = read();
	if(opt == 0) {
		int x = read();
		Update(1, 1, n, t[x].xh);
	}
	else {
		int x = read();
		ans = -1;
		Find(x);
		write(ans);
		putchar('\n');
	}
}
signed main() {
	n = read(), Q = read();
	for(int i = 1; i < n; ++i) {
		int u = read(), v = read();
		V[u].push_back(v);
		V[v].push_back(u);
	}
	Dfs1(1, 1, 0);
	Dfs2(1, 1, 0);
	while(Q--) {
		Solve();
	}
	return 0;
}

(真)完结撒花(真)

posted @ 2022-07-20 16:50  TuSalcc  阅读(40)  评论(0)    收藏  举报