解题报告-小 A 的树

小 A 的树

题目描述

小 A 有一棵 \(N\) 个点的树,每个点都有一个小于 \(2^{20}\) 的非负整数权值。现在小 A 从树中随机选择一个点 \(x\),再随机选择一个点 \(y\)\(x\)\(y\)可以是同一个点),并对从 \(x\)\(y\) 的路径上 所有的点的权值分别做按位与、按位或、异或运算,最终会求得三个整数。小 A 想知道,他求出的三个数的期望值分别是多少。

输入描述

输入包含多组测试数据。

第一行,一个整数 \(T\),表示测试数据的组数。

接下来 \(T\) 节,每节表示一组测试数据,格式如下:

  • 第一行,一个整数 \(N\)
  • 第二行,\(N\) 个整数,其中第 \(i\) 个整数表示第 \(i\) 个点的权值。
  • 接下来 \(N-1\) 行,每行两个整数 \(u\)\(v\),表示树中有一条连接 \(u\)\(v\) 的边。

输出描述

\(T\) 行,每行三个浮点数,保留三位小数,其中第 \(i\) 行的三个浮点数表示第 \(i\) 组数据对应的按位与、按位或、异或的期望。

输入输出描述 #1

输入样例 #1

1
4
1 2 3 4
1 2
2 3
2 4

输出样例 #1

0.875 4.250 3.375

提示/说明

数据范围

  • 对于 \(20\)% 的数据,\(1 \leq N \leq 10^3\)
  • 另外 \(20\)% 的数据,\(N\) 个点构成一条链。
  • 对于 \(100\)% 的数据,\(1 \leq N \leq 10^5\)\(1\leq T\leq 5\)

对于不同路径的判断

设有两个树上的点 \(u\)\(v\)

  • \(u=v\),则路径 \(u \rightarrow v\) 和路径 \(v \rightarrow u\) 是相同路径。
  • \(u \neq v\),则路径 \(u \rightarrow v\) 和路径 \(v \rightarrow u\) 是不同路径。

解题报告

神秘树形 DP。

首先要有一个思路:按位与、按位或、异或三个运算都是各进制位独立运算,可以考虑分开进行

所以,我们可以求出三个运算使每个二进制位为 \(1\) 的期望,最后统计总期望。

由于总路径数量一定,为 \(N \times N\),所以我们对于每个运算只需求出可以使每个二进制位为 \(1\) 的路径数。

一个很常见的思路:对于每个在以 \(u\) 为根的子树,统计以 \(u\) 为一个端点的路径的价值,再转换成每条在以 \(u\) 为根的子树的路径的价值。其实就是把 \(u\) 作为 LCA 的路径 \(s \rightarrow t\),把这条路径分成 \(s \rightarrow u\)\(u \rightarrow t\) 两天路径,单独处理出每个 \(s \rightarrow u\)\(u \rightarrow t\) 的路径的价值,既可以计算出每个路径的价值。

然后就是一个简单的树形 DP。

\(dp[u][i][0/1/2]\) 分别表示以节点 \(u\) 的子树中以 \(u\) 为端点的路径中按位与、按位或、异或后可以使第 \(i\) 个二进制位为 \(1\) 的路径数。

\(val[u][i]\) 表示节点 \(u\) 的权值第 \(i\) 个二进制位状态。

\(siz[u]\) 表示子树 \(u\) 的大小,同时也等价于子树内以 \(u\) 为端点的路径的条数。

\(v\)\(u\) 的一个子节点。

转移方程很好推,代码如下:

for(int i=1;i<M;j++)
{
    if(val[u][i])
    {
         dp[u][i][0]+=dp[v][i][0];
         dp[u][i][1]+=siz[v];
         dp[u][i][2]+=siz[v]-dp[v][i][2];
    }
    else
    {
         dp[u][i][1]+=dp[v][j][1];
         dp[u][j][2]+=dp[v][i][2];
    }
}

然后就可以对子树 \(u\) 分别统计三个运算中使第 \(i\) 个二进制位为 \(1\) 的路径总数 \(cnt[0/1/2][i]\)

// 统计u->u的路径
for(int j=1;j<M;j++)
{
    cnt[0][j]+=val[u][j];
    cnt[1][j]+=val[u][j];
    cnt[2][j]+=val[u][j];
}

