动态树(Link-Cut Tree)
前置知识Splay,可以参考这篇博文。
算法思想
动态树算法用于解决一类树上问题,涉及树边的连接和断开,本质是维护一个森林。该算法将树上的边划分为实边和虚边,每一条实边构成的实链有一个Splay维护,Splay之间用虚边连。LCT可以轻易实现实链的重新划分,因此可以借助Splay实现高效维护树上路径信息。算法细节见YangZhe的论文
代码实现
LCT将普通的树形结构转化为二叉搜索树进行维护,主要依靠两个数组,分别是f[N],ch[N][2]。f[x]表示父亲节点,ch[x][0],ch[x][1]分别表示左儿子和右儿子。二叉搜索树的key值为每个点在树上的dfs序。下面对关键函数功能进行单独解释。
notroot
判断节点x是不是根节点,方法是判断x是否存在父亲。因为可能是森林,根节点的f[x]不一定是0,所以需要判断x是不是f[x]的某个儿子。
bool notroot(int x) {
return (ch[f[x]][0] == x) || (ch[f[x]][1] == x);
}
pushrev
维护链的翻转信息,在LCT中有换根操作,所以节点的dfs序关系可能发生改变,此时需要交换二叉搜索树的左右节点。
void pushrev(int x) {
swap(ch[x][0], ch[x][1]);
rev[x] ^= 1;
}
pushdown
用于维护翻转标记。
void pushdown(int x) {
if (rev[x]) {
if (ch[x][0]) pushrev(ch[x][0]);
if (ch[x][1]) pushrev(ch[x][1]);
rev[x] = 0;
}
}
rotate,splay
与普通Splay中的实现相同,区别是使用notroot判断是否是根节点,并且在splay前需要提前下放翻转标记。splay的作用是将某个点旋转到当前Splay的根的位置。
void rotate(int x) {
int y = f[x], z = f[y];
int d = ch[y][1] == x, w = ch[x][d ^ 1];
if (notroot(y)) ch[z][ch[z][1] == y] = x; ch[x][d ^ 1] = y; ch[y][d] = w;
if (w) f[w] = y; f[y] = x; f[x] = z;
pushup(y); pushup(x);
}
void splay(int x) {
int u = x, top = 0;
stk[++top] = u;
while (notroot(u)) stk[++top] = u = f[u];
while (top) pushdown(stk[top--]);
while (notroot(x)) {
int y = f[x], z = f[y];
if (notroot(y)) (ch[z][0] == y) ^ (ch[y][0] == x) ? rotate(x) : rotate(y);
rotate(x);
}
}
access
LCT中的精髓,作用是将x到根的的链变为实链,将其中的点放进一个Splay中,方便维护链上信息。代码中的y是前一条实链的根,先将x旋至根,然后将两个Splay合并,即将两条链实间的虚边变为实边,不断重复,直到路径上的所有实链被合并到一起。
void access(int x) {
for (int y = 0; x; x = f[y = x]) {
splay(x); ch[x][1] = y; pushup(x);
}
}
makeroot
换根操作,先access,将x和根节点放在一个Splay中。再splay,让x成为Splay的根。最后打上翻转标记,将链进行翻转,实现换根。
void makeroot(int x) {
access(x); splay(x); pushrev(x);
}
findroot
寻找根节点,先access+splay,然后不断走左子树寻找dfs序最小的节点。
int findroot(int x) {
access(x); splay(x);
while (ch[x][0]) pushdown(x), x = ch[x][0];
splay(x);
return x;
}
split
将两点之间的路径放进一个Splay中,这个函数使得我们可以方便的维护任意两点的路径信息。先将x变成根,再用access+splay分离出y到根的路径,得到的就是x到y的路径。
void split(int x,int y) {
makeroot(x); access(y); splay(y);
}
link
连接树边,先将x变为根,再将x连接到y上。注意判断x和y是否本身就在一棵树中。
void link(int x,int y) {
makeroot(x);
if (findroot(y) != x) f[x] = y;
}
cut
断开树边,先将x变为根,若x与y相连,此时y就是x的右儿子,断开该边。注意判断x与y是否直接相连。
void cut(int x,int y) {
makeroot(x);
if (findroot(y) == x && f[y] == x && !ch[y][0]) {
f[y] = ch[x][1] = 0;
pushup(x);
}
}
完整代码
#include <bits/stdc++.h>
#define ll long long
using namespace std;
int read() {
int x = 0, f = 1; char ch = getchar();
while (ch < '0' || ch > '9') {if (ch == '-') f = -1; ch = getchar();}
while (ch >= '0' && ch <= '9') {x = (x << 1) + (x << 3) + (ch ^ 48); ch = getchar();}
return x * f;
}
const int N = 1e5 + 10;
struct Link_Cut_Tree {
int f[N], ch[N][2], sum[N], val[N], rev[N], stk[N];
bool notroot(int x) {
return (ch[f[x]][0] == x) || (ch[f[x]][1] == x);
}
void pushup(int x) {
sum[x] = sum[ch[x][0]] ^ sum[ch[x][1]] ^ val[x];
}
void pushrev(int x) {
swap(ch[x][0], ch[x][1]);
rev[x] ^= 1;
}
void pushdown(int x) {
if (rev[x]) {
if (ch[x][0]) pushrev(ch[x][0]);
if (ch[x][1]) pushrev(ch[x][1]);
rev[x] = 0;
}
}
void rotate(int x) {
int y = f[x], z = f[y];
int d = ch[y][1] == x, w = ch[x][d ^ 1];
if (notroot(y)) ch[z][ch[z][1] == y] = x; ch[x][d ^ 1] = y; ch[y][d] = w;
if (w) f[w] = y; f[y] = x; f[x] = z;
pushup(y); pushup(x);
}
void splay(int x) {
int u = x, top = 0;
stk[++top] = u;
while (notroot(u)) stk[++top] = u = f[u];
while (top) pushdown(stk[top--]);
while (notroot(x)) {
int y = f[x], z = f[y];
if (notroot(y)) (ch[z][0] == y) ^ (ch[y][0] == x) ? rotate(x) : rotate(y);
rotate(x);
}
}
void access(int x) {
for (int y = 0; x; x = f[y = x]) {
splay(x); ch[x][1] = y; pushup(x);
}
}
void makeroot(int x) {
access(x); splay(x); pushrev(x);
}
int findroot(int x) {
access(x); splay(x);
while (ch[x][0]) pushdown(x), x = ch[x][0];
splay(x);
return x;
}
void split(int x,int y) {
makeroot(x); access(y); splay(y);
}
void link(int x,int y) {
makeroot(x);
if (findroot(y) != x) f[x] = y;
}
void cut(int x,int y) {
makeroot(x);
if (findroot(y) == x && f[y] == x && !ch[y][0]) {
f[y] = ch[x][1] = 0;
pushup(x);
}
}
} lct;
int n, m;
int main() {
n = read(); m = read();
for (int i = 1; i <= n; i++) {
int x = read();
lct.val[i] = lct.sum[i] = x;
}
for (int i = 1; i <= m; i++) {
int op = read(), x = read(), y = read();
if (op == 0) {
lct.split(x, y);
printf("%d\n", lct.sum[y]);
}
if (op == 1) {
lct.link(x, y);
}
if (op == 2) {
lct.cut(x, y);
}
if (op == 3) {
lct.splay(x);
lct.sum[x] ^= lct.val[x] ^ y;
lct.val[x] = y;
}
}
return 0;
}

浙公网安备 33010602011771号