P5666 [CSP-S2019] 树的重心 题解
P5666 [CSP-S2019] 树的重心 题解
题目描述
定义:
在树中删去一个结点及与它关联的边,树将分裂为若干个子树;而在树中删去一条边(保留关联结点,下同),树将分裂为恰好两个子树。
树的重心:在树中删去树的重心后,剩下的几棵子树大小都不超过树原大小的一半 ps:一棵树能有不超过两个重心
树的大小:一棵树的节点个数(
by baidu)
求:
把整棵树单独删去每一条边后,求出剩下两个子树的重心的和的和
解题思路
刷题经验:题目读懂后先看数据范围:
测试点编号 | \(n =\) | 特殊性质 |
---|---|---|
\(1 \sim 2\) | \(7\) | 无 |
\(3 \sim 5\) | \(199\) | 无 |
\(6 \sim 8\) | \(1999\) | 无 |
\(9 \sim 11\) | \(49991\) | A |
\(12 \sim 15\) | \(262143\) | B |
\(16\) | \(99995\) | 无 |
\(17 \sim 18\) | \(199995\) | 无 |
\(19 \sim 20\) | \(299995\) | 无 |
表中特殊性质一栏,两个变量的含义为存在一个 \(1 \sim n\) 的排列 \(p_i (1 \leq i \leq n)\),使得:
- A:树的形态是一条链。
- B:树的形态是一个完美二叉树。
对于所有测试点:\(1 \leq T \leq 5 , 1 \leq u_i,v_i \leq n\)。保证给出的图是一个树。
1,考虑暴力 \(O (n^2)\)
方法很简单,只需要对于每一条边切断后 $ O(n) $ 枚举重心即可
这样就能拿到 40pts
2,考虑正解
由于 \(n \leq 3*10^5\) 所以
可以使用nlogn或线性的复杂度来完成
又因为无法避免枚举每一条边
因此,对于求出每一条边切断后的重心,要用\(log_n\)或\(1\)的时间完成
突然想到,对于每一颗子树的重心,(突破口)
要么在根节点上
要么在最重链上
证:
要是这个重心不再最重子树上,
而在轻子树上的话
那么重子树一定会超过轻子树,即超过了总数的一半
对于每一条被切断的边的下面,
直接从重链开始向上倍增即可
那么对于上面呢?
第一种思路(错误)
从根节点开始递归寻找能够变成最重的
此部分无法预处理
当数据是一条链的时候就会出事退化为\(O(n)\)
怎么办呢
第二种思路 换根 (突破口)
倍增可以不一定往上做+预处理,
而是可以顺着重链从根节点往下做
每次递归前用\(O(log_n)\)的时间复杂度换根,
递归完成后再用\(O(log_n)\)的时间回溯(复原)
因此思路就很明确了
- 对于每次dfs到某个节点时,把它往上的那条边断掉,更新答案
void update(int x,int s) {
bool p=(x==5&&s==4);
for(int i=18;i>=0;i--) {
if(s-sum[st[st[x][i]][0]]<=s/2) x=st[x][i];//每次保证以上都满足
}//极易写错和理解错
if(sum[st[x][0]]<=s/2&&s-sum[x]<=s/2) ans+=x;
x=st[x][0];
if(sum[st[x][0]]<=s/2&&s-sum[x]<=s/2) ans+=x;//万一有两个重心,它往后的那个也可以
}
void dfs2(int f,int x) {
if(f) {
update(x,sum[x]);
update(f,sum[f]);
}
}
-
每次继续dfs前用\(log_n\)更新st
-
继续递归
-
回来之后重新弄st
-
(重点) 特判重子数的情况:
if(mx==maxx[x]) {
mx=0;
for(int v,i=head[x];i;i=e[i].next) {
v=e[i].to;
if(v==maxx[x]) continue;
if(v==f) continue;
if(sum[v]>sum[mx]) mx=v;
}//重新找
if(sum[f]>sum[mx]) {
mx=f;
}
sum[x]=n-sum[maxx[x]];
st[x][0]=mx;
for(int i=1;i<=18;i++) {
st[x][i]=st[st[x][i-1]][i-1];//易错
}
dfs2(x,maxx[x]);
}
那么,只要这样,我们就能统计答案了
具体实现
#include<iostream>
#include<cstdio>
#include<cstring>
#define int long long
using namespace std;
const int N=300001;
struct E {
int to,next;
}e[2*N];
int head[N],num;
void add(int u,int v) {
e[++num].to=v;
e[num].next=head[u];
head[u]=num;
}
int n,m;
int sum[N];
int st[N][19] ;//每个点到最重链的倍增
int maxx[N];
int ans=0;
void dfs(int f,int x) {
sum[x]=1;
for(int v,i=head[x];i;i=e[i].next) {
v=e[i].to;
if(v==f) continue;
dfs(x,v);
sum[x]+=sum[v];
if(sum[v]>sum[maxx[x]]) {
maxx[x]=v;
}
}
st[x][0]=maxx[x];
for(int i=1;i<=18;i++) {
st[x][i]=st[st[x][i-1]][i-1];
}
}
void update(int x,int s) {
bool p=(x==5&&s==4);
for(int i=18;i>=0;i--) {
if(s-sum[st[st[x][i]][0]]<=s/2) x=st[x][i];
}
if(sum[st[x][0]]<=s/2&&s-sum[x]<=s/2) ans+=x;
x=st[x][0];
if(sum[st[x][0]]<=s/2&&s-sum[x]<=s/2) ans+=x;
}
void dfs2(int f,int x) {
if(f) {
update(x,sum[x]);
update(f,sum[f]);
}
int mx=maxx[x];
if(n-sum[x]>sum[mx]) {
mx=f;
st[x][0]=mx;
for(int i=1;i<=18;i++) {
st[x][i]=st[st[x][i-1]][i-1];
}
}
for(int v,i=head[x];i;i=e[i].next) {
v=e[i].to;
if(v==f) continue;
if(v==mx) continue;
sum[x]=n-sum[v];
dfs2(x,v);
}
if(mx==maxx[x]) {
mx=0;
for(int v,i=head[x];i;i=e[i].next) {
v=e[i].to;
if(v==maxx[x]) continue;
if(v==f) continue;
if(sum[v]>sum[mx]) mx=v;
}
if(sum[f]>sum[mx]) {
mx=f;
}
sum[x]=n-sum[maxx[x]];
st[x][0]=mx;
for(int i=1;i<=18;i++) {
st[x][i]=st[st[x][i-1]][i-1];
}
dfs2(x,maxx[x]);
}
st[x][0]=maxx[x];
for(int i=1;i<=18;i++) {
st[x][i]=st[st[x][i-1]][i-1];
}
sum[x]=1;
for(int v,i=head[x];i;i=e[i].next) {
v=e[i].to;
if(v==f) continue;
sum[x]+=sum[v];
}
return;
}
void wee() {
ans=0;
memset(head,0,sizeof(head));
num=0;
memset(st,0,sizeof(st));
memset(sum,0,sizeof(sum));
memset(maxx,0,sizeof(maxx));
}
void work() {
wee();
cin>>n;
for(int i=1;i<n;i++) {
int a,b;cin>>a>>b;
add(a,b);add(b,a);
}
dfs(0,1);
dfs2(0,1);
cout<<ans<<'\n';
return;
}
signed main() {
int T;cin>>T;
while(T--) {
work();
}
return 0;
}