(简记)虚树
该结构是用来处理一类关键点问题的。
具体地,我们压缩树的信息,每次需要选出 \(k\) 个点 \(h_1,h_2,\dots,h_k\),它们在树上两两路径是重要信息。我们一般会在一个路径的拐点即 \(u\to v\) 的 \(\text{LCA}(u,v)\) 上处理这条路径,那么我们不妨令虚树包含 \(h_i\) 及所有 \(\text{LCA}_{i<j}(h_i,h_j)\),然后树上 DP 或者 DFS 什么的就可以解决问题。
建树
方法 1
所有关键点按 DFS 序升序排序,排序后任意两个相邻点求 \(\text{LCA}\) 并加入关键点序列然后再次排序并去重,最后在序列上枚举任意两个相邻点 \(x,y\) 并连接 \(\text{LCA}(x,y)\to y\) 即可。
这样的做法直观理解是容易的。首先有结论,加入的那一堆点后序列中包含虚树中所有点,证明是容易的,这个结论 NOIP 2024 T4 也有用到。
Proof
假设有任意一个 \(\text{LCA}(h_i,h_j)\) 没有被加入序列中,如果这个 \(\text{LCA}\) 与 \(h_i\) 或 \(h_j\) 任意其一相等那么它肯定在序列中,否则考虑 \(h_i,h_j\) DFS 序排序后中间有没有点。如果没有,根据上述方法会加入序列。如果有且有的在 \(h_i\) 的子树中(\(A\) 集合,为了全面性包括 \(h_i\)),有的挂在 \(\text{LCA}\to h_j\) 的链上(\(B\) 集合,同样包括 \(h_j\)),那么 \(A\) 和 \(B\) 一定会有一个 DFS 序排序后相邻的点贡献给 \(\text{LCA}\)。
我们按照 DFS 序排序,那么连边的过程类似于进行搜索,考虑 \(x\) 在排列中上一个相邻节点,如果其在祖先链上那么肯定就是它在虚树上的直接父亲 \(fa\),否则一定是 \(fa\) 在搜索过程中在另一个子树中的最后一个节点,其 \(\text{LCA}\) 必定是 \(fa\)。
代码请前往 OI-wiki。
方法 1.1
我在场上写出了一种建虚树方法,具体和 1 大差不差,但是减少了在新序列中再求 \(\text{LCA}(x,y)\) 然后连边所带来的常数。具体来说,对序列的顺序遍历过程即是按照 DFS 序 DFS 的过程。用栈记录当前节点祖先链,切换节点 \(i\to i+1\) 就弹栈直到其为 \(i+1\) 的祖先,此时即为其虚树上的父亲。判断是否为祖先可以预先处理 \(dfn,siz\)。
核心代码
sort(a+1,a+1+k,cmp);
k=unique(a+1,a+1+k)-(a+1);
stk[tp=1]=a[1];
for(int i=2;i<=k;i++){
while(tp&&dfn[stk[tp]]+siz[stk[tp]]-1<dfn[a[i]])tp--;
rG[stk[tp]].emplace_back(a[i]);
rG[a[i]].emplace_back(stk[tp]);
stk[++tp]=a[i];
}
方法 2
我们先所有关键点按 DFS 序升序排序,然后顺序遍历模拟原树上 DFS 过程,用单调栈记录当前链上的虚树上的点有哪些,所有虚树上边在节点出栈时连接。每次加入节点 \(u\) 时令 \(lc=\text{LCA}(u,stack_{top})\),弹出栈顶直到栈顶的下一个元素的 DFS 序 \(<dfn_{lc}\),然后如果此时栈顶就是 \(lc\),直接加入即可。否则把 \(lc\) 插入到栈中两个 DFS 序相邻的位置,然后继续弹出后面那个东西。最后如果栈顶不是 \(u\) 把 \(u\) 入栈即可。
点击查看代码
for(int i=1;i<=k;i++)
cin>>h[i],tg[h[i]]=now,head[h[i]]=0;
sort(h+1,h+1+k,cmp);
stk[++tp]=1;head[1]=0;
for(int i=1;i<=k;i++){
int lc=LCA(stk[tp],h[i]);
while(dfn[lc]<dfn[stk[tp]]){
if(tp>1){
if(dfn[stk[tp-1]]<dfn[lc]){
stk[tp+1]=stk[tp];
head[lc]=0;
stk[tp]=lc;
tp++;
}
ins(stk[tp-1],stk[tp],dis(stk[tp-1],stk[tp]));
}
tp--;
}
if(stk[tp]!=h[i])stk[++tp]=h[i];
}
while(tp>1)
ins(stk[tp-1],stk[tp],dis(stk[tp-1],stk[tp])),tp--;
例题
P4103 [HEOI2014] 大工程
树形 DP 上,我们给每个特殊点打当前询问次数的标记,这样我们就不用清空,然后记录 \(u\) 子树内有多少个关键点及其到 \(u\) 路径长度总和及到 \(u\) 路径的最小值和最大值,然后按照最开始说的,统计一条路径在其拐点上记录信息即可。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=1e6+5;
int n,q,h[N],dfn[N],f[N][20];
vector<int>G[N];
int head[N],idx,dep[N],tms;
struct Edge{int v,next,w;}e[N<<1];
int mn,mx,fmn[N],fmx[N],cnt[N];
LL sum,fsum[N];
void ins(int x,int y,int z){
e[++idx].v=y;
e[idx].next=head[x];
e[idx].w=z;
head[x]=idx;
}
void dfs0(int u,int fa){
dfn[u]=++tms;
dep[u]=dep[fa]+1;
for(int v:G[u]){
if(v==fa)continue;
f[v][0]=u;
for(int j=1;(1<<j)<=dep[u];j++)
f[v][j]=f[f[v][j-1]][j-1];
dfs0(v,u);
}
}
int dis(int x,int y){
if(dep[x]<dep[y])swap(x,y);
int res=0;
for(int i=19;i>=0;i--)
if(dep[f[x][i]]>=dep[y])
x=f[x][i],res+=(1<<i);
return res;
}
int LCA(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=19;i>=0;i--)
if(dep[f[x][i]]>=dep[y])
x=f[x][i];
if(x==y)return x;
for(int i=19;i>=0;i--)
if(f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}
int stk[N],tp,now,tg[N];
bool cmp(int x,int y){return dfn[x]<dfn[y];}
void dfs(int u){
fsum[u]=0;
if(tg[u]==now)fmn[u]=fmx[u]=0,cnt[u]=1;
else fmn[u]=n+1,fmx[u]=-n-1,cnt[u]=0;
for(int i=head[u];i;i=e[i].next){
int v=e[i].v,w=e[i].w;
dfs(v);
mn=min(mn,fmn[u]+fmn[v]+w);
mx=max(mx,fmx[u]+fmx[v]+w);
fmn[u]=min(fmn[u],fmn[v]+w);
fmx[u]=max(fmx[u],fmx[v]+w);
sum+=fsum[u]*cnt[v]+(fsum[v]+cnt[v]*w)*cnt[u];
cnt[u]+=cnt[v];
fsum[u]+=fsum[v]+cnt[v]*w;
}
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0);cout.tie(0);
cin>>n;
for(int i=1;i<n;i++){
int u,v;cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs0(1,0);
cin>>q;
for(now=1;now<=q;now++){
idx=0;
int k;cin>>k;
mn=n+1,mx=0;
tp=sum=0;
for(int i=1;i<=k;i++)
cin>>h[i],tg[h[i]]=now,head[h[i]]=0;
sort(h+1,h+1+k,cmp);
stk[++tp]=1;head[1]=0;
for(int i=1;i<=k;i++){
int lc=LCA(stk[tp],h[i]);
while(dfn[lc]<dfn[stk[tp]]){
if(tp>1){
if(dfn[stk[tp-1]]<dfn[lc]){
stk[tp+1]=stk[tp];
head[lc]=0;
stk[tp]=lc;
tp++;
}
ins(stk[tp-1],stk[tp],dis(stk[tp-1],stk[tp]));
}
tp--;
}
if(stk[tp]!=h[i])stk[++tp]=h[i];
}
while(tp>1)
ins(stk[tp-1],stk[tp],dis(stk[tp-1],stk[tp])),tp--;
dfs(1);
cout<<sum<<' '<<mn<<' '<<mx<<'\n';
}
return 0;
}

浙公网安备 33010602011771号