P3412 仓鼠找sugar II
这种随机游走题用 \(DP\) 求的一般思路就是固定一个方向,然后考虑走到那个方向或从那个方向走过来的期望。
设 \(f_i\) 表示 \(i\) 走向自己的父亲的期望步数, \(g_i\) 表示 \(i\) 的父亲走向自己的期望步数。
列一下方程 (设 \(deg_u\) 表示 \(u\) 的度数)
\[f_u = 1 + \Sigma_{fa_v = u} (f_v+f_u) \times \frac{1}{deg_u}
\]
解一下方程
\[f_u = deg_u + \Sigma_{fa_v=u} f_v
\]
\(g_i\) 也解下方程
\[g_u = 1 + \frac{1}{deg_{fa_u}} \times (g_{fa_u} + g_u + (\Sigma_{fa_v = fa_u} f_{v} + g_u) - f_{u})
\]
\[g_u = f_{fa_u} - f_u + g_{fa_u}
\]
然后计算一下每个 \(f_u\) 和 \(g_u\) 被用到的次数(\(sz_u * (n - sz_u)\))。
#include<bits/stdc++.h>
#define RG register
#define LL long long
#define U(x, y, z) for(RG int x = y; x <= z; ++x)
#define D(x, y, z) for(RG int x = y; x >= z; --x)
#define update(x, y) (x = x + y >= mod ? x + y - mod : x + y)
using namespace std;
const int mod = 998244353;
namespace FastIO {
#define il inline
const int iL = 1 << 25;
char ibuf[iL], *iS = ibuf + iL, *iT = ibuf + iL;
#define GC() (iS == iT) ? \
(iT = (iS = ibuf) + fread(ibuf, 1, iL, stdin), (iS == iT) ? EOF : *iS++) : *iS++
void read(){}
template<typename _Tp, typename... _Tps>
void read(_Tp &x, _Tps &...Ar) {
x = 0; char ch = GC(); bool flg = 0;
for (; !isdigit(ch); ch = GC()) flg |= (ch == '-');
for (; isdigit(ch); ch = GC()) x = (x << 1) + (x << 3) + (ch ^ 48);
if (flg) x = -x;
read(Ar...);
}
char Out[iL], *iter = Out;
#define Flush() fwrite(Out, 1, iter - Out, stdout); iter = Out
template <class T>il void write(T x, char LastChar = '\n') {
int c[35], len = 0;
if (x < 0) {*iter++ = '-'; x = -x;}
do {c[++len] = x % 10; x /= 10;} while (x);
while (len) *iter++ = c[len--] + '0';
*iter++ = LastChar; Flush();
}
template <typename T> inline void writeln(T n){write(n, '\n');}
template <typename T> inline void writesp(T n){write(n, ' ');}
inline char Getchar(){ char ch; for (ch = GC(); !isalpha(ch); ch = GC()); return ch;}
inline void readstr(string &s) { s = ""; static char c = GC(); while (isspace(c)) c = GC(); while (!isspace(c)) s = s + c, c = GC();}
}
using namespace FastIO;
struct modint{
int x;
modint(int o=0){x=o;}
modint &operator = (int o){return x=o,*this;}
modint &operator +=(modint o){return x=x+o.x>=mod?x+o.x-mod:x+o.x,*this;}
modint &operator -=(modint o){return x=x-o.x<0?x-o.x+mod:x-o.x,*this;}
modint &operator *=(modint o){return x=1ll*x*o.x%mod,*this;}
modint &operator ^=(int b){
if(b<0)return x=0,*this;
b%=mod-1;
modint a=*this,c=1;
for(;b;b>>=1,a*=a)if(b&1)c*=a;
return x=c.x,*this;
}
modint &operator /=(modint o){return *this *=o^=mod-2;}
modint &operator +=(int o){return x=x+o>=mod?x+o-mod:x+o,*this;}
modint &operator -=(int o){return x=x-o<0?x-o+mod:x-o,*this;}
modint &operator *=(int o){return x=1ll*x*o%mod,*this;}
modint &operator /=(int o){return *this *= ((modint(o))^=mod-2);}
template<class I>friend modint operator +(modint a,I b){return a+=b;}
template<class I>friend modint operator -(modint a,I b){return a-=b;}
template<class I>friend modint operator *(modint a,I b){return a*=b;}
template<class I>friend modint operator /(modint a,I b){return a/=b;}
friend modint operator ^(modint a,int b){return a^=b;}
friend bool operator ==(modint a,int b){return a.x==b;}
friend bool operator !=(modint a,int b){return a.x!=b;}
bool operator ! () {return !x;}
modint operator - () {return x?mod-x:0;}
};
template <typename T> inline void chkmin(T &x, T y){x = x < y ? x : y;}
template <typename T> inline void chkmax(T &x, T y){x = x > y ? x : y;}
template <typename T> inline T Min(T x, T y){return x < y ? x : y;}
template <typename T> inline T Max(T x, T y){return x > y ? x : y;}
inline void FO(string s){freopen((s + ".in").c_str(), "r", stdin); freopen((s + ".out").c_str(), "w", stdout);}
const int N = 100010;
modint f[N], g[N], ans;
int n, sz[N], deg[N];
vector<int> e[N];
inline modint Qpow(modint a, int b) {
modint res = 1;
for (; b; b >>= 1) {
if (b & 1) (res *= a);
(a *= a);
}
return res;
}
inline void dfs(int u, int fa) {
sz[u] = 1;
f[u] = deg[u];
for (auto v: e[u]) {
if (v ^ fa) {
dfs(v, u);
sz[u] += sz[v];
f[u] += f[v];
}
}
}
inline void dp(int u, int fa) {
if (u != 1) g[u] = f[fa] - f[u] + g[fa];
// cerr << u << " " << g[u].x << "\n";
ans += (f[u] + g[u]) * sz[u] * (n - sz[u]);
// cerr << sz[u] << " " << n - sz[u] << " " << f[u].x << ' ' << g[u].x << " " << ans.x << "\n";
for (auto v: e[u]) {
if (v ^ fa) {
dp(v, u);
}
}
}
int main(){
// FO("P3412");
read(n);
U(i, 2, n) {
int u, v;
read(u, v);
e[u].push_back(v);
e[v].push_back(u);
deg[u]++;
deg[v]++;
}
dfs(1, 0);
dp(1, 0);
modint B = n;
B = B * n;
// cerr << ans.x << "\n";
// cerr << B.x + mod << " " << (LL) n * n % mod << "\n";
writeln((Qpow(B, mod - 2) * ans).x);
return 0;
}