省选集训 9 - 树上技巧

[NOI2021] 轻重边

路径上信息考虑树剖,每次 \(op=1\) 将路径上点染成新颜色,然后 \(op=2\) 查询路径同色相邻点对。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define N 100005
vector<int> v[N];
int t,n,m,dfn,son[N],pre[N],dep[N],sz[N],top[N],ls[N],rs[N];
void dfs1(int x,int fa){
	sz[x]=1,pre[x]=fa,dep[x]=dep[fa]+1,son[x]=0;
	for(auto y:v[x]){
		if(y==fa)  continue;
		dfs1(y,x),sz[x]+=sz[y];
		if(sz[y]>sz[son[x]]) son[x]=y;
	}
}
void dfs2(int x,int fa){
	ls[x]=++dfn;
	if(son[x])  top[son[x]]=top[x],dfs2(son[x],x);
	for(auto y:v[x])  if(y!=fa&&y!=son[x])  top[y]=y,dfs2(y,x);
	rs[x]=dfn;
}
struct Segment_tree{
	int tr[N<<1],tag[N<<1],lc[N<<1],rc[N<<1],sz[N<<1];
	void pushdown(int p,int ls,int rs){
		if(tag[p]==-1)  return;
		lc[ls]=rc[ls]=tag[ls]=tag[p],tr[ls]=sz[ls]-1;
		lc[rs]=rc[rs]=tag[rs]=tag[p],tr[rs]=sz[rs]-1,tag[p]=-1;
	}
	void build(int l,int r,int p){
		sz[p]=r-l+1,tr[p]=lc[p]=rc[p]=0,tag[p]=-1;
		if(l==r)  return;
		int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;
		build(l,mid,ls),build(mid+1,r,rs);
	}
	void update(int sl,int sr,int x,int l,int r,int p){
		if(sl<=l&&r<=sr)  return tr[p]=sz[p]-1,tag[p]=lc[p]=rc[p]=x,void();
		int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;pushdown(p,ls,rs);
		if(sl<=mid)  update(sl,sr,x,l,mid,ls);
		if(sr>mid)  update(sl,sr,x,mid+1,r,rs);
		lc[p]=lc[ls],rc[p]=rc[rs],tr[p]=tr[ls]+tr[rs]+(rc[ls]==lc[rs]&&rc[ls]);
	}
	int qdot(int x,int l,int r,int p){
		if(x==l)  return lc[p];
		if(x==r)  return rc[p];
		int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;
		pushdown(p,ls,rs);
		return (x<=mid?qdot(x,l,mid,ls):qdot(x,mid+1,r,rs));
	}
	int query(int sl,int sr,int l,int r,int p){
		if(sl<=l&&r<=sr)  return tr[p];
		int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;pushdown(p,ls,rs);
		if(sl<=mid&&sr<=mid)  return query(sl,sr,l,mid,ls);
		if(sl>mid&&sr>mid)  return query(sl,sr,mid+1,r,rs);
		return query(sl,sr,l,mid,ls)+query(sl,sr,mid+1,r,rs)+(rc[ls]==lc[rs]&&rc[ls]);
	}
}SGT;
void solve(){
	cin>>n>>m;
	for(int i=1;i<=n;i++)  v[i].clear();
	for(int i=1,x,y;i<n;i++)  cin>>x>>y,v[x].push_back(y),v[y].push_back(x);
	dfs1(1,dfn=0),dfs2(top[1]=1,0),SGT.build(1,n,1);
	for(int i=1,op,x,y;i<=m;i++){
		cin>>op>>x>>y;
		if(op==1){
			while(top[x]!=top[y]){
				if(dep[top[x]]<dep[top[y]])  swap(x,y);
				SGT.update(ls[top[x]],ls[x],i,1,n,1),x=pre[top[x]];
			}
			SGT.update(min(ls[x],ls[y]),max(ls[x],ls[y]),i,1,n,1);
		}
		else{
			int ans=0;
			while(top[x]!=top[y]){
				if(dep[top[x]]<dep[top[y]])  swap(x,y);
				int tmp1=SGT.qdot(ls[top[x]],1,n,1),tmp2=SGT.qdot(ls[pre[top[x]]],1,n,1);
				ans+=SGT.query(ls[top[x]],ls[x],1,n,1)+(tmp1==tmp2&&tmp1),x=pre[top[x]];
			}
			cout<<ans+SGT.query(min(ls[x],ls[y]),max(ls[x],ls[y]),1,n,1)<<"\n";
		}
	}
}
int main(){
	ios::sync_with_stdio(false),cin.tie(nullptr),cout.tie(nullptr);
	cin>>t;while(t--)  solve();
}

简单树剖练习题

与 E_firework 不一样的方法,我们考虑使用与轻重边一样的思路。

在线段树上区间内维护相邻点的 \((a_u+a_v)|a_u-a_v|^m\) 之和。

发现区间修改除两端点外都是加上 \(2k|a_u-a_v|^m\),所以考虑一并维护 \(|a_u-a_v|^m\) 的和。

这样区间修改时直接加上 \(2k|a_u-a_v|^m\),再在两端点单独进行单点修改就可以了。

树剖统计答案的时候也和轻重边一样,两链的分割点要单独统计答案。

值得注意的是因为包含 \(|a_u-a_v|\),所以不能对 \(a\) 数组加后的值取模,非常感谢 paper 帮我调出来这个点。

