「ZJOI2019」语言

传送门

Description

给定一棵\(n\)个点的树和\(m\)条链,两个点可以联会当且仅当它们同在某一条链上,求可以联会的点的方案数

\(n,m\leq10^5\)

Solution 

考虑计算每个点的贡献,然后将贡献和除以\(2\)

相当于求出对于每个点经过它的所有链端点的虚树的大小(显然这些链的并就是一个

每个虚树都先加上了根节点,可以保证以下求虚树大小的做法合法:

注意:这里的根节点是\(1\)号节点,但是我们并不看作\(1\)号节点一定在这个虚树内——它更像是一个等价于\(1\)号点的虚拟节点

先对所有端点按照\(dfs\)序进行排序,这样根肯定在第一个

然后在按顺序加入每个点的过程中

加上这个点连向当前虚树最短链的长度,也就是\(dep_{a[i]}-dep_{lca(a[i],a[i-1])}\)

\(last\)指的是上一个加入的点

对于最小的那个节点,它的贡献就是\(dep[a[1]]\)

我们考虑线段树的分治的过程,假设我们先分别求出了\(dfs\)序在\([l,mid]\)\([mid+1,r]\)中端点的虚树信息

接下来要将其合并,这时,其实只需要合并左区间中最大的和右区间中最小的即可,这时我们是不算根节点的

合并之后,发现多算了一些点,具体来说,是重复计算了从根到\(LCA\)的一段路径,将其减去

最后,求得虚树大小后,还要减去从根到\(LCA\)的长度,在合并的同时维护一下集合的\(lca\)即可

在本题中,求\(lca\)采用\(RMQ\)方法

树上差分,在端点和端点的\(lca\)的父亲处进行加点和删点操作

相当于对线段树进行单点修改

从孩子那里继承信息需要用到线段树合并

复杂度是\(O(n\log n)\)


Code 

#include<bits/stdc++.h>
#define ll long long
using namespace std;
#define reg register
inline int read()
{
	int x=0,f=1;char ch=getchar();
	while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
	return x*f;
}
const int MN=1e5+5,MM=4e6+5;
int n,m;ll ans;
struct ed{int to,nex;}e[MN<<1];int hr[MN],en;
inline void ins(int x,int y)
{
	e[++en]=(ed){y,hr[x]};hr[x]=en;
	e[++en]=(ed){x,hr[y]};hr[y]=en;
}
/*------------求lca------------*/
int st[MN<<1][20],eu[MN],id[MN],dfn[MN],pin,dep[MN],fa[MN],dind,lg[MN<<1];
void dfs(int x,int f)
{
	id[dfn[x]=++pin]=x;st[dind][0]=x;
	eu[x]=++dind;dep[x]=dep[fa[x]=f]+1;
	for(int i=hr[x];i;i=e[i].nex)if(e[i].to^f)
		st[dind][0]=x,++dind,dfs(e[i].to,x);
}
inline void init()
{
	dfs(1,0);register int i,j;
	for(i=2;i<=dind;++i)lg[i]=lg[i>>1]+1;
	for(j=1;j<20;++j)for(i=1;i<=dind;++i)
	{
		if(i+(1<<j)>dind) break;
		st[i][j]=dep[st[i][j-1]]<dep[st[i+(1<<j-1)][j-1]]?st[i][j-1]:st[i+(1<<j-1)][j-1];
	}
}
int LCA(int x,int y)
{
	if(!x||!y) return x|y;
	if(x==y) return x;
	x=eu[x];y=eu[y];if(x>y)swap(x,y);
	int b=lg[y-x];
	if(dep[st[x][b]]<dep[st[y-(1<<b)][b]])
		return st[x][b];
	return st[y-(1<<b)][b];
}
/*------------线段树合并------------*/
struct xxx{
	int x,y,z;
	xxx(int x,int y,int z):x(x),y(y),z(z){}
};
std::vector<xxx> opt[MN];
int ml[MM],mr[MM],ls[MM],rs[MM],v[MM],sum[MM],lca[MM],rt[MN],tot;
void up(int x)
{
	ml[x]=ml[ls[x]]?ml[ls[x]]:ml[rs[x]];
	mr[x]=mr[rs[x]]?mr[rs[x]]:mr[ls[x]];
	v[x]=v[ls[x]]+v[rs[x]];
	if(!v[ls[x]]||!v[rs[x]]) sum[x]=sum[ls[x]]+sum[rs[x]];
	else sum[x]=sum[ls[x]]+sum[rs[x]]-dep[LCA(id[mr[ls[x]]],id[ml[rs[x]]])];
	if(lca[ls[x]]&&lca[rs[x]]) lca[x]=LCA(lca[ls[x]],lca[rs[x]]);
	else lca[x]=lca[ls[x]]|lca[rs[x]];
}
int Merge(int x,int y,int l,int r)
{
	if(!x||!y) return x|y;
	if(l==r)
	{
		v[x]+=v[y];
		if(v[x]) ml[x]=mr[x]=l,lca[x]=id[l],sum[x]=dep[id[l]];
		else ml[x]=mr[x]=lca[x]=sum[x]=0;
		return x;
	}
	int mid=(l+r)>>1;
	ls[x]=Merge(ls[x],ls[y],l,mid);
	rs[x]=Merge(rs[x],rs[y],mid+1,r);
	up(x);return x;
}
void Modify(int &x,int l,int r,int a,int b)
{
	if(!x) x=++tot;
	if(l==r)
	{
		v[x]+=b;
		if(v[x]) ml[x]=mr[x]=l,lca[x]=id[l],sum[x]=dep[id[l]];
		else ml[x]=mr[x]=lca[x]=sum[x]=0;
		return;
	}
	int mid=(l+r)>>1;
	if(a<=mid) Modify(ls[x],l,mid,a,b);
	else Modify(rs[x],mid+1,r,a,b);
	up(x);
}
void Solve(int x)
{
	register int i;
	for(i=hr[x];i;i=e[i].nex)if(e[i].to^fa[x])
		Solve(e[i].to),rt[x]=Merge(rt[x],rt[e[i].to],1,n);
	for(i=opt[x].size()-1;~i;--i)
		Modify(rt[x],1,n,opt[x][i].x,opt[x][i].z),
		Modify(rt[x],1,n,opt[x][i].y,opt[x][i].z);
	ans+=sum[rt[x]]-dep[lca[rt[x]]];
}
int main()
{
	n=read();m=read();
	register int i,x,y,z;
	for(i=1;i<n;++i) x=read(),ins(x,read());
	init();
	for(i=1;i<=m;++i)
	{
		x=read(),y=read();z=LCA(x,y);
		opt[x].push_back(xxx(dfn[x],dfn[y],1));
		opt[y].push_back(xxx(dfn[x],dfn[y],1));
		opt[z].push_back(xxx(dfn[x],dfn[y],-1));
		opt[fa[z]].push_back(xxx(dfn[x],dfn[y],-1));
	}
	ans=0;Solve(1);
    return 0*printf("%lld\n",ans>>1);
}


Blog来自PaperCloud,未经允许,请勿转载,TKS!

posted @ 2019-07-27 20:13  PaperCloud  阅读(439)  评论(0编辑  收藏  举报