LuoguP5327 [ZJOI2019]语言 线段树合并+树链求并
比较好的一道数据结构题.
对于 $i$,我们希望求出所有经过 $i$ 号点的路径所构成的树链之并.
考虑一个求解树链的并的经典做法就是 $\sum_{i=1}^{n} dep[i]-\sum_{i=2}^{n} dep[LCA(i,i-1)]$.
这里要求所有点都要按照 $dfs$ 序排好.
那么这道题中我们就基于 DFS 序对每个点建立动态开点线段树.
加入一条路径就是在 $x,y$ 处添加 $(x,y)$,然后在 $fa[lca]$ 处将这 4 个点再都删掉.
儿子向父亲合并的时候直接用线段树合并就行.
pushup 的时候维护:$(s,t,f,si)$ 分别表示区间最靠左/右的节点编号,区间树链之并长度,以及叶节点开的桶.
用 $RMQ-O(1)$ 求 $LCA$ 可做到 $O(n \log n)$.
代码:
#include <cstdio>
#include <vector>
#include <cstring>
#include <algorithm>
#define N 100009
#define ll long long
#define pb push_back
#define lson s[x].ls
#define rson s[x].rs
#define setIO(s) freopen(s".in","r",stdin)
using namespace std;
ll ans;
int n,m,edges,tim,tot;
vector<int>ADD[N],DEL[N];
int hd[N],to[N<<1],nex[N<<1],rt[N];
int fa[N],dfn[N],seq[20][N<<1],dep[N],Log[N<<1];
void add(int u,int v) {
nex[++edges]=hd[u];
hd[u]=edges,to[edges]=v;
}
void dfs(int x,int ff) {
dep[x]=dep[ff]+1;
fa[x]=ff,dfn[x]=++tim,seq[0][tim]=x;
for(int i=hd[x];i;i=nex[i]) {
int y=to[i];
if(y==ff) continue;
dfs(y,x);
seq[0][++tim]=x;
}
}
void RMQ() {
Log[1]=0;
for(int i=2;i<=2*n;++i) {
Log[i]=Log[i>>1]+1;
}
for(int i=1;(1<<i)<=2*n;++i) {
for(int j=1;j+(1<<i)-1<=2*n;++j) {
int x=seq[i-1][j],y=seq[i-1][j+(1<<(i-1))];
if(dep[x]<dep[y]) seq[i][j]=x;
else seq[i][j]=y;
}
}
}
int get_lca(int x,int y) {
if(!x||!y) {
return 0;
}
x=dfn[x],y=dfn[y];
if(x>y) swap(x,y);
int p=Log[y-x+1];
return dep[seq[p][x]]<dep[seq[p][y-(1<<p)+1]]?seq[p][x]:seq[p][y-(1<<p)+1];
}
struct data {
int ls,rs,si,s,t,f;
data() { ls=rs=si=s=t=f=0;}
int getans() {
return f-dep[get_lca(s,t)];
}
}s[N*80];
void pushup(int x) {
s[x].f=s[lson].f+s[rson].f-dep[get_lca(s[lson].t,s[rson].s)];
s[x].s=s[lson].s?s[lson].s:s[rson].s;
s[x].t=s[rson].t?s[rson].t:s[lson].t;
}
void update(int &x,int l,int r,int p,int v) {
if(!x) x=++tot;
if(l==r) {
s[x].si+=v;
if(s[x].si) {
s[x].s=s[x].t=seq[0][l];
s[x].f=dep[seq[0][l]];
}
else {
s[x].s=s[x].t=s[x].f=0;
}
return;
}
int mid=(l+r)>>1;
if(p<=mid) {
update(lson,l,mid,p,v);
}
else {
update(rson,mid+1,r,p,v);
}
pushup(x);
}
int merge(int l,int r,int x,int y) {
if(!x||!y) {
return x+y;
}
int now=++tot,mid=(l+r)>>1;
if(l==r) {
s[now].si=s[x].si+s[y].si;
s[now].s=s[x].s|s[y].s;
s[now].t=s[x].t|s[y].t;
s[now].f=s[x].f|s[y].f;
return now;
}
s[now].ls=merge(l,mid,s[x].ls,s[y].ls);
s[now].rs=merge(mid+1,r,s[x].rs,s[y].rs);
pushup(now);
return now;
}
void solve(int x,int ff) {
for(int i=hd[x];i;i=nex[i]) {
int y=to[i];
if(y==ff) continue;
solve(y,x);
rt[x]=merge(1,n<<1,rt[x],rt[y]);
}
for(int i=0;i<ADD[x].size();++i) {
update(rt[x],1,n<<1,dfn[ADD[x][i]],1);
}
for(int i=0;i<DEL[x].size();++i) {
update(rt[x],1,n<<1,dfn[DEL[x][i]],-1);
}
ans+=s[rt[x]].getans();
}
int main() {
// setIO("input");
int x,y,z;
scanf("%d%d",&n,&m);
for(int i=1;i<n;++i) {
scanf("%d%d",&x,&y);
add(x,y),add(y,x);
}
dfs(1,0),RMQ();
for(int i=1;i<=m;++i) {
scanf("%d%d",&x,&y);
int lca=get_lca(x,y);
ADD[x].pb(x),ADD[x].pb(y);
ADD[y].pb(x),ADD[y].pb(y);
if(fa[lca]) {
DEL[fa[lca]].pb(x);
DEL[fa[lca]].pb(y);
DEL[fa[lca]].pb(x);
DEL[fa[lca]].pb(y);
}
}
solve(1,0);
printf("%lld\n",ans>>1);
return 0;
}

浙公网安备 33010602011771号