点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=200005,mod=7667713;
vector<int> v[N];
int a[N],ls[N],rs[N],dy[N],po[mod],p[mod],pr[mod];
int n,q,m,cnt,dfn,son[N],pre[N],dep[N],sz[N],top[N];
int quick_pow(int x,int y,int res=1){
	for(;y;x=x*x%mod,y>>=1)  if(y&1)  res=res*x%mod;
	return res;
}
void dfs1(int x,int fa){
	sz[x]=1,pre[x]=fa,dep[x]=dep[fa]+1,son[x]=0;
	for(auto y:v[x]){
		if(y==fa)  continue;
		dfs1(y,x),sz[x]+=sz[y];
		if(sz[y]>sz[son[x]]) son[x]=y;
	}
}
void dfs2(int x,int fa){
	ls[x]=++dfn,dy[dfn]=x;
	if(son[x])  top[son[x]]=top[x],dfs2(son[x],x);
	for(auto y:v[x])  if(y!=fa&&y!=son[x])  top[y]=y,dfs2(y,x);
	rs[x]=dfn;
}
struct Segment_tree{
	int lc[N<<1],rc[N<<1],sum[N<<1],tr[N<<1],tag[N<<1];
	void work(int p,int x){
		tr[p]=(tr[p]+x*2*sum[p])%mod,lc[p]+=x,rc[p]+=x,tag[p]+=x;
	}
	void pushdown(int p,int ls,int rs){
		if(!tag[p])  return;
		work(ls,tag[p]),work(rs,tag[p]),tag[p]=0;
	}
	void pushup(int p,int ls,int rs){
		lc[p]=lc[ls],rc[p]=rc[rs];
		sum[p]=(sum[ls]+sum[rs]+po[abs(rc[ls]-lc[rs])%mod])%mod;
		tr[p]=(tr[ls]+tr[rs]+(rc[ls]+lc[rs])*po[abs(rc[ls]-lc[rs])%mod])%mod;
	}
	void build(int l=1,int r=n,int p=1){
		if(l==r)  return lc[p]=rc[p]=a[dy[l]],void();
		int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;
		build(l,mid,ls),build(mid+1,r,rs),pushup(p,ls,rs);
	}
	int geta(int x,int l=1,int r=n,int p=1){
		if(l==r)  return lc[p];
		int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;pushdown(p,ls,rs);
		return x<=mid?geta(x,l,mid,ls):geta(x,mid+1,r,rs);
	}
	void upddot(int x,int l=1,int r=n,int p=1){
		if(l==r)  return;
		int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;pushdown(p,ls,rs);
		x<=mid?upddot(x,l,mid,ls):upddot(x,mid+1,r,rs),pushup(p,ls,rs);
	}
	void update(int sl,int sr,int x,int l=1,int r=n,int p=1){
		if(sl<=l&&r<=sr)  return work(p,x);
		int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;pushdown(p,ls,rs);
		if(sl<=mid)  update(sl,sr,x,l,mid,ls);
		if(sr>mid)  update(sl,sr,x,mid+1,r,rs);pushup(p,ls,rs);
	}
	int query(int sl,int sr,int l=1,int r=n,int p=1){
		if(sl<=l&&r<=sr)  return tr[p];
		int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;pushdown(p,ls,rs);
		if(sr<=mid)  return query(sl,sr,l,mid,ls);
		if(sl>mid)  return query(sl,sr,mid+1,r,rs);
		int res=(query(sl,sr,l,mid,ls)+query(sl,sr,mid+1,r,rs))%mod;
		return (res+(rc[ls]+lc[rs])*po[abs(rc[ls]-lc[rs])%mod])%mod;
	}
}SGT;
signed main(){
	ios::sync_with_stdio(false),cin.tie(nullptr),cout.tie(nullptr);
	cin>>n>>q>>m,po[0]=0,po[1]=1;
	for(int i=2;i<mod;i++){
		if(!p[i])  po[i]=quick_pow(i,m),pr[++cnt]=i;
		for(int j=1;j<=cnt&&i*pr[j]<mod;i++){
			p[i*pr[j]]=1,po[i*pr[j]]=po[i]*po[pr[j]]%mod;
			if(i%pr[j]==0)  break;
		}
	}
	for(int i=1;i<=n;i++)  cin>>a[i];
	for(int i=1,x,y;i<n;i++)  cin>>x>>y,v[x].push_back(y),v[y].push_back(x);
	dfs1(1,0),top[1]=1,dfs2(1,0),SGT.build();
	for(int i=1,x,y,k,res=0;i<=q;i++,res=0){
		cin>>x>>y>>k;
		if(k){
			while(top[x]!=top[y]){
				if(dep[top[x]]<dep[top[y]])  swap(x,y);
				SGT.update(ls[top[x]],ls[x],k);
				SGT.upddot(ls[top[x]]),SGT.upddot(ls[x]),x=pre[top[x]];	
			}
			if(ls[x]>ls[y])  swap(x,y);
			SGT.update(ls[x],ls[y],k),SGT.upddot(ls[x]),SGT.upddot(ls[y]);
		}
		else{
			while(top[x]!=top[y]){
				if(dep[top[x]]<dep[top[y]])  swap(x,y);
				res=(res+SGT.query(ls[top[x]],ls[x]))%mod;
				int tmp1=SGT.geta(ls[top[x]]),tmp2=SGT.geta(ls[pre[top[x]]]);
				res=(res+(tmp1+tmp2)*po[abs(tmp1-tmp2)%mod])%mod,x=pre[top[x]];
			}
			cout<<(res+SGT.query(min(ls[x],ls[y]),max(ls[x],ls[y])))%mod<<"\n";
		}
	}
}
posted @ 2026-01-12 14:29  tkdqmx  阅读(5)  评论(0)    收藏  举报