习题:Deblo(树形DP)
题目:
大约30年前,年轻的Krešo首次参加了全国信息学竞赛。与今天相似的,比赛的开幕都是由一系列演讲者组成,他们试图通过演讲激励参加者们并展现竞赛的重要性。观众们热情地每隔几秒钟鼓掌一次,但Krešo被其中一位发言者的一句话激怒了,因为这位发言者声称他更赞赏逻辑运算而非逻辑运算,因为无论获胜者是谁,Mirko和Slavko都会是这次竞赛的获胜者,而不是Mirko或Slavko。Krešo这时站起来,开始向大家解释一种名为“异或”的东西。在他的演讲结束后,他给尊敬的演讲者布置了这样一个任务来验证他的解释。
存在由n个节点组成的树,其中每个节点分配一个值,这个树上的路径值定义为这条路所有节点的值的异或。你的任务是确定树上所有路径的值的总和(这里的路径包括只有一个节点的路径)。
30年后,Krešo终于说服COCI的出题人将这个任务纳入其中一环,让我们恢复Krešo对编程竞赛未来的信心。
输入格式
第一行包含正整数n(1≤n≤100000),表示这棵树上的节点数。
第二行包含n个数字vi(0≤vi≤3000000)第i个数字表示第i个节点的价值。
接下来n-1行,每行输入两个数字aj和bj(1≤aj,bj≤n),表示在节点aj和bj之间有一条边。
输出格式
输出一个数,表示这棵树的价值。
样例
样例输入1
3
1 2 3
1 2
2 3
样例输出1
10
样例输入2
5
2 3 4 2 1
1 2
1 3
3 4
3 5
样例输出2
64
样例输入3
6
5 4 1 3 3 3
3 1
3 5
4 3
4 2
2 6
样例输出3
85
思路:
对于^操作,我们都知道,它就是转换成二进制后的不进位加法
之后我们就考虑对于每一个节点我们都可以知道他的答案
对于他的父节点,其实可以用它的儿子节点的值求出答案
也就是将树进行划分。
很明显,对于重心我们应该特殊考虑
代码:
#include<bits/stdc++.h>
using namespace std;
bool f[100005];
long long tot[100005],w[100005],cnt[2][25],n;
struct edge
{
long long pre,to;
}a[200005];
long long tot_edge,head[200005];
long long ans;
void adde(long long u,long long v)
{
a[++tot_edge].pre=head[u];
a[tot_edge].to=v;
head[u]=tot_edge;
}
long long solve_root(long long u,long long fat,long long all_son)
{
// cout<<"solve_root:"<<u<<" "<<fat<<" "<<S<<endl;
long long pos=0;
for (long long i=head[u]; i; i=a[i].pre)
{
long long v=a[i].to;
if (v==fat||f[v])
continue;
if (tot[v]>tot[pos]) pos=v;
}
if (pos!=0&&tot[pos]>=all_son/2)
return solve_root(pos,u,all_son);
return u;
}
void solve_distance(long long u,long long fat,long long num)
{
// cout<<"solve_dis:"<<u<<" "<<fat<<" "<<num<<endl;
num^=w[u];
tot[u]=1;
for(long long i=0;i<=22;i++)
cnt[(num & (1<<i))>0][i]++;
for (long long i=head[u]; i; i=a[i].pre)
{
long long v=a[i].to;
if (v==fat||f[v])
continue;
solve_distance(v,u,num);
tot[u]+=tot[v];
}
}
void solve_now(long long u,long long fat,long long num)
{
// cout<<"solve_now:"<<u<<" "<<fat<<" "<<num<<endl;
num^=w[u];
for(long long i=0;i<=22;i++)
ans+= (1<<i)*cnt[(num & (1<<i))==0][i];
for (long long i=head[u]; i; i=a[i].pre)
{
long long v=a[i].to;
if (v==fat||f[v])
continue;
solve_now(v,u,num);
}
}
void solve(long long u,long long fat)
{
// cout<<"solve_all:"<<u<<" "<<fat<<endl;
long long root=solve_root(u,fat,tot[u]);
f[root]=1;
memset(cnt,0,sizeof(cnt));
for(long long i=0;i<=22;i++)
cnt[(w[root]&(1<<i))>0][i]++;
ans+=w[root];
for (long long i=head[root]; i; i=a[i].pre)
{
long long v=a[i].to;
if (f[v])
continue;
solve_now(v,0,0);
solve_distance(v,0,w[root]);
}
for (long long i=head[root]; i; i=a[i].pre)
{
long long v=a[i].to;
if (f[v])
continue;
solve(v,u);
}
}
int main()
{
cin>>n;
for(int i=1;i<=n;i++)
cin>>w[i];
for(int i=2;i<=n;i++)
{
int u,v;
cin>>u;
cin>>v;
adde(u,v);
adde(v,u);
}
long long rt=solve_root(1,0,n);
solve_distance(rt,0,0);
solve(rt,0);
cout<<ans;
return 0;
}

浙公网安备 33010602011771号