# [ BZOJ 3451 ] Normal

## Solution

$E(depth[x])=\sum_{y=1}^n P(x\in subtree[y])$

$\sum_{x=1}^n\sum_{j=1}^n \frac{1}{dis(i,j) + 1}=\sum_{len = 0}^n \frac{cnt[i]}{i + 1}$

### 容斥做法

#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 65537
#define mod 998244353
using namespace std;
typedef long long ll;

inline int rd() {
int x = 0;
char c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) {
x = x * 10 + (c ^ 48); c = getchar();
}
return x;
}

inline void print(ll x) {
int y = 10, len = 1;
while(y <= x) {y *= 10; ++len;}
while(len--) {y /= 10; putchar(x / y + 48); x %= y;}
putchar('\n');
}

inline int fpow(int x, int t = mod - 2) {
int res = 1;
while (t) {
if (t & 1) res = 1ll * res * x % mod;
x = 1ll * x * x % mod; t >>= 1;
}
return res;
}

int mxlen = (1 << 16), w[2][N], rev[N];

inline int mo(int x) {
return x >= mod ? x - mod : x;
}

inline void init() {
int per = fpow(3, (mod - 1) / mxlen);
int invper = fpow(per);
w[0][0] = w[1][0] = 1;
for (int i = 1; i < mxlen; ++i) {
w[0][i] = 1ll * w[0][i - 1] * per % mod;
w[1][i] = 1ll * w[1][i - 1] * invper % mod;
}
}

inline int Rev(int n) {
int len = 1, bit = 0;
while (len <= n) len <<= 1, ++bit;
for (int i = 0; i < len; ++i)
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)));
return len;
}

inline void NTT(int *f, int len, int o) {
for (int i = 0; i < len; ++i)
if (i > rev[i]) swap(f[i], f[rev[i]]);
for (int i = 1; i < len; i <<= 1) {
int wn = mxlen / (i << 1);
for (int j = 0; j < len; j += (i << 1)) {
int nw = 0, x, y;
for (int k = 0; k < i; ++k, nw += wn) {
x = f[j + k];
y = 1ll * w[o][nw] * f[i + j + k] % mod;
f[j + k] = mo(x + y);
f[i + j + k] = mo(x - y + mod);
}
}
}
if (o == 1) {
int invl = fpow(len);
for (int i = 0; i < len; ++i) f[i] = 1ll * f[i] * invl % mod;
}
}

bool vis[N];

int n, m, tot, totn, mx, rt, mxd;

int bkt[N], cnt[N], sz[N], hd[N];

struct edge{int to, nxt;} e[N << 1];

inline void add(int u, int v) {
e[++tot].to = v; e[tot].nxt = hd[u]; hd[u] = tot;
e[++tot].to = u; e[tot].nxt = hd[v]; hd[v] = tot;
}

