CodeForces 766E Mahmoud and a xor trip

树形$dp$,位运算。

按位统计贡献。

每一位:统计以$u$为根的子树中,一端为$u$,另一端为子树中节点的链异或为$1$和$0$的方案数,$dp$一下就可以得到。还要统计$u$的子孙$v$跨过$u$进行组合的情况。

#pragma comment(linker, "/STACK:1024000000,1024000000")
#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#include<vector>
#include<map>
#include<set>
#include<queue>
#include<stack>
#include<ctime>
#include<iostream>
using namespace std;
typedef long long LL;
const double pi=acos(-1.0);
void File()
{
    freopen("D:\\in.txt","r",stdin);
    freopen("D:\\out.txt","w",stdout);
}
template <class T>
inline void read(T &x)
{
    char c = getchar();
    x = 0;
    while(!isdigit(c)) c = getchar();
    while(isdigit(c))
    {
        x = x * 10 + c - '0';
        c = getchar();
    }
}

int n,a[100010];
long long f[100010][2],p0[100010],p1[100010],ans;
int flag[100010];
vector<int>v[100010],t[100010];

void pre(int x)
{
    flag[x]=1;
    for(int i=0;i<t[x].size();i++)
    {
        if(flag[t[x][i]]==1) continue;
        v[x].push_back(t[x][i]);
        pre(t[x][i]);
    }
}

void dfs(int x,int y)
{
    int g =(a[x]&(1<<y))?1:0;

    for(int i=0;i<v[x].size();i++)
    {
        dfs(v[x][i],y);
        f[x][g^0]=f[x][g^0]+f[v[x][i]][0];
        f[x][g^1]=f[x][g^1]+f[v[x][i]][1];
    }

    f[x][g]++;

    ans=ans+f[x][1]*(LL)(1<<y);

    if(v[x].size()<=1) return ;

    p0[v[x].size()]=p1[v[x].size()]=0;

    for(int i=v[x].size()-1;i>=0;i--)
    {
        p0[i]=p0[i+1]+f[v[x][i]][0];
        p1[i]=p1[i+1]+f[v[x][i]][1];
    }

    for(int i=0;i<v[x].size()-1;i++)
    {
        ans=ans
                +f[v[x][i]][0]*p0[i+1]*(LL)(0^g)*(LL)(1<<y)
                +f[v[x][i]][0]*p1[i+1]*(LL)(1^g)*(LL)(1<<y)
                +f[v[x][i]][1]*p0[i+1]*(LL)(1^g)*(LL)(1<<y)
                +f[v[x][i]][1]*p1[i+1]*(LL)(0^g)*(LL)(1<<y);
    }

}

int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++) scanf("%d",&a[i]);
    for(int i=1;i<=n-1;i++)
    {
        int x,y; scanf("%d%d",&x,&y);
        t[x].push_back(y);
        t[y].push_back(x);
    }

    pre(1);

    for(int i=0;i<23;i++)
    {
        memset(flag,0,sizeof flag);
        memset(f,0,sizeof f);
        dfs(1,i);
    }

    printf("%lld\n",ans);
    return 0;
}

 

posted @ 2017-02-20 20:58  Fighting_Heart  阅读(187)  评论(0编辑  收藏  举报