点分治
- 对于解决树上k距离、与路径和最值等路径问题的有力工具。虽然dsu on tree也可解决大部分点分治题目,但点分治思路相对固定,思维难度较低。缺点是在与其他数据结构结合时代码量极大,同时并不好调,需要注意细节并且写熟练。
- 点分治作为一种分治算法,与序列分治有异曲同工之妙。序列分治是取l与r之间的mid来判断与统计答案,而树上的“mid”被称为“重心”,是树中最大子树最小的节点。不过,与序列二分不同的是,重心是用来统计其统计范围内经过这个重心的树上路径的答案,而如何统计也成为了点分治中的重点。
【模板】点分治
- 以此题为例来介绍点分治中的几个重点板块。
- 题意:给定一棵有 n 个点的树,询问树上距离为 k 的点对是否存在
- 题解:树上路径问题,考虑点分治。将某重心子树中的路径分为两类:经过当前重心的,不经过的。经过的现在处理,不经过的继续分重心递归求解
- 求重心
- 事实上求中心就是遍历一遍子树,统计每个节点最大子树最小的节点
void getroot(int u,int fa,int sumsiz)
//sumsiz->当前统计区域的总节点数。因为除u节点的每个子树外,其父亲及以上的节点也算做统计重心时的“子树”
{
maxsiz[u]=0,siz[u]=1;
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(v==fa||vis[v]) continue;
getroot(v,u,sumsiz);
maxsiz[u]=max(siz[v],maxsiz[u]);
siz[u]+=siz[v];
}
maxsiz[u]=max(sumsiz-siz[u],maxsiz[u]);
if(maxsiz[u]<maxsiz[zx]) zx=u;
}
- 分治
- dfs整棵树,求重心后分而治之,统计经过重心的答案
void solve(int u)
{
vis[u]=1,tong[0]=1;cacl(u);
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(vis[v]) continue;
maxsiz[0]=n,zx=0;
getroot(v,0,siz[v]);solve(zx);
}
}
- 统计重心的答案
- 点分治考察的重点,其他的部分都比较固定。同时,需要极其注重其中的细节。
- 对于这道题,考虑开数组rev存重心当前处理的子树中所有节点与重心存在的距离。再存个桶,存已经处理了的子树中的所有距离,便可\(O(1)\)处理所有经过重心的可能距离。同时注意先处理后再将rev存入桶中,细节详见代码。
void cacl(int u)
{
int c=0;
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(vis[v]) continue;
cnt=0,dis[v]=q[u][i].w;getdis(v,u); //求v及其子树所有可能的距离,存在rev中,详见后文
for(int j=1;j<=cnt;j++)
for(int k=1;k<=m;k++)
if(query[k]>=rev[j])
ans[k] |= tong[query[k]-rev[j]];
//统计答案,对于每个询问,若桶中存有与rev匹配(和为query)的值,则说明此询问有解。
//或等于不影响之前的答案。
for(int j=1;j<=cnt;j++)
{
qq[++c]=rev[j];
if(rev[j]<=10000000)
//注意越界,超过后由于询问都没这么大,一定不会有匹配的值了,直接舍弃
tong[rev[j]]=1;
//统计完答案后将rev存入桶中,若先存入桶中,会出现自己与自己匹配的情况。
}
}
for(int i=1;i<=c;i++)
tong[qq[i]]=0; //注意记得将桶清空
}
- 求重心子树中每个节点与中心的距离
- 将新加入的节点求距离后加入rev中递归继续求解
void getdis(int u,int fa)
{
rev[++cnt]=dis[u];
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(v==fa||vis[v]) continue;
dis[v]=dis[u]+q[u][i].w;
getdis(v,u);
}
}
-
主函数套路非常固定,在此不多赘述
-
但事实上,这样求重心并不全对,但并不影响复杂度
-
完整代码
#include<bits/stdc++.h>
using namespace std;
const int N=10000005;
const int M=4e4+10;
int siz[N],tot[N],maxsiz[N],zx,n,cnt=0,rev[N],tong[N],m,ans[N],qq[N],query[N],dis[N];
bool vis[N];
struct node
{
int v,w;
};
vector <node> q[N];
void getroot(int u,int fa,int sumsiz)
{
maxsiz[u]=0,siz[u]=1;
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(v==fa||vis[v]) continue;
getroot(v,u,sumsiz);
maxsiz[u]=max(siz[v],maxsiz[u]);
siz[u]+=siz[v];
}
maxsiz[u]=max(sumsiz-siz[u],maxsiz[u]);
if(maxsiz[u]<maxsiz[zx]) zx=u;
}
void getdis(int u,int fa)
{
rev[++cnt]=dis[u];
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(v==fa||vis[v]) continue;
dis[v]=dis[u]+q[u][i].w;
getdis(v,u);
}
}
void cacl(int u)
{
int c=0;
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(vis[v]) continue;
cnt=0,dis[v]=q[u][i].w;getdis(v,u);
for(int j=1;j<=cnt;j++)
for(int k=1;k<=m;k++)
if(query[k]>=rev[j])
ans[k] |= tong[query[k]-rev[j]];
for(int j=1;j<=cnt;j++)
{
qq[++c]=rev[j];
if(rev[j]<=10000000)
tong[rev[j]]=1;
}
}
for(int i=1;i<=c;i++)
tong[qq[i]]=0;
}
void solve(int u)
{
vis[u]=1,tong[0]=1;cacl(u);
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(vis[v]) continue;
maxsiz[0]=n,zx=0;
getroot(v,0,siz[v]);solve(zx);
}
}
int main()
{
cin>>n>>m;
for(int i=1,u,v,w;i<n;i++)
{
cin>>u>>v>>w;
q[u].push_back({v,w}),tot[u]++;
q[v].push_back({u,w}),tot[v]++;
}
for(int i=1;i<=m;i++) cin>>query[i];
maxsiz[0]=n;
getroot(1,0,n);solve(zx);
for(int i=1;i<=m;i++)
if(ans[i]) cout<<"AYE"<<endl;
else cout<<"NAY"<<endl;
}
注意细节!注意细节!注意细节!重要的事情说三遍!!!
点分治例题
- 由于点分治代码很相似,在此只叙述主要思路以及统计答案的过程
Tree
- 题意:求树上两点间边权和小于等于k的路径数量
- 统计答案:对于每个重心,由于是求小于等于k的路径数,因此直接抛弃掉桶,将每种路径长存在rev中,将rev排序,有些取巧地用双指针l,r,表示rev[l+1]到rev[r]之间所有边(不包括l本身)都可以与rev[l]匹配满足小于等于k
int cacl(int u,int w)
{
cnt=0;dis[u]=w;getdis(u,0);
int l=1,r=cnt,res=0;
sort(rev+1,rev+cnt+1);
while(l<=r)
{
if(rev[l]+rev[r]<=k) res+=r-l,l++;
//rev[l]能与rev[l+1]到rev[r]之间的所有数匹配,方案数+r-l.
else r--;
}
return res;
}
总的代码
#include<bits/stdc++.h>
using namespace std;
const int N=1000005;
int n,tot[N],k,zx,maxsiz[N],siz[N],ans=0,rev[N],cnt=0,dis[N];
bool vis[N];
struct node
{
int v,w;
};
vector <node> q[N];
void getzx(int u,int fa,int sumsiz)
{
maxsiz[u]=0,siz[u]=1;
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(v==fa||vis[v]) continue;
getzx(v,u,sumsiz);
siz[u]+=siz[v];
maxsiz[u]=max(maxsiz[u],siz[v]);
}
// cout<<u<<' '<<siz[u]<<endl;
maxsiz[u]=max(maxsiz[u],sumsiz-siz[u]);
if(maxsiz[u]<maxsiz[zx]) zx=u;
}
void getdis(int u,int fa)
{
rev[++cnt]=dis[u];
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v,w=q[u][i].w;
if(vis[v]||v==fa) continue;
dis[v]=dis[u]+w;
getdis(v,u);
}
}
int cacl(int u,int w)
{
cnt=0;dis[u]=w;getdis(u,0);
int l=1,r=cnt,res=0;
sort(rev+1,rev+cnt+1);
while(l<=r)
{
if(rev[l]+rev[r]<=k) res+=r-l,l++;
else r--;
}
return res;
}
void solve(int u)
{
vis[u]=1;ans+=cacl(u,0);
// cout<<u<<endl;
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v,w=q[u][i].w;
if(vis[v]) continue;
ans-=cacl(v,w);
maxsiz[0]=siz[v],zx=0;
getzx(v,u,siz[v]);solve(zx);
}
}
int main()
{
cin>>n;
for(int i=1,u,v,w;i<n;i++)
{
cin>>u>>v>>w;
q[u].push_back({v,w}),tot[u]++;
q[v].push_back({u,w}),tot[v]++;
}
cin>>k;
maxsiz[0]=n+1,zx=0;
getzx(1,0,n);
// cout<<zx<<endl;
solve(zx);
cout<<ans<<endl;
}
树的难题
- 题意:给一棵树,每个点有颜色,每种颜色有权值,连续相同颜色权值和为其颜色权值本身。求边数在L到R之间的路径的最大权值和。
- 统计答案:同样的分治求重心,然而不同颜色的最大权值和并不好统计,同时还要考虑重心的不同与相同颜色分界的问题。考虑对重心的所有v按颜色编号排序,使所有颜色一样的边在一起连续统计。而对已经处理过的边维护最大权值,考虑线段树。
- 建两棵线段树,一棵维护当前相同颜色但已经处理过的边的最大权值,一棵维护其他已经处理且颜色与当前边不同的最大权值。
- 同时建一个栈,将相同颜色集中处理完后,再将这些节点全部扔进维护颜色不同的线段树中并将维护相同颜色的线段树清空。
- 线段树
struct tree
{
#define ls ((u<<1)) //左儿子
#define rs ((u<<1)+1)//右儿子
struct edge
{
int maxnum,tag; //最大权值、是否清空的标记
}tr[N<<3];
void clear() //清空
{
tr[1].tag=1;tr[1].maxnum=-inf;
}
void push_down(int u)//下放清空标记
{
if(tr[u].tag)
{
tr[ls].tag=tr[rs].tag=1;
tr[ls].maxnum=tr[rs].maxnum=-inf;
tr[u].tag=0;
}
}
void push_up(int u)
{
tr[u].maxnum=max(tr[ls].maxnum,tr[rs].maxnum);
}
int query(int u,int l,int r,int nl,int nr)//区间询问最大值
{
if(l>nr||r<nl) return -inf;
if(l>=nl&&r<=nr) return tr[u].maxnum;
if(tr[u].tag) return -inf;
push_down(u);int mid=(l+r)>>1;
return max(query(ls,l,mid,nl,nr),query(rs,mid+1,r,nl,nr));
}
void change(int u,int l,int r,int p,int x)//单点修改
{
if(l>p||r<p) return;
if(l==r)
{
tr[u].tag=0;
tr[u].maxnum=max(tr[u].maxnum,x);
return;
}
push_down(u);
int mid=(l+r)>>1;
if(p<=mid) change(ls,l,mid,p,x);
else change(rs,mid+1,r,p,x);
push_up(u);
}
}diff,same;
- 统计答案
void getdis(int u,int fa,int val,int lastcol)
//这里的getdis是用来统计答案的,同时处理dep满足L与R的限制
{
if(u==fa) return;
dep[u]=dep[fa]+1;
// cout<<"getdis "<<u<<" fa is "<<fa<<" dep is "<<dep[u]<<" val is "<<val<<endl;
if(dep[u]>R) return;
if(dep[u]>=L&&dep[u]<=R) ans=max(ans,val);
// cout<<"change ans "<<ans;
ans=max({ans,val+same.query(1,0,n,max(0,L-dep[u]),R-dep[u])-c[nowcolor],val+diff.query(1,0,n,max(L-dep[u],0),R-dep[u])});
// cout<<" to "<<ans<<endl;
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(v==fa||vis[v]) continue;
if(q[u][i].color==lastcol) getdis(v,u,val,lastcol);
else getdis(v,u,val+c[q[u][i].color],q[u][i].color);
}
}
void addsame(int u,int fa,int val,int lastcolor) //same树中插入权值
{
dep[u]=dep[fa]+1;
if(dep[u]>R) return;
same.change(1,0,n,dep[u],val);
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(v==fa||vis[v]) continue;
if(q[u][i].color==lastcolor) addsame(v,u,val,lastcolor);//区分同色与异色
else addsame(v,u,val+c[q[u][i].color],q[u][i].color);
}
}
void adddiff(int u,int fa,int val,int lastcolor)//diff树中插入权值
{
dep[u]=dep[fa]+1;
if(dep[u]>R) return;
diff.change(1,0,n,dep[u],val);
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(v==fa||vis[v]) continue;
if(q[u][i].color==lastcolor) adddiff(v,u,val,lastcolor);//同same
else adddiff(v,u,val+c[q[u][i].color],q[u][i].color);
}
}
//这里cacl 用来处理重心所有边
void cacl(int u)//记得在主函数中排序
{
// cout<<"cacl "<<u<<endl;
diff.clear(),same.clear();//注意清空
dep[u]=0; //一定别忘了初始化
int top=0;
for(int i=0;i<tot[u];i++)
{
dep[u]=0;
int v=q[u][i].v;
if(vis[v]||v==u) continue;
// cout<<v<<' '<<q[u][i].color<<endl;
if(i==0||q[u][i].color==q[u][i-1].color)
{
sta[++top]=v; //同色节点入栈
// cout<<top<<endl;
continue;
}
nowcolor=q[u][i-1].color;//记录颜色
for(int j=1;j<=top;j++)
{
//注意,与模板题相同,先统计答案再插入,避免自己匹配自己
getdis(sta[j],u,c[q[u][i-1].color],q[u][i-1].color);
addsame(sta[j],u,c[q[u][i-1].color],q[u][i-1].color);
}
same.clear();//记得清空
for(int j=1;j<=top;j++)
adddiff(sta[j],u,c[q[u][i-1].color],q[u][i-1].color);//扔进diff树中
top=0;
sta[++top]=q[u][i].v;
}
same.clear();
nowcolor=q[u][tot[u]-1].color;//注意,最后一种颜色统计不到,要单独处理
for(int j=1;j<=top;j++)
{
getdis(sta[j],u,c[q[u][tot[u]-1].color],q[u][tot[u]-1].color);
addsame(sta[j],u,c[q[u][tot[u]-1].color],q[u][tot[u]-1].color);
}
top=0;
diff.clear(),same.clear();//保险
}
- 看看用来调试的注释数量,就知道点分治不是什么好写好调的东西,细节真的多,注释全是坑啊
总的代码
#include<bits/stdc++.h>
using namespace std;
const int N=8e6+10;
const int inf=2e9+10;
typedef long long ll;
int c[N],n,m,L,R,tot[N],ans=-inf,zx=0,maxp[N],siz[N],dep[N],sta[N],nowcolor;
bool vis[N];
struct node
{
int v,color;
};
vector <node> q[N];
struct tree
{
#define ls ((u<<1))
#define rs ((u<<1)+1)
struct edge
{
int maxnum,tag;
}tr[N<<3];
void clear()
{
tr[1].tag=1;tr[1].maxnum=-inf;
}
void push_down(int u)
{
if(tr[u].tag)
{
tr[ls].tag=tr[rs].tag=1;
tr[ls].maxnum=tr[rs].maxnum=-inf;
tr[u].tag=0;
}
}
void push_up(int u)
{
tr[u].maxnum=max(tr[ls].maxnum,tr[rs].maxnum);
}
int query(int u,int l,int r,int nl,int nr)
{
if(l>nr||r<nl) return -inf;
if(l>=nl&&r<=nr) return tr[u].maxnum;
if(tr[u].tag) return -inf;
push_down(u);int mid=(l+r)>>1;
return max(query(ls,l,mid,nl,nr),query(rs,mid+1,r,nl,nr));
}
void change(int u,int l,int r,int p,int x)
{
if(l>p||r<p) return;
if(l==r)
{
tr[u].tag=0;
tr[u].maxnum=max(tr[u].maxnum,x);
return;
}
push_down(u);
int mid=(l+r)>>1;
if(p<=mid) change(ls,l,mid,p,x);
else change(rs,mid+1,r,p,x);
push_up(u);
}
}diff,same;
bool cmp(node x,node y){return x.color<y.color;}
void getroot(int u,int fa,int sizsum)
{
// cout<<"getroot "<<u<<' '<<fa<<endl;
siz[u]=1,maxp[u]=0;
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(v==fa||vis[v]) continue;
getroot(v,u,sizsum);
siz[u]+=siz[v];
maxp[u]=max(maxp[u],siz[v]);
}
maxp[u]=max(maxp[u],sizsum-siz[u]);
// cout<<"maxp u "<<u<<' '<<maxp[u]<<endl;
// cout<<"maxp zx "<<zx<<' '<<maxp[zx]<<endl;
// cout<<"siz "<<u<<' '<<siz[u]<<" maxp "<<maxp[u]<<endl;
if(maxp[zx]>maxp[u])
{
zx=u;
// cout<<"changezx "<<u<<endl;
}
}
void getdis(int u,int fa,int val,int lastcol)
{
if(u==fa) return;
dep[u]=dep[fa]+1;
// cout<<"getdis "<<u<<" fa is "<<fa<<" dep is "<<dep[u]<<" val is "<<val<<endl;
if(dep[u]>R) return;
if(dep[u]>=L&&dep[u]<=R) ans=max(ans,val);
// cout<<"change ans "<<ans;
ans=max({ans,val+same.query(1,0,n,max(0,L-dep[u]),R-dep[u])-c[nowcolor],val+diff.query(1,0,n,max(L-dep[u],0),R-dep[u])});
// cout<<" to "<<ans<<endl;
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(v==fa||vis[v]) continue;
if(q[u][i].color==lastcol) getdis(v,u,val,lastcol);
else getdis(v,u,val+c[q[u][i].color],q[u][i].color);
}
}
void addsame(int u,int fa,int val,int lastcolor)
{
dep[u]=dep[fa]+1;
if(dep[u]>R) return;
same.change(1,0,n,dep[u],val);
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(v==fa||vis[v]) continue;
if(q[u][i].color==lastcolor) addsame(v,u,val,lastcolor);
else addsame(v,u,val+c[q[u][i].color],q[u][i].color);
}
}
void adddiff(int u,int fa,int val,int lastcolor)
{
dep[u]=dep[fa]+1;
if(dep[u]>R) return;
diff.change(1,0,n,dep[u],val);
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(v==fa||vis[v]) continue;
if(q[u][i].color==lastcolor) adddiff(v,u,val,lastcolor);
else adddiff(v,u,val+c[q[u][i].color],q[u][i].color);
}
}
void cacl(int u)
{
// cout<<"cacl "<<u<<endl;
diff.clear(),same.clear();
dep[u]=0;
int top=0;
for(int i=0;i<tot[u];i++)
{
dep[u]=0;
int v=q[u][i].v;
if(vis[v]||v==u) continue;
// cout<<v<<' '<<q[u][i].color<<endl;
if(i==0||q[u][i].color==q[u][i-1].color)
{
sta[++top]=v;
// cout<<top<<endl;
continue;
}
nowcolor=q[u][i-1].color;
for(int j=1;j<=top;j++)
{
getdis(sta[j],u,c[q[u][i-1].color],q[u][i-1].color);
addsame(sta[j],u,c[q[u][i-1].color],q[u][i-1].color);
}
same.clear();
for(int j=1;j<=top;j++)
adddiff(sta[j],u,c[q[u][i-1].color],q[u][i-1].color);
top=0;
sta[++top]=q[u][i].v;
}
same.clear();
nowcolor=q[u][tot[u]-1].color;
for(int j=1;j<=top;j++)
{
getdis(sta[j],u,c[q[u][tot[u]-1].color],q[u][tot[u]-1].color);
addsame(sta[j],u,c[q[u][tot[u]-1].color],q[u][tot[u]-1].color);
}
top=0;
diff.clear(),same.clear();
}
void solve(int u)
{
// cout<<"solve "<<u<<endl;
vis[u]=true;cacl(u);
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(vis[v]) continue;
zx=0,maxp[zx]=n;
//cout<<u<<' '<<v<<endl;
getroot(v,u,siz[v]);
// cout<<"final zx "<<zx<<endl;
// if(zx==0) continue;
solve(zx);
}
}
int main()
{
diff.tr[0].maxnum=-inf;same.tr[0].maxnum=-inf;
cin>>n>>m>>L>>R;
for(int i=1;i<=m;i++) scanf("%d",&c[i]);
for(int i=1,u,v,color;i<=n-1;i++)
{
scanf("%d%d%d",&u,&v,&color);
q[u].push_back({v,color}),tot[u]++;
q[v].push_back({u,color}),tot[v]++;
}
for(int i=1;i<=n;i++) sort(q[i].begin(),q[i].end(),cmp);
zx=0,maxp[0]=n;getroot(1,0,n);
// cout<<"final zx "<<zx<<endl;
solve(zx);
cout<<ans<<endl;
return 0;
}
/*
8 4 3 4
-7 9 6 1
1 2 1
1 3 2
1 4 1
2 5 1
5 6 2
3 7 1
3 8 3
*/
Tree MST
- 题意:给定一棵 \(n\) 个节点的树,现有有一张完全图,两点 \(x,y\) 之间的边长为 \(w_{x}+w_{y}+dis_{x,y}\),其中 \(dis\) 表示给定的树上两点间的距离。求这个完全图的最小生成树。
\(n\le2e5\) - 这题乍一看与点分治没有任何关系,由于完全图上点关系相当复杂同时边数过多,存与算都不可能,因此考虑怎么拆分一下问题。
这里有一个结论
对于完全图 \((V,E)\),将 \(E\) 分成 \(E_1,E_2,\cdots, E_k(E_1 \cup E_2 \cup \cdots \cup E_k=E)\)对每个边集求最小生成树(MST),对于新的图再求MST,两步所组成的点集的并等同于直接对原图求最小生成树。 - 由于 \((u,v)=w_{u}+w_{v}+dis_{u,v}=(w_u+dis_{u,lca_{u,v}})+(w_v+dis_{v,lca_{u,v}})\),由此想到点分治,对于每个 \(lca\) 即分治中心,求出每个点的\(w_u+dis_{u}\) (这里的 \(dis_u\) 是相对于分治中心而言的),由于新图是完全图,所以对于这部分边集,最小生成树一定是其中最短的边与其他所有边连起来后所组成的边集。因此将得到的所有边排序后将最小的边作为一条,将其与其他所有点组合后放入一个vector中。但其实有些问题,就是可能有些重复的边。我们原来处理点分治的时候都是先统计一个子树中的再与其它的子树中的答案比较或合并,这样就可以避免重复。但这里显然相当暴力
void cacl(int u)
{
t.clear();dis[u]=0;cnt=0,getdis(u,0);
sort(t.begin(),t.end(),cmp);
int f=t[0];
for(int i=1;i<cnt;i++) ku.add(f,t[i],dis[f]+dis[t[i]]+a[f]+a[t[i]]);
}
- 这样就会有很多重复的边被加到vetcor中。至于为什么可以这样干,待会合并的时候再说。
- 合并时,由于我们已经将边集由若干个分治中心分开后求出了这些边集单独的最小生成树,其实就是相对较优的一堆边。按照上面的结论,我们直接合并就可以得到原本的最小生成树了。直接考虑贪心,由于原本的最小生成树的算法本身也就是贪心,因此我们直接对算出来的边集进行排序,再通过一个并查集来维护是否两个点己经联通就可以了。
struct eee
{
int len=0,fa[N];
vector <edge> e;
void add(int u,int v,int w)
{
e.push_back({u,v,w}),len++;
}
int get(int x)
{
if(x==fa[x]) return x;
return fa[x]=get(fa[x]);
}
int solve()
{
for(int i=1;i<=n;i++) fa[i]=i;
int ans=0;
sort(e.begin(),e.end(),cmp1);
for(int i=0;i<len;i++)
{
if(get(e[i].u)==get(e[i].v)) continue;
ans+=e[i].w;fa[get(e[i].u)]=get(e[i].v);
}
return ans;
}
}ku;
- 时间复杂度实际上相当神奇。点分治时,时间是传统的 \(O(nlog_{2}n)\),因此,一共最多也只有 \(nlog_2n\)条边加入vector中。在合并的时候,由于每个点都只会加入一次并查集,因此复杂度是均摊 \(O(1)\) 的
大概,因此时间的上界就是排序。对于 \(nlong_2n\) 条边的排序时间复杂度应该是 \(O(nlog_2^2n)\) 的,因此总复杂度也就是\(O(nlog_2^2n)\) ,时间有5s相当充裕。其实应该是跑不满的,因为最慢的点都只有1s
总的代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=5e5+10;
struct node
{
int v,w;
};
vector <node> q[N];
int n,a[N],tot[N],zx,siz[N],maxsiz[N],dis[N],cnt=0;
bool vis[N];
vector <int> t;
struct edge
{
int u,v,w;
};
bool cmp1(edge x,edge y)
{
return x.w<y.w;
}
struct eee
{
int len=0,fa[N];
vector <edge> e;
void add(int u,int v,int w)
{
e.push_back({u,v,w}),len++;
}
int get(int x)
{
if(x==fa[x]) return x;
return fa[x]=get(fa[x]);
}
int solve()
{
for(int i=1;i<=n;i++) fa[i]=i;
int ans=0;
sort(e.begin(),e.end(),cmp1);
for(int i=0;i<len;i++)
{
if(get(e[i].u)==get(e[i].v)) continue;
ans+=e[i].w;fa[get(e[i].u)]=get(e[i].v);
}
return ans;
}
}ku;
void getzx(int u,int fa,int sumsiz)
{
siz[u]=1;maxsiz[u]=0;
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(vis[v]||v==fa) continue;
getzx(v,u,sumsiz);siz[u]+=siz[v];
if(siz[v]>maxsiz[u]) maxsiz[u]=siz[v];
}
if(sumsiz-siz[u]>maxsiz[u]) maxsiz[u]=sumsiz-siz[u];
if(maxsiz[u]<maxsiz[zx]) zx=u;
}
void getdis(int u,int fa)
{
t.push_back(u);cnt++;
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v,w=q[u][i].w;
if(vis[v]||v==fa) continue;
dis[v]=dis[u]+w;
getdis(v,u);
}
}
bool cmp(int x,int y)
{
return dis[x]+a[x]<dis[y]+a[y];
}
void cacl(int u)
{
t.clear();dis[u]=0;cnt=0,getdis(u,0);
sort(t.begin(),t.end(),cmp);
int f=t[0];
for(int i=1;i<cnt;i++) ku.add(f,t[i],dis[f]+dis[t[i]]+a[f]+a[t[i]]);
}
void solve(int u)
{
cacl(u);vis[u]=1;
for(int i=0;i<tot[u];i++)
{
int v=q[u][i].v;
if(vis[v]) continue;zx=0;
getzx(v,0,siz[v]);
solve(zx);
}
}
signed main()
{
ios::sync_with_stdio(false);cin.tie(0),cout.tie(0);
cin>>n;
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1,u,v,w;i<=n-1;i++)
{
cin>>u>>v>>w;
q[u].push_back({v,w}),tot[u]++;
q[v].push_back({u,w}),tot[v]++;
}
maxsiz[0]=n;getzx(1,0,n);
solve(zx);
cout<<ku.solve()<<endl;
return 0;
}

浙公网安备 33010602011771号