void getrt(int u, int fa) {
sz[u] = 1;
int mxs = 0;
for (int i = hd[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa && !vis[v]) {
getrt(v, u);
sz[u] += sz[v];
mxs = max(mxs, sz[v]);
}
mxs = max(mxs, totn - sz[u]);
if (mxs < mx) {mx = mxs; rt = u;}
}

void getsz(int u, int fa) {
sz[u] =  1;
for (int i = hd[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa && !vis[v]) {
getsz(v, u); sz[u] += sz[v];
}
}

void dfs(int u, int fa, int dep) {
++bkt[dep]; mxd = max(mxd, dep);
for (int i = hd[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa && !vis[v]) dfs(v, u, dep + 1);
}

inline void mul(int *a, int len, int o) {
len = Rev(len << 1);
NTT(a, len, 0);
for (int i = 0; i < len; ++i) a[i] = 1ll * a[i] * a[i] % mod;
NTT(a, len, 1);
if (o > 0) for (int i = 0; i < len; ++i) cnt[i + 1] += a[i];
else for (int i = 0; i < len; ++i) cnt[i + 3] -= a[i];
for (int i = 0; i < len; ++i) a[i] = 0;
}

inline void calc(int u, int o) {
mxd = 0;
dfs(u, 0, 0);
mul(bkt, mxd, o);
}

void divide(int u) {
vis[u] = 1;
calc(u, 1);
for (int i = hd[u], v; i; i = e[i].nxt)
if (!vis[v = e[i].to]) {
calc(v, -1);
getsz(v, u);
totn = mx = sz[v]; rt = v;
getrt(v, 0); divide(rt);
}
}

int main() {
init();
n = rd();
for (int i = 1; i < n; ++i) add(rd() + 1, rd() + 1);
mx = totn = n;
getrt(1, 0); divide(rt);
double ans = 0.0;
for (int i = 1; i <= n + 1; ++i) ans += (double) cnt[i] / i;
printf("%.4lf", ans);
return 0;
}

### 子树按秩合并做法

#include <cmath>
#include <cstdio>
#include <cctype>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
#define N 65537
#define mod 998244353
using namespace std;
typedef long long ll;

inline int rd() {
int x = 0;
char c = getchar();
while (!isdigit(c)) c = getchar();
while (isdigit(c)) {
x = x * 10 + (c ^ 48); c = getchar();
}
return x;
}

inline void print(ll x) {
int y = 10, len = 1;
while(y <= x) {y *= 10; ++len;}
while(len--) {y /= 10; putchar(x / y + 48); x %= y;}
putchar('\n');
}

inline int fpow(int x, int t = mod - 2) {
int res = 1;
while (t) {
if (t & 1) res = 1ll * res * x % mod;
x = 1ll * x * x % mod; t >>= 1;
}
return res;
}

int mxlen = (1 << 16), w[2][N], rev[N];

inline int mo(int x) {
return x >= mod ? x - mod : x;
}

inline void init() {
int per = fpow(3, (mod - 1) / mxlen);
int invper = fpow(per);
w[0][0] = w[1][0] = 1;
for (int i = 1; i < mxlen; ++i) {
w[0][i] = 1ll * w[0][i - 1] * per % mod;
w[1][i] = 1ll * w[1][i - 1] * invper % mod;
}
}

inline int Rev(int n) {
int len = 1, bit = 0;
while (len <= n) len <<= 1, ++bit;
for (int i = 0; i < len; ++i)
rev[i] = ((rev[i >> 1] >> 1) | ((i & 1) << (bit - 1)));
return len;
}

inline void NTT(int *f, int len, int o) {
for (int i = 0; i < len; ++i)
if (i > rev[i]) swap(f[i], f[rev[i]]);
for (int i = 1; i < len; i <<= 1) {
int wn = mxlen / (i << 1);
for (int j = 0; j < len; j += (i << 1)) {
int nw = 0, x, y;
for (int k = 0; k < i; ++k, nw += wn) {
x = f[j + k];
y = 1ll * w[o][nw] * f[i + j + k] % mod;
f[j + k] = mo(x + y);
f[i + j + k] = mo(x - y + mod);
}
}
}
if (o == 1) {
int invl = fpow(len);
for (int i = 0; i < len; ++i) f[i] = 1ll * f[i] * invl % mod;
}
}

bool vis[N];

double ans = 0.0;

int n, m, tot, totn, mx, rt;

int bkt[N], sz[N], hd[N];

struct edge{int to, nxt;} e[N << 1];

inline void add(int u, int v) {
e[++tot].to = v; e[tot].nxt = hd[u]; hd[u] = tot;
e[++tot].to = u; e[tot].nxt = hd[v]; hd[v] = tot;
}

void getrt(int u, int fa) {
sz[u] = 1;
int mxs = 0;
for (int i = hd[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa && !vis[v]) {
getrt(v, u);
sz[u] += sz[v];
mxs = max(mxs, sz[v]);
}
mxs = max(mxs, totn - sz[u]);
if (mxs < mx) {mx = mxs; rt = u;}
}

void getsz(int u, int fa) {
sz[u] =  1;
for (int i = hd[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa && !vis[v]) {
getsz(v, u); sz[u] += sz[v];
}
}

int res[N], tmp[N];

inline int mul(int *a, int *b, int lena, int lenb) {
int len = Rev(lenb << 1);
for (int i = 0; i < lena; ++i) res[i] = a[i];
for (int i = lena; i < len; ++i) res[i] = 0;
for (int i = 0; i < lenb; ++i) tmp[i] = b[i];
for (int i = lenb; i < len; ++i) tmp[i] = 0;
NTT(res, len, 0); NTT(tmp, len, 0);
for (int i = 0; i < len; ++i) res[i] = 1ll * res[i] * tmp[i] % mod;
NTT(res, len, 1);
for (int i = 0; i < len; ++i) ans += 2.0 * res[i] / (i + 1);
return len;
}

int mxd[N], s[N], bkts[N];

inline bool cmp(int x, int y) {return mxd[x] < mxd[y];}

int dfs(int u, int fa, int dep) {
int resd = dep;
for (int i = hd[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa && !vis[v]) resd = max(resd, dfs(v, u, dep + 1));
return resd;
}

void dfs2(int u, int fa, int dep) {
++bkts[dep];
for (int i = hd[u], v; i; i = e[i].nxt)
if ((v = e[i].to) != fa && !vis[v]) dfs2(v, u, dep + 1);
}

void divide(int u) {
vis[u] = 1;
s[0] = 0;
for (int i = hd[u], v; i; i = e[i].nxt)
if (!vis[v = e[i].to]) {
s[++s[0]] = v;
mxd[v] = dfs(v, u, 1);
}
sort(s + 1, s + 1 + s[0], cmp);
bkt[0] = 1;
int nowlen = 1;
for (int i = 1, v; i <= s[0]; ++i) {
dfs2(v = s[i], 0, 1);
nowlen = mul(bkt, bkts, nowlen, mxd[v] + 1);
for (int i = 0; i <= mxd[v]; ++i) {
bkt[i] += bkts[i]; bkts[i] = 0;
}
}
for (int i = 0; i <= nowlen; ++i) bkt[i] = 0;
for (int i = hd[u], v; i; i = e[i].nxt)
if (!vis[v = e[i].to]) {
getsz(v, u);
totn = mx = sz[v]; rt = v;
getrt(v, 0); divide(rt);
}
}

int main() {
init();
n = rd();
for (int i = 1; i < n; ++i) add(rd() + 1, rd() + 1);
mx = totn = n;
getrt(1, 0); divide(rt);
printf("%.4lf", ans + n);
return 0;
}
posted @ 2019-03-26 07:44  SGCollin  阅读(...)  评论(...编辑  收藏