洛谷P1600
点击查看代码
#include<bits/stdc++.h>
#define N 500005
#define mid (l+r)/2
using namespace std;
int n,m,w[N],ans[N];
vector<int>g[N];
int fa[N][20],dep[N];
int root[N],tot;//root存储每个节点所对应的线段树根节点编号
int ls[55*N],rs[N*55],sum[N*55];
void dfs(int x,int f)
{
fa[x][0]=f;
dep[x]=dep[f]+1;
for(int i=1;i<20;++i)
{
fa[x][i]=fa[fa[x][i-1]][i-1];
}
for(int y:g[x])
{
if (y!=f)
{
dfs(y,x);
}
}
}
void change(int &u,int l,int r,int p,int k)
{
// 如果当前节点为空,创建新节点
if(!u) u=++tot;
if(l==r)
{
sum[u]+=k;
return;
}
if(p<=mid) change(ls[u],l,mid,p,k);//左子树
else change(rs[u],mid+1,r,p,k);//右子树
}
int lca(int x,int y)
{
if(dep[x]<dep[y]) swap(x,y);
for(int i=19;i>=0;--i)
{
if (dep[fa[x][i]]>=dep[y])
{
x=fa[x][i];
}
}
if (x==y) return x;
for (int i=19;i>=0;--i)
{
if (fa[x][i]!=fa[y][i])
{
x=fa[x][i];
y=fa[y][i];
}
}
return fa[x][0];
}
//x和y是两棵树的索引,l是左端点,r是右端点
int merge(int x,int y,int l,int r)//将两棵树合并
{
//merge 函数的返回值是合并后的线段树的根节点索引
if(!x||!y) return x+y;//如果一棵树为空,返回另一棵
if(l==r)
{
sum[x]+=sum[y];
return x;
}
ls[x]=merge(ls[x],ls[y],l,mid);//ls[x]:树x的左子树的根节点索引
//递归合并树x和树y的左子树,下面同理
rs[x]=merge(rs[x],rs[y],mid+1,r);
return x;
}
//query函数的返回值就是合法路径的数目
int query(int u,int l,int r,int p)//查询点
{
/*u:当前线段树节点的索引
l 和 r:当前节点管理的区间范围
p:要查询的目标位置*/
if(l==r) return sum[u];
if(p<=mid) return query(ls[u],l,mid,p);
else return query(rs[u],mid+1,r,p);
}
void dfs2(int x)
{
for(int y:g[x])
{
if(y==fa[x][0]) continue;
dfs2(y);
root[x]=merge(root[x],root[y],1,n<<1);//将子节点 y 的线段树合并到当前节点 x 的线段树中
//当遍历完成后,root[x] 对应的线段树将包含以 x 为根的整个子树的所有路径标记,为后续查询 x 的 ans 值提供完整的数据。
}
if(w[x]&&n+dep[x]+w[x]<=n<<1)//n+dep[x]+w[x]的位置在范围之内
ans[x]+=query(root[x],1,n<<1,n+dep[x]+w[x]);
ans[x]+=query(root[x],1,n<<1,n+dep[x]-w[x]);
}
int main()
{
cin>>n>>m;
int x,y;
for(int i=1;i<n;i++)
{
cin>>x>>y;
g[x].push_back(y);
g[y].push_back(x);//建造树
}
for(int i=1;i<=n;i++)
{
cin>>w[i];
}
dfs(1,0);
for(int i=1;i<=m;i++)
{
cin>>x>>y;
int l=lca(x,y);
//统一加上偏移量,所以并不会造成什么影响
change(root[x],1,n<<1,n+dep[x],1);//+1这个标记代表"有一条路径从这里开始"
change(root[y],1,n<<1,n+2*dep[l]-dep[x],1);//这个标记代表"有一条路径在这里结束"
change(root[l],1,n<<1,n+dep[x],-1);//-1是取消标记
change(root[fa[l][0]],1,n<<1,n+2*dep[l]-dep[x],-1);
}
dfs2(1);
for(int i=1;i<=n;i++)
cout<<ans[i]<<" ";
return 0;
}