仓鼠找sugar II

因为起点终点是不确定的,因此我们可以考虑枚举起点和终点来计算贡献,但由于期望逆推的原理,可以发现当终点相同时所有起点的期望计算方式都是类似的,于是我们可以考虑枚举每个终点。那么对于每个终点 \(x\),我们令 \(dp_i\) 为以 \(i\) 为起点到这个终点 \(x\) 的期望步数,于是可以得到转移:

\[dp_u = \dfrac{\sum\limits dp_v}{deg_u} + 1 \]

可以发现这个 \(dp\) 的转移并不是一个 \(\rm DAG\) 的形式,我们大可以使用高斯消元来解出每个 \(dp_i\) 但这样显然是不能过题的。考虑到每个点的转移方向只有父亲或者儿子两种,可以使用 分手是祝愿 中的方法将这个期望设置成转移向一边的期望,即我们重新定义 \(dp_i\)\(i\) 走到其父亲的期望步数。于是有转移(\(deg\)\(d\) 简写):

\[\begin{aligned} dp_u &= \frac{1}{d} \times 1 + \frac{1}{d} \sum (dp_v + dp_u + 1)\\ &= 1 + \frac{d - 1}{d} \times dp_u + \frac{1}{d} \sum dp_v \end{aligned} \]

移项可得:

\[dp_u = d + \sum dp_v \]

但可能统计答案的时候终点和这个点在根的不同的子树中,这样显然是不能使用这个 \(dp_i\) 来简单相加的,于是我们需要枚举每个终点之后,以这个终点作为根重新构成一颗新的树。那么每个点的答案就是这个点到根路径上的 \(dp\) 值。

进一步地我们可以发现上面的枚举每个终点本质上是对这棵树进行换根操作,并且可以发现当我们的根从 \(u \rightarrow v\)\(v\)\(u\) 的儿子)时,发生改变的 \(dp\) 值只有 \(u, v\) 两个点,令 \(f_i\) 为当前树下在 \(fa_i\) 的儿子中除了 \(i\)\(dp\) 值以外的 \(dp\) 值之和,每次换根 \(u \rightarrow v, dp\) 值的改变实际上是:

\[dp_u = f_v + deg_u, dp_v = 0 \]

于是我们能很轻松的维护出换根时 \(dp\) 值的变化,再来考虑如何计算答案,显然我们不能再重新算一遍每个点到根路径上的 \(dp\) 值,因为我们改变的 \(dp\) 值只有两个,我们可以考虑每个 \(dp\) 值对答案的贡献,显然是 \(dp_i \times s_i\)\(s_i\) 为以 \(i\) 为根的子树大小),于是我们只需在沿途维护出 \(s\) 的变化即可维护出每次答案的变化,将每次算出来的答案直接相加即可。因为我们算的都是给定起点终点的期望,因此计算最终的期望时还用除以总方案数即 \(n ^ 2\)

#include<bits/stdc++.h>
using namespace std;
#define N 100000 + 5
#define Mod 998244353
#define rep(i, l, r) for(int i = l; i <= r; ++i)
#define Next(i, u) for(int i = h[u]; i; i = e[i].next)
struct edge{
    int v, next;
}e[N << 1];
int n, u, v, tot, ans, tmp, h[N], s[N], dp[N], sum[N], deg[N];
int read(){
    char c; int x = 0, f = 1;
    c = getchar();
    while(c > '9' || c < '0'){ if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}
void add(int u, int v){
    e[++tot].v = v, e[tot].next = h[u], h[u] = tot;
    e[++tot].v = u, e[tot].next = h[v], h[v] = tot;
}
int Inc(int a, int b){
    return (a += b) >= Mod ? a - Mod : a;
}
int Dec(int a, int b){
    return (a -= b) < 0 ? a + Mod : a;
}
int Mul(int a, int b){
    return 1ll * a * b % Mod;
}
int Qpow(int a, int b){
    int ans = 1;
    while(b){
        if(b & 1) ans = Mul(ans, a);
        a = Mul(a, a), b >>= 1;
    }
    return ans;
}
void dfs1(int u, int fa){
    s[u] = 1;
    Next(i, u){
        int v = e[i].v; if(v == fa) continue;
        dfs1(v, u), s[u] += s[v]; dp[u] = Inc(dp[u], dp[v]);
    }
    sum[u] = dp[u], dp[u] = Inc(dp[u], deg[u]); if(!fa) dp[u] = 0;
}
void dfs2(int u, int fa){
    sum[u] = Inc(sum[u], dp[fa]);
    Next(i, u){
        int v = e[i].v; if(v == fa) continue;
        int Dp1 = dp[u], Dp2 = dp[v], T = tmp;
        dp[u] = Inc(Dec(sum[u], dp[v]), deg[u]), dp[v] = 0;
        tmp = Inc(tmp, Dec(Mul(dp[u], n - s[v]), Mul(Dp2, s[v]))), ans = Inc(ans, tmp);
        dfs2(v, u);
        dp[u] = Dp1, dp[v] = Dp2, tmp = T;
    }
    sum[u] = Dec(sum[u], dp[fa]);
}
int main(){
    n = read();
    rep(i, 1, n - 1) u = read(), v = read(), add(u, v), ++deg[u], ++deg[v];
    dfs1(1, 0);
    rep(i, 1, n) tmp = Inc(tmp, Mul(dp[i], s[i])), ans = tmp; 
    dfs2(1, 0);
    printf("%d", Mul(ans, Qpow(Mul(n, n), Mod - 2)));
    return 0;
}
posted @ 2020-08-12 22:47  Achtoria  阅读(72)  评论(0编辑  收藏  举报