【NOIP 校内模拟】T3 忘了是啥名字了(dfs序+树状数组)

对于当前新加入的一条路径 他产生的贡献分为两种

1.另一条路径的LCA在当前路径上
2.当前路径的LCA在另一条上

对于情况1:

可以维护当前点到根节点有多少个LCA,查询只需查询u,v,-2*lca(u,v),修改需要对lca的子树+1

对于情况2:

显然的树上差分,查询就是lca子树的前缀和,修改u++,v++,lca-2

即开两个树状数组,一个支持单点查询+区间修改,一个支持单点修改+区间查询,不嫌麻烦的话可以尝试线段树。

需要开栈,某OJ栈空间感人。

#include<bits/stdc++.h>
#define N 1000005
#define M 1000005
#define ll long long
using namespace std;
template<class T>
inline void read(T &x)
{
	x=0; int f=1;
	static char ch=getchar();
	while((!isdigit(ch))&&ch!='-')	ch=getchar();
	if(ch=='-')	f=-1,ch=getchar();
	while(isdigit(ch))	x=x*10+ch-'0',ch=getchar();
	x*=f;
}
//1e6
struct Edge
{
	int to,next;
}edge[2*N];
int n,m,tot,first[N];
inline void addedge(int x,int y)
{
	tot++;
	edge[tot].to=y; edge[tot].next=first[x]; first[x]=tot;
}
int up[N][27],depth[N],st[N],sign,ed[N];
ll con[N];
void dfs(int now,int fa)
{
	up[now][0]=fa;
	depth[now]=depth[fa]+1;
	st[now]=++sign;
	for(int i=1;i<=25;i++)	up[now][i]=up[up[now][i-1]][i-1];
	for(int u=first[now];u;u=edge[u].next)
	{
		int vis=edge[u].to;
		if(vis==fa)	continue;
		dfs(vis,now);
	}
	ed[now]=sign;
}
inline int getlca(int x,int y)
{
	if(depth[x]<depth[y])	swap(x,y);
	for(int i=25;i>=0;i--) if(depth[up[x][i]]>=depth[y]) x=up[x][i];
	if(x==y)	return x;
	for(int i=25;i>=0;i--) if(up[x][i]!=up[y][i])	x=up[x][i],y=up[y][i];
	return up[x][0];
}
inline int lowbit(int x)
{
	return x&(-x);
}
struct BIT
{
	int n;
	ll tree[N];
	inline void getn(int x)
	{
		n=x;
	}
	inline void update(int x,ll del)
	{
		for(int i=x;i<=n;i+=lowbit(i))	tree[i]+=del;
	}
	inline ll query(int x)
	{
		ll ans=0;
		for(int i=x;i;i-=lowbit(i))	ans+=tree[i];
		return ans;
	}
}bit1,bit2;	//区间加单点查  单点加区间查 	其实就是差分,普通 bit
int main()
{
    ll size=40<<20;//40M
    __asm__ ("movq %0,%%rsp\n"::"r"((char*)malloc(size)+size));//提交用这个 
	read(n),read(m);
	for(register int i=1;i<n;i++)
	{
		int x,y;
		read(x),read(y);
		addedge(x,y); addedge(y,x);
	}
	dfs(1,0);
	bit1.getn(n); bit2.getn(n);
	ll ans=0;
	//需要分两种情况讨论:其他的lca在这条路径上  自己的lca在其他路径上 
	for(int i=1,u,v,lca;i<=m;i++)
	{
		read(u); read(v); lca=getlca(u,v);
		ans=ans+bit1.query(st[u])+bit1.query(st[v])-2*bit1.query(st[lca]);
		ans=ans+bit2.query(ed[lca])-bit2.query(st[lca]-1);
		ans=ans+con[lca];
		con[lca]++;
		bit1.update(st[lca],1); bit1.update(ed[lca]+1,-1);
		bit2.update(st[u],1); bit2.update(st[v],1); bit2.update(st[lca],-2);
	}
	cout<<ans;
	exit(0);
}
posted @ 2018-11-05 20:30  Patrickpwq  阅读(132)  评论(0编辑  收藏  举报