洛谷P1600

洛谷P1600

点击查看代码
#include<bits/stdc++.h>
#define N 500005
#define mid (l+r)/2
using namespace std;
int n,m,w[N],ans[N];
vector<int>g[N];
int fa[N][20],dep[N];
int root[N],tot;//root存储每个节点所对应的线段树根节点编号
int ls[55*N],rs[N*55],sum[N*55];
void dfs(int x,int f) 
{
    fa[x][0]=f;
    dep[x]=dep[f]+1;
    for(int i=1;i<20;++i) 
	{
        fa[x][i]=fa[fa[x][i-1]][i-1];
    }
    for(int y:g[x]) 
	{
        if (y!=f) 
		{
            dfs(y,x);
        }
    }
}
void change(int &u,int l,int r,int p,int k)
{
	// 如果当前节点为空,创建新节点
	if(!u) u=++tot;
	if(l==r) 
	{
		sum[u]+=k;
		return;
	}
	if(p<=mid) change(ls[u],l,mid,p,k);//左子树
	else change(rs[u],mid+1,r,p,k);//右子树
}
int lca(int x,int y) 
{
    if(dep[x]<dep[y]) swap(x,y);
    for(int i=19;i>=0;--i) 
	{
        if (dep[fa[x][i]]>=dep[y]) 
		{
            x=fa[x][i];
        }
    }
    if (x==y) return x;
    for (int i=19;i>=0;--i) 
	{
        if (fa[x][i]!=fa[y][i]) 
		{
            x=fa[x][i];
            y=fa[y][i];
        }
    }
    return fa[x][0];
}
//x和y是两棵树的索引,l是左端点,r是右端点
int merge(int x,int y,int l,int r)//将两棵树合并
{
	//merge 函数的返回值是合并后的线段树的根节点索引
	if(!x||!y) return x+y;//如果一棵树为空,返回另一棵
	if(l==r) 
	{
		sum[x]+=sum[y];
		return x;
	}
	ls[x]=merge(ls[x],ls[y],l,mid);//ls[x]:树x的左子树的根节点索引
	//递归合并树x和树y的左子树,下面同理
	rs[x]=merge(rs[x],rs[y],mid+1,r);
	return x;
}
//query函数的返回值就是合法路径的数目
int query(int u,int l,int r,int p)//查询点
{
	/*u:当前线段树节点的索引
	l 和 r:当前节点管理的区间范围
	p:要查询的目标位置*/
	if(l==r) return sum[u];
	if(p<=mid) return query(ls[u],l,mid,p);
	else return query(rs[u],mid+1,r,p);
}
void dfs2(int x)
{
	for(int y:g[x])
	{
		if(y==fa[x][0]) continue;
		dfs2(y);
		root[x]=merge(root[x],root[y],1,n<<1);//将子节点 y 的线段树合并到当前节点 x 的线段树中
	//当遍历完成后,root[x] 对应的线段树将包含以 x 为根的整个子树的所有路径标记,为后续查询 x 的 ans 值提供完整的数据。
	}
	if(w[x]&&n+dep[x]+w[x]<=n<<1)//n+dep[x]+w[x]的位置在范围之内
	ans[x]+=query(root[x],1,n<<1,n+dep[x]+w[x]);
	ans[x]+=query(root[x],1,n<<1,n+dep[x]-w[x]);
}
int main()
{
	cin>>n>>m;
	int x,y;
	for(int i=1;i<n;i++)
	{
		cin>>x>>y;
		g[x].push_back(y);
		g[y].push_back(x);//建造树
	}
	for(int i=1;i<=n;i++)
	{
		cin>>w[i];
	}
	dfs(1,0);
	for(int i=1;i<=m;i++)
	{
		cin>>x>>y;
		int l=lca(x,y);
		//统一加上偏移量,所以并不会造成什么影响
		change(root[x],1,n<<1,n+dep[x],1);//+1这个标记代表"有一条路径从这里开始"
		change(root[y],1,n<<1,n+2*dep[l]-dep[x],1);//这个标记代表"有一条路径在这里结束"
		change(root[l],1,n<<1,n+dep[x],-1);//-1是取消标记
		change(root[fa[l][0]],1,n<<1,n+2*dep[l]-dep[x],-1);
	}
	dfs2(1);
	for(int i=1;i<=n;i++)
	cout<<ans[i]<<" ";
	return 0;
}


posted @ 2025-08-27 21:24  Lucian2007  阅读(6)  评论(0)    收藏  举报