洛谷 P3384 【模板】树链剖分
Code:
#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
using namespace std;
typedef long long ll;
const int maxn=100000+2;
ll mod;int n;
int head[maxn],nex[maxn*2],to[maxn*2];
int p[maxn],dep[maxn],siz[maxn],son[maxn],top[maxn];int cnt,cnt2; //树剖
ll sumv[maxn*4];
int lazy[maxn*4],A[maxn],st[maxn],ed[maxn],val[maxn];//线段树
void addedge(int u,int v){nex[++cnt]=head[u],head[u]=cnt,to[cnt]=v;}
void dfs1(int u,int fa,int cur){
p[u]=fa,dep[u]=cur,siz[u]=1;
for(int i=head[u];i;i=nex[i])
if(to[i]!=fa)
{
dfs1(to[i],u,cur+1);
siz[u]+=siz[to[i]];
if(son[u]==-1||siz[to[i]]>siz[son[u]])son[u]=to[i];
}
}
void dfs2(int u,int tp){
top[u]=tp,A[u]=++cnt2,st[u]=cnt2;
if(son[u]!=-1)dfs2(son[u],tp);
for(int i=head[u];i;i=nex[i])
if(to[i]!=p[u]&&to[i]!=son[u])dfs2(to[i],to[i]);
ed[u]=cnt2;
}
void down(int L,int R,int o){
if(lazy[o]){
int mid=(L+R)/2;
lazy[o*2]+=lazy[o],lazy[o*2+1]+=lazy[o];
sumv[o*2]+=(mid-L+1)*lazy[o],sumv[o*2+1]+=(R-mid)*lazy[o];
lazy[o]=0;
if(sumv[o*2]>=mod)sumv[o*2]%=mod;
if(sumv[o*2+1]>=mod)sumv[o*2+1]%=mod;
if(lazy[o*2]>=mod)lazy[o*2]%=mod;
if(lazy[o*2+1]>=mod)lazy[o*2+1]%=mod;
}
}
void build(int L,int R,int o,int arr[])
{
if(L==R){sumv[o]=arr[L];if(sumv[o]>=mod)sumv[o]%=mod;return;}
int mid=(L+R)/2;
build(L,mid,o*2,arr);
build(mid+1,R,o*2+1,arr);
sumv[o]=sumv[o*2]+sumv[o*2+1];
if(sumv[o]>=mod)sumv[o]%=mod;
}
void update(int l,int r,int k,int L,int R,int o){
if(l<=L&&r>=R){
lazy[o]+=k,sumv[o]+=(R-L+1)*k;
if(lazy[o]>=mod)lazy[o]%=mod;
if(sumv[o]>=mod)sumv[o]%=mod;
return;
}
int mid=(L+R)/2;
down(L,R,o);
if(l<=mid)update(l,r,k,L,mid,o*2);
if(r>mid)update(l,r,k,mid+1,R,o*2+1);
sumv[o]=sumv[o*2]+sumv[o*2+1];
if(sumv[o]>=mod)sumv[o]%=mod;
}
ll query(int l,int r,int L,int R,int o)
{
if(l<=L&&r>=R)return sumv[o];
int mid=(L+R)/2;
down(L,R,o);
ll ret=0;
if(l<=mid)ret+=query(l,r,L,mid,o*2);
if(r>mid)ret+=query(l,r,mid+1,R,o*2+1);
sumv[o]=sumv[o*2]+sumv[o*2+1];
if(sumv[o]>=mod)sumv[o]%=mod;
if(ret>=mod)ret%=mod;
return ret;
}
void up(int x,int y,int del){
while(top[x]!=top[y]){
if(dep[top[y]]<dep[top[x]]){update(A[top[x]],A[x],del,1,n,1);x=p[top[x]];}
else {update(A[top[y]],A[y],del,1,n,1);y=p[top[y]];}
}
if(dep[x]<dep[y])update(A[x],A[y],del,1,n,1);
else update(A[y],A[x],del,1,n,1);
}
ll look_up(int x,int y)
{
ll _sum=0;
while(top[x]!=top[y]){
if(dep[top[y]]<dep[top[x]]){
_sum+=query(A[top[x]],A[x],1,n,1);
if(_sum>=mod)_sum%=mod;
x=p[top[x]];
}
else {
_sum+=query(A[top[y]],A[y],1,n,1);
if(_sum>=mod)_sum%=mod;
y=p[top[y]];
}
}
if(dep[x]<dep[y])_sum+=query(A[x],A[y],1,n,1);
else _sum+=query(A[y],A[x],1,n,1);
if(_sum>=mod)_sum%=mod;
return _sum;
}
int main()
{
int m,r;
scanf("%d%d%d",&n,&m,&r);scanf("%lld",&mod);
for(int i=1;i<=n;++i)scanf("%d",&val[i]);
for(int i=1;i<n;++i){int a,b;scanf("%d%d",&a,&b);addedge(a,b);addedge(b,a);}
memset(son,-1,sizeof(son));
dfs1(r,-1,1);
dfs2(r,r);
for(int i=1;i<=n;++i)siz[A[i]]=val[i];
build(1,n,1,siz);
while(m--)
{
int op;scanf("%d",&op);
if(op==1){
int x,y,z;scanf("%d%d%d",&x,&y,&z);up(x,y,z);
}
if(op==2){
int x,y;scanf("%d%d",&x,&y);printf("%lld\n",look_up(x,y));
}
if(op==3){
int x,z;scanf("%d%d",&x,&z);
update(st[x],ed[x],z,1,n,1);
}
if(op==4){
int x;scanf("%d",&x);
printf("%lld\n",query(st[x],ed[x],1,n,1));
}
}
return 0;
}

浙公网安备 33010602011771号