欧拉之树

欧拉之树

给你一棵树,每个点有权值 \(a_i\),求 \(\sum\limits_{1\leq i,j\leq n}\operatorname{dis}(i,j)\varphi(a_ia_j)\)

今天集训模拟赛的题,

首先把 \(\varphi(a_ia_j)\) 写成 \(\dfrac{\varphi(a_i)\varphi(a_j)\gcd(a_i,a_j)}{\varphi[\gcd(a_i,a_j)]}\)

证明:
\(n\) 的质因数分解为 $$

算了先不写了,有空填坑(

最后就是莫反之后虚树搞一搞,还算比较套路

#include <cstdio>
#include <algorithm>
#include <cstring>
typedef long long ll;
using namespace std;
inline int read() {
    int x = 0;
    char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x;
}
inline void write(int x) {
    if (x == 0)
        return;
    write(x / 10);
    putchar(x % 10 + '0');
}
const int maxn = 5e5 + 10;
const int mod = 1000000007;
int primes[maxn], tot, phi[maxn], mu[maxn];
bool v[maxn];
void Init() {
    mu[1] = phi[1] = 1;
    for (int i = 2; i < maxn; i++) {
        if (!v[i])
            primes[++tot] = i, mu[i] = -1, phi[i] = i - 1;
        for (int j = 1; i * primes[j] < maxn; j++) {
            v[primes[j] * i] = true;
            if (i % primes[j] == 0) {
                phi[primes[j] * i] = phi[i] * primes[j];
                break;
            }
            mu[primes[j] * i] = -mu[i];
            phi[primes[j] * i] = phi[primes[j]] * phi[i];
        }
    }
}
struct Edge {
    int to, next;
} edge[maxn * 2];
int head[maxn], cnt;
inline void addedge(int u, int v) {
    edge[++cnt].to = v;
    edge[cnt].next = head[u];
    head[u] = cnt;
}
int n, a[maxn], b[maxn];
int dfn[maxn], idx, deep[maxn], fa[maxn][20];
void dfs(int x) {
    for (int i = 1; i <= 19; i++) fa[x][i] = fa[fa[x][i - 1]][i - 1];
    dfn[x] = ++idx;
    for (int i = head[x]; i; i = edge[i].next) {
        int y = edge[i].to;
        if (dfn[y])
            continue;
        deep[y] = deep[x] + 1;
        fa[y][0] = x;
        dfs(y);
    }
}
inline int Lca(int x, int y) {
    if (deep[x] < deep[y])
        swap(x, y);
    for (int i = 19; i >= 0; i--)
        if (deep[fa[x][i]] >= deep[y])
            x = fa[x][i];
    if (x == y)
        return x;
    for (int i = 19; i >= 0; i--)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}
inline bool cmp(int a, int b) { return dfn[a] < dfn[b]; }
int c[maxn], stk[maxn], top;
bool qry[maxn];
void build(int t) {
    cnt = 0;
    sort(c + 1, c + n / t + 1, cmp);
    stk[++top] = c[1];
    for (int i = 2; i <= n / t; i++) {
        int lca = Lca(c[i], stk[top]);
        while (deep[lca] < deep[stk[top - 1]]) {
            addedge(stk[top - 1], stk[top]);
            top--;
        }
        if (lca != stk[top]) {
            addedge(lca, stk[top--]);
            if (lca != stk[top])
                stk[++top] = lca;
        }
        stk[++top] = c[i];
    }
    while (--top) addedge(stk[top], stk[top + 1]);
}
int tmp, sum[maxn];
void dp(int x) {
    int ans = 0;
    sum[x] = 0;
    for (int i = head[x]; i; i = edge[i].next) {
        int y = edge[i].to;
        dp(y);
        ans = (ans + 2ll * sum[x] * sum[y]) % mod;
        sum[x] = (sum[x] + sum[y]) % mod;
        if (qry[x])
            ans = (ans + 2ll * phi[a[x]] * sum[y]) % mod;
    }
    if (qry[x]) {
        sum[x] = (sum[x] + phi[a[x]]) % mod;
        ans = (ans + (ll)phi[a[x]] * phi[a[x]]) % mod;
    }
    tmp = (tmp + (ll)deep[x] * ans) % mod;
    qry[x] = false, head[x] = 0;
}
int F[maxn], f[maxn];
inline int inv(int a) {
    int ans = 1, b = mod - 2;
    for (; b; b >>= 1) {
        if (b & 1)
            ans = (ll)ans * a % mod;
        a = (ll)a * a % mod;
    }
    return ans;
}
signed main() {
    freopen("sm.in", "r", stdin);
    freopen("sm.out", "w", stdout);
    n = read();
    for (int i = 1; i <= n; i++) a[i] = read(), b[a[i]] = i;
    for (int i = 1; i < n; i++) {
        int u = read(), v = read();
        addedge(u, v), addedge(v, u);
    }
    Init();
    dfs(1);
    memset(head, 0, sizeof(head));
    cnt = 0;
    for (int t = 1; t <= n; t++) {
        for (int i = 1; i <= n / t; i++) c[i] = b[i * t], qry[c[i]] = true;
        build(t);
        int res = 0;
        for (int i = t; i <= n; i += t) res = (res + phi[i]) % mod;
        int ans = 0;
        for (int i = t; i <= n; i += t) ans = (ans + (ll)deep[b[i]] * phi[i] % mod * res) % mod;
        tmp = 0;
        dp(stk[1]);
        cnt = 0;
        ans = (ans - tmp + mod) % mod;
        F[t] = ans * 2ll % mod;
    }
    for (int t = 1; t <= n; t++)
        for (int d = t; d <= n; d += t) f[t] = (f[t] + (ll)mu[d / t] * F[d]) % mod;
    int ans = 0;
    for (int t = 1; t <= n; t++) ans = (ans + (ll)t * f[t] % mod * inv(phi[t])) % mod;
    write((((ll)ans * inv((ll)n * (n - 1) % mod) % mod) + mod) % mod);
    return 0;
}
posted @ 2021-02-23 00:41  iMya_nlgau  阅读(54)  评论(0编辑  收藏  举报