// 统计子树内经过 u 且起始点不同的路径
for(auto v:e[u])
{
    if(v==fa) continue;
    for(int j=1;j<M;j++)
    {
        cnt[0][j]+=2*dp[u][i][0]*dp[v][i][0];
        cnt[1][j]+=2*(siz[u]*siz[v]-(siz[u]-dp[u][j][1])*(siz[v]-dp[v][j][1]));
        cnt[2][j]+=2*(dp[u][j][2]*(siz[v]-dp[v][j][2])+dp[v][j][2]*(siz[u]-dp[u][j][2]));
    }
}

然后统计答案就好了,总代码如下:

#include<bits/stdc++.h>
#define int long long
using namespace std;
const int INF=0x3f3f3f3f;
const int N=1001100;
const int M=32;

#define ckmax(x,y) ( x=max(x,y) )
#define ckmin(x,y) ( x=min(x,y) )

inline int read()
{
	int f=1,x=0; char ch=getchar();
	while(!isdigit(ch)) { if(ch=='-') f=-1; ch=getchar(); }
	while(isdigit(ch))  { x=x*10+ch-'0';    ch=getchar(); }
	return f*x;
}

struct node
{
    int siz;
    int f[3][M];
    bool val[M];
}p[N];
int n;
int cnt[3][M];
vector<int> e[N];

inline void addedge(int u,int v)
{
    e[u].push_back(v);
    e[v].push_back(u);
}

inline void Clear()
{
    memset(cnt,0,sizeof(cnt));
    for(int i=1;i<=n;i++)
    {
        memset(p[i].val,0,sizeof(p[i].val));
        memset(p[i].f,0,sizeof(p[i].f));
        p[i].siz=0;
        e[i].clear();
    }
}

inline void debug(int u)
{
    printf("%d\n",u);
    for(int i=1;i<M;i++)
      printf("%d ",p[u].val[i]);
    cout<<endl;
    for(int j= 0;j<3;j++,putchar('\n'))
     for(int i=1;i<M;i++,putchar(' '))
       cout<<p[u].f[j][i];
    cout<<endl<<endl;
}

void dfs(int u,int fa)
{
    p[u].siz=1;
    for(int i=0;i<e[u].size();i++)
    {
        int v=e[u][i];
        if(v==fa) continue;
        dfs(v,u);
        for(int j=1;j<M;j++)
        {
            cnt[0][j]+=2*p[u].f[0][j]*p[v].f[0][j];
            cnt[1][j]+=2*(p[u].siz*p[v].siz-(p[u].siz-p[u].f[1][j])*(p[v].siz-p[v].f[1][j]));
            cnt[2][j]+=2*(p[u].f[2][j]*(p[v].siz-p[v].f[2][j])+p[v].f[2][j]*(p[u].siz-p[u].f[2][j]));
        }
        p[u].siz+=p[v].siz;
        for(int j=1;j<M;j++)
        {
            if(p[u].val[j])
            {
                p[u].f[0][j]+=p[v].f[0][j];
                p[u].f[1][j]+=p[v].siz;
                p[u].f[2][j]+=p[v].siz-p[v].f[2][j];
            }
            else
            {
                p[u].f[1][j]+=p[v].f[1][j];
                p[u].f[2][j]+=p[v].f[2][j];
            }
        }

    }
    // debug(u);
}

signed main()
{
	freopen("tree.in","r",stdin);
	freopen("tree.out","w",stdout);
    int Q=read();
    while(Q--)
    {
        n=read();
        for(int i=1;i<=n;i++)
        {
            int x=read();
            for(int j=1;j<M;j++)
            {
                p[i].val[j]=(x>>j-1)&1;
                p[i].f[0][j]=p[i].val[j];
                p[i].f[1][j]=p[i].val[j];
                p[i].f[2][j]=p[i].val[j];
                cnt[0][j]+=p[i].val[j];
                cnt[1][j]+=p[i].val[j];
                cnt[2][j]+=p[i].val[j];
            }
        }
        for(int i=1;i<n;i++)
          addedge(read(),read());
        dfs(1,0);
        int tot=n*n;
        double ans0=0,ans1=0,ans2=0;
        for(int i=1;i<M;i++)
        {
            int tmp=(1<<i-1);
            ans0+=(double)tmp*cnt[0][i]/(double)tot;
            ans1+=(double)tmp*cnt[1][i]/(double)tot;
            ans2+=(double)tmp*cnt[2][i]/(double)tot;
        }
        printf("%.3lf %.3lf %.3lf\n",ans0,ans1,ans2);
        Clear();
    }
	return 0;
}
posted @ 2025-10-03 09:58  南北天球  阅读(9)  评论(0)    收藏  举报