省选集训 9 - 树上技巧
[NOI2021] 轻重边
路径上信息考虑树剖,每次 \(op=1\) 将路径上点染成新颜色,然后 \(op=2\) 查询路径同色相邻点对。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define N 100005
vector<int> v[N];
int t,n,m,dfn,son[N],pre[N],dep[N],sz[N],top[N],ls[N],rs[N];
void dfs1(int x,int fa){
sz[x]=1,pre[x]=fa,dep[x]=dep[fa]+1,son[x]=0;
for(auto y:v[x]){
if(y==fa) continue;
dfs1(y,x),sz[x]+=sz[y];
if(sz[y]>sz[son[x]]) son[x]=y;
}
}
void dfs2(int x,int fa){
ls[x]=++dfn;
if(son[x]) top[son[x]]=top[x],dfs2(son[x],x);
for(auto y:v[x]) if(y!=fa&&y!=son[x]) top[y]=y,dfs2(y,x);
rs[x]=dfn;
}
struct Segment_tree{
int tr[N<<1],tag[N<<1],lc[N<<1],rc[N<<1],sz[N<<1];
void pushdown(int p,int ls,int rs){
if(tag[p]==-1) return;
lc[ls]=rc[ls]=tag[ls]=tag[p],tr[ls]=sz[ls]-1;
lc[rs]=rc[rs]=tag[rs]=tag[p],tr[rs]=sz[rs]-1,tag[p]=-1;
}
void build(int l,int r,int p){
sz[p]=r-l+1,tr[p]=lc[p]=rc[p]=0,tag[p]=-1;
if(l==r) return;
int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;
build(l,mid,ls),build(mid+1,r,rs);
}
void update(int sl,int sr,int x,int l,int r,int p){
if(sl<=l&&r<=sr) return tr[p]=sz[p]-1,tag[p]=lc[p]=rc[p]=x,void();
int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;pushdown(p,ls,rs);
if(sl<=mid) update(sl,sr,x,l,mid,ls);
if(sr>mid) update(sl,sr,x,mid+1,r,rs);
lc[p]=lc[ls],rc[p]=rc[rs],tr[p]=tr[ls]+tr[rs]+(rc[ls]==lc[rs]&&rc[ls]);
}
int qdot(int x,int l,int r,int p){
if(x==l) return lc[p];
if(x==r) return rc[p];
int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;
pushdown(p,ls,rs);
return (x<=mid?qdot(x,l,mid,ls):qdot(x,mid+1,r,rs));
}
int query(int sl,int sr,int l,int r,int p){
if(sl<=l&&r<=sr) return tr[p];
int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;pushdown(p,ls,rs);
if(sl<=mid&&sr<=mid) return query(sl,sr,l,mid,ls);
if(sl>mid&&sr>mid) return query(sl,sr,mid+1,r,rs);
return query(sl,sr,l,mid,ls)+query(sl,sr,mid+1,r,rs)+(rc[ls]==lc[rs]&&rc[ls]);
}
}SGT;
void solve(){
cin>>n>>m;
for(int i=1;i<=n;i++) v[i].clear();
for(int i=1,x,y;i<n;i++) cin>>x>>y,v[x].push_back(y),v[y].push_back(x);
dfs1(1,dfn=0),dfs2(top[1]=1,0),SGT.build(1,n,1);
for(int i=1,op,x,y;i<=m;i++){
cin>>op>>x>>y;
if(op==1){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
SGT.update(ls[top[x]],ls[x],i,1,n,1),x=pre[top[x]];
}
SGT.update(min(ls[x],ls[y]),max(ls[x],ls[y]),i,1,n,1);
}
else{
int ans=0;
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
int tmp1=SGT.qdot(ls[top[x]],1,n,1),tmp2=SGT.qdot(ls[pre[top[x]]],1,n,1);
ans+=SGT.query(ls[top[x]],ls[x],1,n,1)+(tmp1==tmp2&&tmp1),x=pre[top[x]];
}
cout<<ans+SGT.query(min(ls[x],ls[y]),max(ls[x],ls[y]),1,n,1)<<"\n";
}
}
}
int main(){
ios::sync_with_stdio(false),cin.tie(nullptr),cout.tie(nullptr);
cin>>t;while(t--) solve();
}
简单树剖练习题
与 E_firework 不一样的方法,我们考虑使用与轻重边一样的思路。
在线段树上区间内维护相邻点的 \((a_u+a_v)|a_u-a_v|^m\) 之和。
发现区间修改除两端点外都是加上 \(2k|a_u-a_v|^m\),所以考虑一并维护 \(|a_u-a_v|^m\) 的和。
这样区间修改时直接加上 \(2k|a_u-a_v|^m\),再在两端点单独进行单点修改就可以了。
树剖统计答案的时候也和轻重边一样,两链的分割点要单独统计答案。
值得注意的是因为包含 \(|a_u-a_v|\),所以不能对 \(a\) 数组加后的值取模,非常感谢 paper 帮我调出来这个点。
点击查看代码
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=200005,mod=7667713;
vector<int> v[N];
int a[N],ls[N],rs[N],dy[N],po[mod],p[mod],pr[mod];
int n,q,m,cnt,dfn,son[N],pre[N],dep[N],sz[N],top[N];
int quick_pow(int x,int y,int res=1){
for(;y;x=x*x%mod,y>>=1) if(y&1) res=res*x%mod;
return res;
}
void dfs1(int x,int fa){
sz[x]=1,pre[x]=fa,dep[x]=dep[fa]+1,son[x]=0;
for(auto y:v[x]){
if(y==fa) continue;
dfs1(y,x),sz[x]+=sz[y];
if(sz[y]>sz[son[x]]) son[x]=y;
}
}
void dfs2(int x,int fa){
ls[x]=++dfn,dy[dfn]=x;
if(son[x]) top[son[x]]=top[x],dfs2(son[x],x);
for(auto y:v[x]) if(y!=fa&&y!=son[x]) top[y]=y,dfs2(y,x);
rs[x]=dfn;
}
struct Segment_tree{
int lc[N<<1],rc[N<<1],sum[N<<1],tr[N<<1],tag[N<<1];
void work(int p,int x){
tr[p]=(tr[p]+x*2*sum[p])%mod,lc[p]+=x,rc[p]+=x,tag[p]+=x;
}
void pushdown(int p,int ls,int rs){
if(!tag[p]) return;
work(ls,tag[p]),work(rs,tag[p]),tag[p]=0;
}
void pushup(int p,int ls,int rs){
lc[p]=lc[ls],rc[p]=rc[rs];
sum[p]=(sum[ls]+sum[rs]+po[abs(rc[ls]-lc[rs])%mod])%mod;
tr[p]=(tr[ls]+tr[rs]+(rc[ls]+lc[rs])*po[abs(rc[ls]-lc[rs])%mod])%mod;
}
void build(int l=1,int r=n,int p=1){
if(l==r) return lc[p]=rc[p]=a[dy[l]],void();
int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;
build(l,mid,ls),build(mid+1,r,rs),pushup(p,ls,rs);
}
int geta(int x,int l=1,int r=n,int p=1){
if(l==r) return lc[p];
int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;pushdown(p,ls,rs);
return x<=mid?geta(x,l,mid,ls):geta(x,mid+1,r,rs);
}
void upddot(int x,int l=1,int r=n,int p=1){
if(l==r) return;
int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;pushdown(p,ls,rs);
x<=mid?upddot(x,l,mid,ls):upddot(x,mid+1,r,rs),pushup(p,ls,rs);
}
void update(int sl,int sr,int x,int l=1,int r=n,int p=1){
if(sl<=l&&r<=sr) return work(p,x);
int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;pushdown(p,ls,rs);
if(sl<=mid) update(sl,sr,x,l,mid,ls);
if(sr>mid) update(sl,sr,x,mid+1,r,rs);pushup(p,ls,rs);
}
int query(int sl,int sr,int l=1,int r=n,int p=1){
if(sl<=l&&r<=sr) return tr[p];
int mid=(l+r)>>1,ls=mid<<1,rs=mid<<1|1;pushdown(p,ls,rs);
if(sr<=mid) return query(sl,sr,l,mid,ls);
if(sl>mid) return query(sl,sr,mid+1,r,rs);
int res=(query(sl,sr,l,mid,ls)+query(sl,sr,mid+1,r,rs))%mod;
return (res+(rc[ls]+lc[rs])*po[abs(rc[ls]-lc[rs])%mod])%mod;
}
}SGT;
signed main(){
ios::sync_with_stdio(false),cin.tie(nullptr),cout.tie(nullptr);
cin>>n>>q>>m,po[0]=0,po[1]=1;
for(int i=2;i<mod;i++){
if(!p[i]) po[i]=quick_pow(i,m),pr[++cnt]=i;
for(int j=1;j<=cnt&&i*pr[j]<mod;i++){
p[i*pr[j]]=1,po[i*pr[j]]=po[i]*po[pr[j]]%mod;
if(i%pr[j]==0) break;
}
}
for(int i=1;i<=n;i++) cin>>a[i];
for(int i=1,x,y;i<n;i++) cin>>x>>y,v[x].push_back(y),v[y].push_back(x);
dfs1(1,0),top[1]=1,dfs2(1,0),SGT.build();
for(int i=1,x,y,k,res=0;i<=q;i++,res=0){
cin>>x>>y>>k;
if(k){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
SGT.update(ls[top[x]],ls[x],k);
SGT.upddot(ls[top[x]]),SGT.upddot(ls[x]),x=pre[top[x]];
}
if(ls[x]>ls[y]) swap(x,y);
SGT.update(ls[x],ls[y],k),SGT.upddot(ls[x]),SGT.upddot(ls[y]);
}
else{
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
res=(res+SGT.query(ls[top[x]],ls[x]))%mod;
int tmp1=SGT.geta(ls[top[x]]),tmp2=SGT.geta(ls[pre[top[x]]]);
res=(res+(tmp1+tmp2)*po[abs(tmp1-tmp2)%mod])%mod,x=pre[top[x]];
}
cout<<(res+SGT.query(min(ls[x],ls[y]),max(ls[x],ls[y])))%mod<<"\n";
}
}
}

浙公网安备 33010602011771号