习题:Deblo(树形DP)

题目:

大约30年前,年轻的Krešo首次参加了全国信息学竞赛。与今天相似的,比赛的开幕都是由一系列演讲者组成,他们试图通过演讲激励参加者们并展现竞赛的重要性。观众们热情地每隔几秒钟鼓掌一次,但Krešo被其中一位发言者的一句话激怒了,因为这位发言者声称他更赞赏逻辑运算而非逻辑运算,因为无论获胜者是谁,Mirko和Slavko都会是这次竞赛的获胜者,而不是Mirko或Slavko。Krešo这时站起来,开始向大家解释一种名为“异或”的东西。在他的演讲结束后,他给尊敬的演讲者布置了这样一个任务来验证他的解释。

存在由n个节点组成的树,其中每个节点分配一个值,这个树上的路径值定义为这条路所有节点的值的异或。你的任务是确定树上所有路径的值的总和(这里的路径包括只有一个节点的路径)。

30年后,Krešo终于说服COCI的出题人将这个任务纳入其中一环,让我们恢复Krešo对编程竞赛未来的信心。

输入格式

第一行包含正整数n(1≤n≤100000),表示这棵树上的节点数。

第二行包含n个数字vi(0≤vi≤3000000)第i个数字表示第i个节点的价值。

接下来n-1行,每行输入两个数字aj和bj(1≤aj,bj≤n),表示在节点aj和bj之间有一条边。

输出格式

输出一个数,表示这棵树的价值。

样例

样例输入1

3
1 2 3
1 2
2 3

样例输出1

10

样例输入2

5
2 3 4 2 1
1 2
1 3
3 4
3 5

样例输出2

64

样例输入3

6
5 4 1 3 3 3
3 1
3 5
4 3
4 2
2 6

样例输出3

85

思路:
对于^操作,我们都知道,它就是转换成二进制后的不进位加法

之后我们就考虑对于每一个节点我们都可以知道他的答案

对于他的父节点,其实可以用它的儿子节点的值求出答案

也就是将树进行划分。

很明显,对于重心我们应该特殊考虑

代码:

#include<bits/stdc++.h>
using namespace std;
bool f[100005];
long long tot[100005],w[100005],cnt[2][25],n;
struct edge
{
	long long pre,to;
}a[200005];
long long tot_edge,head[200005];
long long ans;
void adde(long long u,long long v)
{
	a[++tot_edge].pre=head[u];
	a[tot_edge].to=v;
	head[u]=tot_edge;
}
long long solve_root(long long u,long long fat,long long all_son)
{
	// cout<<"solve_root:"<<u<<" "<<fat<<" "<<S<<endl;
	long long pos=0;
	for (long long i=head[u]; i; i=a[i].pre)
	{
		long long v=a[i].to;
		if (v==fat||f[v]) 
			continue;
		if (tot[v]>tot[pos]) pos=v;
	}
	if (pos!=0&&tot[pos]>=all_son/2) 
		return solve_root(pos,u,all_son);
	return u;
}
void solve_distance(long long u,long long fat,long long num)
{
	// cout<<"solve_dis:"<<u<<" "<<fat<<" "<<num<<endl;
	num^=w[u];
	tot[u]=1;
	for(long long i=0;i<=22;i++)
		cnt[(num & (1<<i))>0][i]++;

	for (long long i=head[u]; i; i=a[i].pre)
	{
		long long v=a[i].to;
		if (v==fat||f[v]) 
			continue;
		solve_distance(v,u,num);
		tot[u]+=tot[v];
	}
}
void solve_now(long long u,long long fat,long long num)
{
	// cout<<"solve_now:"<<u<<" "<<fat<<" "<<num<<endl;
	num^=w[u];
	for(long long i=0;i<=22;i++)
		ans+= (1<<i)*cnt[(num & (1<<i))==0][i];
	for (long long i=head[u]; i; i=a[i].pre)
	{
		long long v=a[i].to;
		if (v==fat||f[v])
		 continue;
		solve_now(v,u,num);
	}
}
void solve(long long u,long long fat)
{
	// cout<<"solve_all:"<<u<<" "<<fat<<endl;
	long long root=solve_root(u,fat,tot[u]);
	f[root]=1;
	memset(cnt,0,sizeof(cnt));
	for(long long i=0;i<=22;i++)
		cnt[(w[root]&(1<<i))>0][i]++;
	ans+=w[root];
	for (long long i=head[root]; i; i=a[i].pre)
	{
		long long v=a[i].to;
		if (f[v]) 
			continue;
		solve_now(v,0,0);
		solve_distance(v,0,w[root]);
	}
	for (long long i=head[root]; i; i=a[i].pre)
	{
		long long v=a[i].to;
		if (f[v]) 
			continue;
		solve(v,u);
	}
}
int main()
{
	cin>>n;
	for(int i=1;i<=n;i++)
		cin>>w[i];
	for(int i=2;i<=n;i++)
	{
		int u,v;
		cin>>u;
		cin>>v;
		adde(u,v);
		adde(v,u);
	}
	
	long long rt=solve_root(1,0,n);
	solve_distance(rt,0,0);
	solve(rt,0);
	cout<<ans;
	return 0;
 }

 

posted @ 2019-08-24 19:30  loney_s  阅读(137)  评论(0)    收藏  举报