CF161D Distance in Tree 题解

Description

洛谷传送门

Solution

似乎各种做法都可以过,这里提供一篇 \(dsu\ on\ tree\) (树上启发式合并)的题解。

不会的同学可以看我的博客 浅谈 dsu on tree

题目要求我们求出长度为 \(k\) 的路径有多少条。

那么我们可以开一个桶 \(cnt_x\),表示深度为 \(x\) 的点有多少个,统计答案时 \(ans += cnt_{k - dep[x] + 2 * dep[topx]}\) (类似于树上差分的思想)。

然后修改就比较板子了,加入一个点的话就 \(cnt_{dep[x]}++\),删除的话就 \(cnt_{dep[x]}--\)

其他的就没有什么了。

具体看代码吧。

Code

#include <iostream>
#include <cstdio>
#include <algorithm>
#define ll long long

using namespace std;

const int N = 6e4 + 10;
int n, k;
ll ans;
struct node{
    int v, nxt;
}edge[N << 1];
int head[N], tot;
int siz[N], son[N], fa[N], dep[N];
int cnt[N];

inline void add(int x, int y){
    edge[++tot] = (node){y, head[x]};
    head[x] = tot;
}

inline void dfs(int x, int p){
    dep[x] = dep[p] + 1, siz[x] = 1, fa[x] = p;
    for(int i = head[x]; i; i = edge[i].nxt){
        int y = edge[i].v;
        if(y == p) continue;
        dfs(y, x);
        siz[x] += siz[y];
        if(!son[x] || siz[y] > siz[son[x]])
            son[x] = y;
    }
}

inline void update(int x, int topfa, int type){//type = 0: 加入   1: 删除   2: 统计答案
    if(!type) cnt[dep[x]]++;
    else if(type == 1) cnt[dep[x]]--;
    else if(k - dep[x] + (dep[topfa] << 1) >= 0) ans += (ll)cnt[k - dep[x] + (dep[topfa] << 1)];//这里前面要判一下,不然会 RE
    for(int i = head[x]; i; i = edge[i].nxt){
        int y = edge[i].v;
        if(y != fa[x])
            update(y, topfa, type);
    }
}

inline void solve(int x, int type){//type = 1 表示是重儿子,type = 0 表示是轻儿子
    for(int i = head[x]; i; i = edge[i].nxt){
        int y = edge[i].v;
        if(y != son[x] && y != fa[x]) solve(y, 0);
    }
    if(son[x]) solve(son[x], 1);//加入重儿子
    ans += (ll)cnt[dep[x] + k];//从当前点向下 k 个单位
    cnt[dep[x]]++;//加入根节点
    for(int i = head[x]; i; i = edge[i].nxt){
        int y = edge[i].v;
        if(y == fa[x] || y == son[x]) continue;
        update(y, x, 2);//统计答案
        update(y, x, 0);//加入轻儿子
    }
    if(!type) update(x, x, 1);//删除轻儿子
}

int main(){
    scanf("%d%d", &n, &k);
    for(int i = 1; i < n; ++i){
        int u, v;
        scanf("%d%d", &u, &v);
        add(u, v), add(v, u);
    }
    dfs(1, 0);
    solve(1, 0);
    printf("%lld\n", ans);
    return 0;
}

End

posted @ 2021-10-27 21:59  xixike  阅读(73)  评论(0编辑  收藏  举报