【洛谷P6071】Treequery

题目

题目链接:https://www.luogu.com.cn/problem/P6071
给定一棵 \(n\) 个点的无根树,边有边权。
\(E(x,y)\) 表示树上 \(x,y\) 之间的简单路径上的所有边的集合,特别地,当 \(x=y\) 时,\(E(x,y) = \varnothing\)
你需要 实时 回答 \(q\) 个询问,每个询问给定 \(p,l,r\),请你求出集合 \(\bigcap_{i=l}^r E(p,i)\) 中所有边的边权和,即 \(E(p, l\dots r)\) 的交所包含的边的边权和。
通俗的讲,你需要求出 \(p\)\([l,r]\) 内每一个点的简单路径的公共部分长度。
\(n,Q\leq 2\times 10^5\)

思路

问题等价于询问 \(x\)\([l,r]\) 之间点的虚树的距离。
分类讨论:

  • 如果 \([l,r]\) 内的点都在 \(x\) 的子树内,那么答案等于 \([l,r]\) 内点的 \(\text{LCA}\)\(x\) 的距离。我们知道一个区间点的 \(\text{LCA}\) 等于 dfs 序最小最大两个点的 \(\text{LCA}\)
  • 如果 \([l,r]\) 内的点一部分在 \(x\) 的子树内,\(x\) 肯定被包含在虚树内,答案为 \(0\)
  • 如果 \([l,r]\) 内的点全部不在 \(x\) 的子树内,考虑虚树的根(也就是 \([l,r]\) 内点的 \(\text{LCA}\)):
    • 如果 \(\text{LCA}\) 不是 \(x\) 的祖先,显然 \(x\) 到虚树的距离就是 \(x\)\(\text{LCA}\) 的距离。
    • 如果 \(\text{LCA}\)\(x\) 的祖先,那么答案等于把 \(x\) 加入虚树后,\(x\) 到它虚树上父亲的距离。又因为 \(x\) 在虚树上的父亲一定是 dfs 序上 \(x\) 的前驱或后继与 \(x\)\(\text{LCA}\),所以我们只需要找到区间 \([l,r]\) 内,dfs 序为 \(x\) 的前驱和后继两个点即可。

\(\text{LCA}\) 可以用倍增,而其他操作只需要维护一个以点编号为版本,dfs 序为区间的主席树即可。
时间复杂度 \(O((n+m)\log n)\)

代码

#include <bits/stdc++.h>
using namespace std;

const int N=200010,LG=18;
int n,Q,tot,last,rt[N],head[N],dis[N],dep[N],id[N],siz[N],rk[N],f[N][LG+1];

struct edge
{
	int next,to,dis;
}e[N*2];

void add(int from,int to,int dis)
{
	e[++tot]=(edge){head[from],to,dis};
	head[from]=tot;
}

void dfs(int x,int fa)
{
	id[x]=++tot; rk[tot]=x; siz[x]=1;
	dep[x]=dep[fa]+1; f[x][0]=fa;
	for (int i=1;i<=LG;i++)
		f[x][i]=f[f[x][i-1]][i-1];
	for (int i=head[x];~i;i=e[i].next)
	{
		int v=e[i].to;
		if (v!=fa)
		{
			dis[v]=dis[x]+e[i].dis;
			dfs(v,x); siz[x]+=siz[v];
		}
	}
}

int lca(int x,int y)
{
	if (dep[x]<dep[y]) swap(x,y);
	for (int i=LG;i>=0;i--)
		if (dep[f[x][i]]>=dep[y]) x=f[x][i];
	if (x==y) return x;
	for (int i=LG;i>=0;i--)
		if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
	return f[x][0];
}

struct SegTree
{
	int tot,lc[N*LG*4],rc[N*LG*4],cnt[N*LG*4];
	
	int update(int now,int l,int r,int k)
	{
		int x=++tot;
		lc[x]=lc[now]; rc[x]=rc[now]; cnt[x]=cnt[now]+1;
		if (l==r) return x;
		int mid=(l+r)>>1;
		if (k<=mid) lc[x]=update(lc[now],l,mid,k);
			else rc[x]=update(rc[now],mid+1,r,k);
		return x;
	}
	
	int query1(int nowl,int nowr,int l,int r,int ql,int qr)
	{
		if (ql<=l && qr>=r) return cnt[nowr]-cnt[nowl];
		int mid=(l+r)>>1,res=0;
		if (ql<=mid) res+=query1(lc[nowl],lc[nowr],l,mid,ql,qr);
		if (qr>mid) res+=query1(rc[nowl],rc[nowr],mid+1,r,ql,qr);
		return res;
	}
	
	int query2(int nowl,int nowr,int l,int r,int k)
	{
		if (l==r) return rk[l];
		int mid=(l+r)>>1,res=cnt[lc[nowr]]-cnt[lc[nowl]];
		if (k<=res) return query2(lc[nowl],lc[nowr],l,mid,k);
			else return query2(rc[nowl],rc[nowr],mid+1,r,k-res);
	}
}seg;

int main()
{
	memset(head,-1,sizeof(head));
	scanf("%d%d",&n,&Q);
	for (int i=1,x,y,z;i<n;i++)
	{
		scanf("%d%d%d",&x,&y,&z);
		add(x,y,z); add(y,x,z);
	}
	tot=0; dfs(1,0);
	for (int i=1;i<=n;i++)
		rt[i]=seg.update(rt[i-1],1,n,id[i]);
	while (Q--)
	{
		int x,l,r;
		scanf("%d%d%d",&x,&l,&r);
		x^=last; l^=last; r^=last;
		int cnt=seg.query1(rt[l-1],rt[r],1,n,id[x],id[x]+siz[x]-1);
		if (cnt && cnt<r-l+1) last=0;
		else
		{
			int u=seg.query2(rt[l-1],rt[r],1,n,1);
			int v=seg.query2(rt[l-1],rt[r],1,n,r-l+1);
			int p=lca(u,v);
			if (cnt==r-l+1)
				last=dis[p]-dis[x];
			else if (id[x]<id[p] || id[x]>id[p]+siz[p]-1)
				last=dis[p]+dis[x]-2*dis[lca(p,x)];
			else
			{
				cnt=seg.query1(rt[l-1],rt[r],1,n,1,id[x]);
				if (cnt) u=seg.query2(rt[l-1],rt[r],1,n,cnt);
				if (cnt<r-l+1) v=seg.query2(rt[l-1],rt[r],1,n,cnt+1);
				last=dis[x]-max(dis[lca(u,x)],dis[lca(v,x)]);
			}
		}
		cout<<last<<"\n";
	}
	return 0;
}
posted @ 2021-06-13 20:11  stoorz  阅读(56)  评论(0编辑  收藏  举报