【学习笔记】点分树

P6329

分析

把点分治所搜出的重心构建成一棵树,称之为点分树
点分树上x,y的lca,原树中lca一定在x->y的路径上
所以对于每个点u开两颗线段树,一颗表示子树中在原树上与u的距离为k的权值和,一颗表示子树中在原树上与fa[u]的距离为k的权值和
查询的时候向上跳,累加f{fa[u]}-g{u}即可
更新的时候向上跳,直接更新线段树

代码

#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10,logN=log2(N)+1;int n;
int nw[N];
vector<int> mp[N];
int st[logN][N];
int dep[N];
//求原树距离
void dfs(int u,int pa){
	st[0][u]=pa;
	dep[u]=dep[pa]+1;
	for(auto v:mp[u]){
		if(v==pa) continue;
		dfs(v,u);
	}
}
int LCA(int a,int b){
	if(dep[a]<dep[b]) swap(a,b);
	int x=dep[a]-dep[b];
	for(int i=0;x>>i;i++){
		if(x>>i&1) a=st[i][a];
	}
	if(a==b) return a;
	for(int i=logN-1;i>=0;i--){
		if(st[i][a]!=st[i][b]){
			a=st[i][a];
			b=st[i][b];
		}
	}
	return st[0][a];
}
int dis(int x,int y){
	return dep[x]+dep[y]-(dep[LCA(x,y)]<<1);
}
//求原树距离
int cnt,root;
int del[N];
int siz[N];
//构建点分树
void Size(int u,int pa){
	cnt++;
	siz[u]=1;
	for(auto v:mp[u]){
		if(v==pa||del[v]) continue;
		Size(v,u);
		siz[u]+=siz[v];
	}
}
int mson[N];
void Root(int u,int pa){
	mson[u]=cnt-siz[u];
	for(auto v:mp[u]){
		if(v==pa||del[v]) continue;
		mson[u]=max(mson[u],siz[v]);
		Root(v,u);
	}
	if(root==-1||mson[u]<mson[root]) root=u;
}
int fa[N];
void calc(int u,int pa){
	//cerr<<u<<"\n";
	cnt=0;Size(u,u);
	root=-1;Root(u,u);
	u=root;
	fa[u]=pa;
	del[u]=1;
	for(auto v:mp[u]){
		if(del[v]||v==pa) continue;
		calc(v,u);
	}
}
//构建点分树
//对于每个节点维护线段树
struct segtree{
	int rt[N],idx=0;
	struct node{
		int sum;
		int ls,rs;
		#define sum(q) tree[q].sum
		#define ls(q) tree[q].ls
		#define rs(q) tree[q].rs
	}tree[(int)5e6];
	void push_up(int q){
		sum(q)=sum(ls(q))+sum(rs(q));
	}
	void update(int &q,int l,int r,int tp,int d){
		//cerr<<q<<"\n";
		if(!q){
			q=++idx;
		}
		if(l==r){
			sum(q)+=d;
			return ;
		}
		int mid=(l+r)>>1;
		if(tp<=mid) update(ls(q),l,mid,tp,d);
		else update(rs(q),mid+1,r,tp,d);
		push_up(q);
	}
	int query(int q,int l,int r,int L,int R){
		if(!q) return 0;
		if(L<=l&&r<=R){
			return sum(q);
		}
		int mid=(l+r)>>1;
		int res=0;
		if(L<=mid) res+=query(ls(q),l,mid,L,R);
		if(mid<R) res+=query(rs(q),mid+1,r,L,R);
		return res;
	}
}w1,w2;
//对于每个节点维护线段树
//更新/查询
void update(int x,int v){
	int now=x;
	while(now){
		w1.update(w1.rt[now],0,n-1,dis(now,x),v);
		if(fa[now]) w2.update(w2.rt[now],0,n-1,dis(fa[now],x),v);
		now=fa[now];
	}
}
int query(int x,int k){
	int ans=0;
	int now=x,son=0;
	while(now){
		if(dis(now,x)>k){
			son=now;now=fa[now];continue;
		}
		ans+=w1.query(w1.rt[now],0,n-1,0,k-dis(x,now));
		if(son) ans-=w2.query(w2.rt[son],0,n-1,0,k-dis(x,now));
		son=now;now=fa[now];
	}
	return ans;
}
//更新/查询
int main(){
	ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);
	//freopen("P6329_2.in","r",stdin);
	//freopen("std.out","w",stdout);
	int m;cin>>n>>m;
	for(int i=1;i<=n;i++) cin>>nw[i];
	for(int i=1;i<n;i++){
		int a,b;cin>>a>>b;
		mp[a].push_back(b);
		mp[b].push_back(a);
	}
	dfs(1,0);
	for(int k=1;k<logN;k++){
		for(int u=1;u<=n;u++) st[k][u]=st[k-1][st[k-1][u]];
	}
	
	calc(1,0);
	
	for(int i=1;i<=n;i++) update(i,nw[i]);
	int ans=0;
	while(m--){
		int op,x,y;cin>>op>>x>>y;
		x^=ans;y^=ans;
		if(op==0){
			ans=query(x,y);
			cout<<ans<<"\n";
		}
		else{
			update(x,y-nw[x]);
			nw[x]=y;
		}
	}
	//cerr<<double(clock())/CLOCKS_PER_SEC<<"\n";
	return 0;
}
posted @ 2025-12-05 16:09  Ming3398  阅读(1)  评论(0)    收藏  举报