【牛客7872 J】树上启发式合并

【牛客7872 J】树上启发式合并

题意

树上启发式合并,求有多少点对满足,这两个点x和y相互之间不是祖先和后代的关系
同时满足\(val[x]+val[y]=2 * val[ lca(x,y) ]\)

题解

根据两个点不能互为祖先的要求可知:

比较可行的方式是枚举这个作为lca的结点,对于一个作为lca的结点
什么样的结点会以它为lca呢,当然是以它的不同的儿子为根结点的子树中的结点
因此,统计答案的方式也比较巧妙,对于一个作为lca的结点u

  • 首先遍历它的第一个儿子v1的那棵子树,用一个mp数组记录当前已经遍历过的结点中每个数出现的次数
    遍历第1个儿子那棵子树时把mp维护好。
  • 然后从第2个儿子开始,先对每一个结点v,获取到当前mp[2*val[u]-val[v]]的大小
    这表示能和结点v一起组成符合条件的点对有多少。
  • 这样查询完第2个儿子上所有节点后,再把第2个儿子子树上的所有结点的mp值维护好,依次循环这样一个过程

由于在做这个过程的时候必须保证mp值的准确,所以每次一个lca判断完后要清空该棵子树对mp值造成的影响。
那么考虑什么样的结点不用清空呢,那就是该结点作为父亲结点的最后一个儿子维护答案时不用清空。
那么我们怎样能使时间复杂度尽可能降低呢?那就是把所有儿子中最重的(子树大小最大的儿子)放在最后一个访问,这样就可以节省下清空它的时间复杂度,这就是启发式合并,运用最后一个儿子不需要清空的性质来降低时间复杂度。

Code

/****************************
* Author : W.A.R            *
* Date : 2020-10-31-20:44   *
****************************/
/*
*/
#include<stdio.h>
#include<string.h>
#include<math.h>
#include<algorithm>
#include<queue>
#include<map>
#include<unordered_map>
#include<stack>
#include<string>
#include<set>
#define mem(a,x) memset(a,x,sizeof(a))
using namespace std;
typedef long long ll;
const int maxn=1e6+10;
const ll mod=1e9+7;

namespace Fast_IO{
    const int MAXL((1 << 18) + 1);int iof, iotp;
    char ioif[MAXL], *ioiS, *ioiT, ioof[MAXL],*iooS=ioof,*iooT=ioof+MAXL-1,ioc,iost[55];
    char Getchar(){
        if (ioiS == ioiT){
            ioiS=ioif;ioiT=ioiS+fread(ioif,1,MAXL,stdin);return (ioiS == ioiT ? EOF : *ioiS++);
        }else return (*ioiS++);
    }
    void Write(){fwrite(ioof,1,iooS-ioof,stdout);iooS=ioof;}
    void Putchar(char x){*iooS++ = x;if (iooS == iooT)Write();}
    inline int read(){
        int x=0;for(iof=1,ioc=Getchar();(ioc<'0'||ioc>'9')&&ioc!=EOF;)iof=ioc=='-'?-1:1,ioc=Getchar();
		if(ioc==EOF)exit(0);
        for(x=0;ioc<='9'&&ioc>='0';ioc=Getchar())x=(x<<3)+(x<<1)+(ioc^48);return x*iof;
    }
    inline long long read_ll(){
        long long x=0;for(iof=1,ioc=Getchar();(ioc<'0'||ioc>'9')&&ioc!=EOF;)iof=ioc=='-'?-1:1,ioc=Getchar();
		if(ioc==EOF)exit(0);
        for(x=0;ioc<='9'&&ioc>='0';ioc=Getchar())x=(x<<3)+(x<<1)+(ioc^48);return x*iof;
    }
    template <class Int>void Print(Int x, char ch = '\0'){
        if(!x)Putchar('0');if(x<0)Putchar('-'),x=-x;while(x)iost[++iotp]=x%10+'0',x/=10;
        while(iotp)Putchar(iost[iotp--]);if (ch)Putchar(ch);
    }
    void Getstr(char *s, int &l){
        for(ioc=Getchar();ioc==' '||ioc=='\n'||ioc=='\t';)ioc=Getchar();
		if(ioc==EOF)exit(0);
        for(l=0;!(ioc==' '||ioc=='\n'||ioc=='\t'||ioc==EOF);ioc=Getchar())s[l++]=ioc;s[l] = 0;
    }
    void Putstr(const char *s){for(int i=0,n=strlen(s);i<n;++i)Putchar(s[i]);}
}
using namespace Fast_IO;
struct node{int to,nxt;}e[maxn];
int son[maxn],siz[maxn],cnt[maxn],head[maxn],val[maxn],ct;
ll ans;
unordered_map<int,int>mp;
void addE(int u,int v){e[++ct].to=v;e[ct].nxt=head[u];head[u]=ct;}
void dfs(int u,int fa){
	siz[u]=1;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;if(v==fa)continue;
		dfs(v,u);siz[u]+=siz[v];
		if(siz[v]>siz[son[u]])son[u]=v;
	}
}
void add(int u,int fa,int value){
	mp[val[u]]+=value;
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==fa)continue;
		add(v,u,value);
	}
}
void calc(int u,int fa,int lca){
	ans+=mp[2*val[lca]-val[u]];
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==fa)continue;
		calc(v,u,lca);
	}
}
void getAns(int u,int fa,bool heavy){
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==fa||v==son[u])continue;
		getAns(v,u,0);
	}
	if(son[u])getAns(son[u],u,1);
	for(int i=head[u];i;i=e[i].nxt){
		int v=e[i].to;
		if(v==fa||v==son[u])continue;
		calc(v,u,u);
		add(v,u,1);
	}
	mp[val[u]]++;
	if(!heavy)add(u,fa,-1);
}
int main(){
	int n=read();
	for(int i=1;i<=n;i++)val[i]=read();
	for(int i=1;i<n;i++){int u=read(),v=read();addE(u,v);addE(v,u);}
	dfs(1,0);
	getAns(1,0,0);
	printf("%lld\n",ans<<1);
	return 0;
}

posted @ 2020-10-31 21:50  AnranWu  阅读(125)  评论(0编辑  收藏  举报