树上叶子之间点对距离平方和

题目大意:

一棵无根树,定义度数为1点为叶子节点,求所有两个叶子之间距离的平方和,树边上有边权。

样例:

4
1 4 1
4 3 1
2 4 1

12

4
1 2 3
1 4 2
4 3 1

36

5
1 2 1
1 3 4
2 4 3
2 5 2

138

10
1 2 10
10 2 7
3 2 8
3 9 3
9 8 2
7 9 1
6 4 3
4 5 2
3 4 4

4709
View Code

 

(看到类似的题解才有的思路)

 

其实这道题的难点在于平方和,如果去掉和这道题就很简单了,换根dp就行了。

然后发现其实平方和拆开来也可以换根。

 

先不考虑叶子,我们直接求每个点到所有叶子的距离平方和。

假设$dis[x]$表示$x$到当前点处理到的点的距离。

那么当前答案为$\sum dis[i]^2$

往一个子树跑的时候,子树内的$dis$会减去边长,子树外的$dis$会加上边长。  

于是和式变为$\sum_{i\ in\ tree}(dis[i] - edge[i])^2 + \sum _{i\ not\ in\ tree}(dis[i] + edge[i]) ^ 2$

把这个式子拆开,就变成$\sum (dis[i])^2 - (2 * edge[i] * \sum_{i\ in\ tree}dis[i]) + (2 * edge[i] * \sum _{i\ not\ in\ tree}dis[i]) + (edge[i] * edge[i] * totcnt)$

$\sum (dis[i])^2 $是当前节点的答案,$\sum_{i\ in\ tree}dis[i], \sum _{i\ not\ in\ tree}dis[i]$都可以通过预处理预处理出来。

然后碰到当前点是叶子点直接统计答案就可以了。

#include <bits/stdc++.h>
#define int long long
#define Mid ((l + r) / 2)
#define lson (rt << 1)
#define rson (rt << 1 | 1)
using namespace std;
int read() {
    char c; int num, f = 1;
    while(c = getchar(),!isdigit(c)) if(c == '-') f = -1; num = c - '0';
    while(c = getchar(), isdigit(c)) num = num * 10 + c - '0';
    return f * num;
}
const int N = 2e5 + 1009;
int n, in[N], rt, cntl, totdis[N], revtotdis[N], cnt[N], ans;
int head[N], nxt[N], ver[N], edge[N], tot = 1;
void add(int x, int y, int w) {
    ver[++tot] = y; nxt[tot] = head[x]; head[x] = tot; edge[tot] = w;
}
void dfs(int x, int pre, int d) {
    totdis[x] = 0;
    cnt[x] = in[x] == 1;
    if(in[x] == 1) ans += d * d;
    for(int i = head[x]; i; i = nxt[i]) if(ver[i] != pre) {
        dfs(ver[i], x, d + edge[i]);
        totdis[x] += totdis[ver[i]] + cnt[ver[i]] * edge[i];
        cnt[x] += cnt[ver[i]];
    }
}
void dfs1(int x, int pre) {
    for(int i = head[x]; i; i = nxt[i]) if(ver[i] != pre) {
        revtotdis[ver[i]] = revtotdis[x] + (cntl - cnt[x]) * edge[i] + totdis[x] - totdis[ver[i]] - edge[i] * cnt[ver[i]] + edge[i] * (cnt[x] - cnt[ver[i]]);
        dfs1(ver[i], x);
    }
}
void dp(int x, int pre, int now) {
    if(in[x] == 1) ans += now;
    for(int i = head[x]; i; i = nxt[i]) if(ver[i] != pre) {
        dp(ver[i], x, now - 2 * edge[i] * (totdis[ver[i]] + edge[i] * cnt[ver[i]]) + 2 * edge[i] * (revtotdis[ver[i]] - edge[i] * (cntl - cnt[ver[i]])) + edge[i] * edge[i] * cntl);
    }
}
signed main()
{
    n = read();
    for(int i = 1; i < n; i++) {
        int x = read(), y = read(), w = read();
        in[x]++; in[y]++;
        add(x, y, w); add(y, x, w);
    }
    for(int i = 1; i <= n; i++) 
        if(in[i] != 1) rt = i;
        else cntl++;
    dfs(rt, rt, 0);
    dfs1(rt, rt);
    int tmp = ans; 
    ans = 0;
    dp(rt, rt, tmp);
    printf("%lld\n", ans / 2);
    return 0;
}
View Code

 

posted @ 2021-05-05 15:05  _onglu  阅读(134)  评论(0编辑  收藏  举报