[虚树][lca][dfs] 洛谷 P3233 世界树
题解
-
对于每次询问建一棵虚树
-
然后在虚树上dfs一遍,得到每个点从属于哪个节点
-
这样的话我们只要统计不在虚树中的点
-
考虑虚树上的某一条边,如果两个点同属于一个节点,那么只要加上两点之间的未在虚树中的点数即可
-
假如两个点不属于同一节点,那么显然中间会存在分界点,倍增地找出这个分界点mid,然后两边分别计算贡献就可以了
-
还需要记一个g数组,表示未在上述统计到的节点数,因为有一些点没有在任何一次讨论中被考虑,那么显然将会与他们在虚树上的父亲节点属于同一节点,只要把初值设为size,每次在上面讨论一次,就把讨论的部分删掉即可
代码
1 #include <cstdio> 2 #include <iostream> 3 #include <cstring> 4 #include <algorithm> 5 using namespace std; 6 const int N=3e5+10; 7 struct edge{ int from,to;}e[N<<1]; 8 int n,tot,cnt,be,top,Q,m,head[N],f[N][20],size[N],dep[N],dfn[N],ans[N],g[N],c[N],h[N],bel[N],stack[N<<1],a[N]; 9 void insert(int x,int y) { e[++tot].to=y,e[tot].from=head[x],head[x]=tot; } 10 bool cmp(int x,int y) { return dfn[x]<dfn[y]; } 11 int lca(int x,int y) 12 { 13 if (!x&&!y) return 0; if (!x||!y) return x+y; 14 if (dep[x]<dep[y]) swap(x,y); 15 if (dep[x]!=dep[y]) for (int i=19;i>=0;i--) if (dep[f[x][i]]>=dep[y]) x=f[x][i]; 16 if (x==y) return x; 17 for (int i=19;i>=0;i--) if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i]; 18 return f[x][0]; 19 } 20 void add(int x) 21 { 22 if (top==1) { stack[++top]=x; return; } 23 int LCA=lca(x,stack[top]); 24 while (top>1&&dfn[stack[top-1]]>=dfn[LCA]) insert(stack[top-1],stack[top]),top--; 25 if (stack[top]!=LCA) insert(LCA,stack[top]),stack[top]=LCA; 26 stack[++top]=x; 27 } 28 int getdis(int x,int y) { return dep[x]+dep[y]-2*dep[lca(x,y)]; } 29 void dfs(int x,int fa) 30 { 31 f[x][0]=fa,size[x]=1,dep[x]=dep[fa]+1,dfn[x]=++cnt; 32 for (int i=1;i<20;i++) f[x][i]=f[f[x][i-1]][i-1]; 33 for (int i=head[x];i;i=e[i].from) if (e[i].to!=fa) dfs(e[i].to,x),size[x]+=size[e[i].to]; 34 } 35 void dfs1(int x) 36 { 37 g[x]=size[x],c[++cnt]=x; 38 for (int i=head[x];i;i=e[i].from) 39 { 40 dfs1(e[i].to); 41 if (!bel[e[i].to]) continue; 42 if (!bel[x]) { bel[x]=bel[e[i].to]; continue; } 43 int dis1=getdis(bel[x],x),dis2=getdis(bel[e[i].to],x); 44 if (dis1>dis2||(dis1==dis2&&bel[x]>bel[e[i].to])) bel[x]=bel[e[i].to]; 45 } 46 } 47 void dfs2(int x) 48 { 49 for (int i=head[x];i;i=e[i].from) 50 { 51 if (!bel[e[i].to]) bel[e[i].to]=bel[x]; 52 else 53 { 54 int dis1=getdis(bel[e[i].to],e[i].to),dis2=getdis(bel[x],e[i].to); 55 if (dis1>dis2||(dis1==dis2&&bel[x]<bel[e[i].to])) bel[e[i].to]=bel[x]; 56 } 57 dfs2(e[i].to); 58 } 59 } 60 void calc(int fa,int x) 61 { 62 int son=x,mid=x; 63 for (int i=19;i>=0;i--) if (dep[f[son][i]]>dep[fa]) son=f[son][i]; 64 g[fa]-=size[son]; 65 if (bel[fa]==bel[x]) { ans[bel[fa]]+=size[son]-size[x]; return ; } 66 for (int i=19;i>=0;i--) 67 { 68 int r=f[mid][i]; 69 if (dep[r]<=dep[fa]) continue; 70 int dis1=getdis(bel[x],r),dis2=getdis(bel[fa],r); 71 if (dis1<dis2||(dis1==dis2&&bel[x]<bel[fa])) mid=r; 72 } 73 ans[bel[fa]]+=size[son]-size[mid],ans[bel[x]]+=size[mid]-size[x]; 74 } 75 int main() 76 { 77 scanf("%d",&n); 78 for (int i=1,x,y;i<n;i++) scanf("%d%d",&x,&y),insert(x,y),insert(y,x); 79 dfs(1,0),scanf("%d",&Q),tot=0,memset(head,0,sizeof(head)); 80 while (Q--) 81 { 82 scanf("%d",&m); 83 for (int i=1;i<=m;i++) scanf("%d",&a[i]),h[i]=a[i]; 84 for (int i=1;i<=m;i++) bel[h[i]]=h[i]; 85 sort(h+1,h+m+1,cmp),stack[top=1]=1,cnt=0; 86 if (h[1]==1) be=2; else be=1; 87 for (int i=be;i<=m;i++) add(h[i]); 88 while (top>1) insert(stack[top-1],stack[top]),top--; 89 dfs1(1),dfs2(1); 90 for (int i=1;i<=cnt;i++) for (int j=head[c[i]];j;j=e[j].from) calc(c[i],e[j].to); 91 for (int i=1;i<=cnt;i++) ans[bel[c[i]]]+=g[c[i]]; 92 for (int i=1;i<=m;i++) printf("%d ",ans[a[i]]); 93 printf("\n"); 94 for (int i=1;i<=cnt;i++) ans[c[i]]=head[c[i]]=g[c[i]]=bel[c[i]]=0; 95 tot=0; 96 } 97 }