2020ccpc威海站 C题 Rencontre(树形dp)

题意: 给出一棵树,有三个人,每个人都一个集合,表示他们可能会选的点,当三个人确定选的点a,b, c后,d=min(dis(a,v)+dis(b,v)+dis(c,v)),求d的期望值。
练了很多树形dp,只可惜比赛还是不会这道树形dp
题解:
画个图去试验各种情况,可以发现,当三个点固定时,d=1/2*(dis(a,b)+dis(a,c)+dis(b,c)),然后根据期望公式:
E(X+Y)=E(X)+E(Y)
E(aX)=aE(X)
可以得出:E(1/2
(dis(a,b)+dis(a,c)+dis(b,c)))=1/2*(E(dis(a,b))+E(dis(a,c))+E(dis(b,c)))。
对于dis(a,b),算出每条边的贡献,即这条边的贡献为子树中b的数量子树外a的数量+子树中a的数量 * 子树外b的数量
则E(dis(a,b))=贡献/(a的数量
b的数量)
代码:

#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<queue>
#include<map>
#include<stack>
#include<set>
#define iss ios::sync_with_stdio(false)
using namespace std;
typedef unsigned long long ull;
typedef long long ll;
typedef pair<int ,int > pii;
const ll mod=998244353;
const ll INF=0x3f3f3f3f;
const int MAXN=2e5+5;
struct node
{
    int to;
    int next;
    ll cost;
}e[MAXN<<1];
int cnt=0;
int head[MAXN];
int op[5];
map<int ,int >m1,m2,m3;
int dp[MAXN][4];
ll ans1=0,ans2=0,ans3=0;
void add(int u,int v,ll w)
{
    e[cnt].to=v;
    e[cnt].cost=w;
    e[cnt].next=head[u];
    head[u]=cnt++;
}
void dfs(int u,int f)
{
    if(m1.count(u)) dp[u][1]=1;
    if(m2.count(u)) dp[u][2]=1;
    if(m3.count(u)) dp[u][3]=1;
    for(int i=head[u];i!=-1;i=e[i].next)
    {
        int v=e[i].to;
        if(v==f) continue;
        dfs(v,u);
        dp[u][1]+=dp[v][1];
        dp[u][2]+=dp[v][2];
        dp[u][3]+=dp[v][3];
    }
}
void solve(int u,int f)
{
    for(int i=head[u];i!=-1;i=e[i].next)
    {
        int v=e[i].to;
        if(v==f) continue;
        ans1+=e[i].cost*dp[v][1]*(op[2]-dp[v][2])+e[i].cost*dp[v][2]*(op[1]-dp[v][1]);
        ans2+=e[i].cost*dp[v][2]*(op[3]-dp[v][3])+e[i].cost*dp[v][3]*(op[2]-dp[v][2]);
        ans3+=e[i].cost*dp[v][1]*(op[3]-dp[v][3])+e[i].cost*dp[v][3]*(op[1]-dp[v][1]);
        solve(v,u);
    }
}
int main()
{
    int n;
    scanf("%d",&n);
    memset(head,-1,sizeof head);
    int u,v;
    ll w;
    for(int i=1;i<=n-1;i++)
    {
        scanf("%d%d%lld",&u,&v,&w);
        add(u,v,w);
        add(v,u,w);
    }
    for(int i=1;i<=3;i++)
    {
        //cout<<i<<endl;
        scanf("%d",&op[i]);
        int x;
        for(int j=1;j<=op[i];j++)
        {
            scanf("%d",&x);
            if(i==1) m1[x]=1;
            if(i==2) m2[x]=1;
            if(i==3) m3[x]=1;
        }
    }
    dfs(1,-1);
    solve(1,-1);
    double d1=(double)ans1/((double)op[1]*(double)op[2]);
    double d2=(double)ans2/((double)op[2]*(double)op[3]);
    double d3=(double)ans3/((double)op[1]*(double)op[3]);
    double ans=(d1+d2+d3)/2;
    printf("%.9lf\n",ans);
}


posted @ 2020-10-28 21:35  TheBestQAQ  阅读(87)  评论(0)    收藏  举报