点分治
点分治
树
给定一个有N个点(编号0,1,…,N-1)的树,每条边都有一个权值(不超过1000)。
树上两个节点x与y之间的路径长度就是路径上各条边的权值之和。
求长度不超过K的路径有多少条。
输入格式
输入包含多组测试用例。
每组测试用例的第一行包含两个整数N和K。
接下来N-1行,每行包含三个整数u,v,l,表示节点u与v之间存在一条边,且边的权值为l。
当输入用例N=0,K=0时,表示输入终止,且该用例无需处理。
输出格式
每个测试用例输出一个结果。
每个结果占一行。
数据范围
N≤10000
输入样例:
5 4
0 1 3
0 2 1
0 3 2
2 4 1
0 0
输出样例:
8
题面要求我们统计长度小于等于K的路径数
假如我们以u为根
那么满足条件的路径就可以分为两种:
1.经过u
2.不经过u
对于经过u的部分,我们直接统计
而不经过u的部分,显然就是在删去u之后,形成的各个连通块的相同子问题
于是分治的思路就出来了。
在对数组进行操作时,我们知道每次二分进行操作的效率是很高的,对于本题也可以采用类似的思想。
但由于本题是树,每次我们要保证效率尽量地高,那么我们进行分治的树删除根节点后,分出的子树大小要尽量一致,以保证我们操作的次数尽量地小,从这个思路出发,我们可以很容易想到——对正在操作的树找其重心作为树根。
可以先随意取一点为根,通过\(DFS\)找到每个子树的大小,那么当以该节点为根时,最大子树就应该是它子树中最大的一颗或者是总节点数减去该子树的大小(即该节点朝上走)。
可结合下图理解:
代码为
void find(int u,int p,int tot)
{
Size[u]=1;
mx[u]=0;
for(int i=0;i<Q[u].size();i++)
{
int v=Q[u][i];
if(vis[v]||v==p) continue;
find(v,u,tot);
Size[u]+=Size[v];
mx[u]=max(mx[u],Size[v]);
}
mx[u]=max(mx[u],tot-Size[u]);
if(mx[u]<mx[root]) root=u;
return;
}
找到重心后,我们需要计算一遍以重心为根,这棵树每个节点真正的大小。
可以通过\(DFS\)实现。
void calc(int u,int p)
{
Size[u]=1;
for(int i=0;i<Q[u].size();i++)
{
int v=Q[u][i];
if(vis[v]||v==p) continue;
calc(v,u);
Size[u]+=Size[v];
}
}
考虑如何设计\(Getans(u)\)
我们求出u的第i棵子树中所有链的长度
那么有两种组合方式可能对答案产生贡献:
1.这条链本身就比K短
2.和前i-1棵子树中某条链拼接(当然长度不能大于K)
第一种直接暴力就好
第二种开个树状数组枚举大力统计就行了
但是由于可能存在零边,我们需要加一个偏移量,把零边变成非零边。
void getdis(int u,int p)
{
//将距离加入数组
ls[++tt]=dis[u];
for(int i=0;i<Q[u].size();i++)
{
int v=Q[u][i],w=W[u][i];
if(vis[v]||v==p) continue;
dis[v]=dis[u]+w;
getdis(v,u);
}
}
void getans(int u)
{
dis[u]=0;
int q=0;
//寻找每个子树
for(int i=0;i<Q[u].size();i++)
{
int v=Q[u][i],w=W[u][i];
if(vis[v]) continue;
tt=0;
dis[v]=dis[u]+w;
//统计子树内的每个节点到u的距离
getdis(v,u);
for(int j=1;j<=tt;++j)
{
//将零边变为非零边
++ls[j];
//两个边都被加了1,所以K也应该+2
int nex=k+2-ls[j];
//如果该条边本就符合条件,则ans+1
ans+=know(nex)+(ls[j]<=k+1);
}
for(int j=1;j<=tt;++j)
{
//将本子树的所有距离加入树状数组
ql[++q]=ls[j];
add(ls[j],1);
}
}
//还原树状数组
for(int i=1;i<=q;++i) add(ql[i],-1);
}
完整代码:
#include<bits/stdc++.h>
using namespace std;
#define INF 0x3fffffff
#define maxn 10005
int tot,ans,n,k,root,tt,Size[maxn],mx[maxn],dis[maxn];
int ql[maxn],ls[maxn],c[maxn*1000];
// while.clear()root.clear()
bool vis[maxn];
vector<int>Q[maxn],W[maxn];
int know(int x)
{
if(x<=0) return 0;
int res=0;
for(;x;x-=x&-x) res+=c[x];
return res;
}
void add(int x,int y)
{
for(;x<=n*1000;x+=x&-x) c[x]+=y;
}
void find(int u,int p,int tot)
{
Size[u]=1;
mx[u]=0;
for(int i=0;i<Q[u].size();i++)
{
int v=Q[u][i];
if(vis[v]||v==p) continue;
find(v,u,tot);
Size[u]+=Size[v];
mx[u]=max(mx[u],Size[v]);
}
mx[u]=max(mx[u],tot-Size[u]);
if(mx[u]<mx[root]) root=u;
return;
}
void calc(int u,int p)
{
Size[u]=1;
for(int i=0;i<Q[u].size();i++)
{
int v=Q[u][i];
if(vis[v]||v==p) continue;
calc(v,u);
Size[u]+=Size[v];
}
}
void getdis(int u,int p)
{
ls[++tt]=dis[u];
for(int i=0;i<Q[u].size();i++)
{
int v=Q[u][i],w=W[u][i];
if(vis[v]||v==p) continue;
dis[v]=dis[u]+w;
getdis(v,u);
}
}
void getans(int u)
{
dis[u]=0;
int q=0;
for(int i=0;i<Q[u].size();i++)
{
int v=Q[u][i],w=W[u][i];
if(vis[v]) continue;
tt=0;
dis[v]=dis[u]+w;
getdis(v,u);
for(int j=1;j<=tt;++j)
{
++ls[j];
int nex=k+2-ls[j];
ans+=know(nex)+(ls[j]<=k+1);
}
for(int j=1;j<=tt;++j)
{
ql[++q]=ls[j];
add(ls[j],1);
}
}
for(int i=1;i<=q;++i) add(ql[i],-1);
}
void devide(int u)
{
vis[u]=1;
getans(u);
for(int i=0;i<Q[u].size();i++)
{
int v=Q[u][i];
if(vis[v]) continue;
mx[root=v]=INF;
find(v,u,Size[v]);
calc(v,u);
devide(v);
}
}
int main()
{
while(~scanf("%d %d",&n,&k))
{
if(!n&&!k) return 0;
ans=0;
tot=0;
for(int i=0;i<n;++i) {vis[i]=0;Q[i].clear();W[i].clear();}
for(int i=1,a,b,c;i<n;++i)
{
scanf("%d %d %d",&a,&b,&c);
Q[b].push_back(a);
Q[a].push_back(b);
W[a].push_back(c);
W[b].push_back(c);
}
mx[root=0]=INF;
find(root,0,n);
calc(root,0);
devide(root);
printf("%d\n",ans);
}
}

浙公网安备 33010602011771号