(CF161D Distance in Tree) 树分治

//题意:给定一棵树,求树上有多少个距离为k的点对
//思路:(树分治思想,具体看博客图吧,另外这题很明显的可以用DSU on tree做,因为求一个长度为k路径的合并部分非常简单)
//
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 50005, K = 10005;
vector<pair<int, int>> e[N];
namespace CenDec{
    int ctr, n, sz[N];
    bool del[N];
    void dfs(int p, int fa = 0) {
        sz[p] = 1;
        int mss = 0;
        for (auto to : e[p]) {
            if (del[to.first] || to.first == fa) continue;
                dfs(to.first, p);
                if (ctr != -1) return;//在子树递归过程中找到重心就即时退出
                mss = max(mss, sz[to.first]);
                sz[p] += sz[to.first];
        }
        mss = max(mss, n - sz[p]);//与根节点之上的那棵子树进行比较
        if (mss <= n / 2) {
            ctr = p;
            sz[fa] = n - sz[p];//更新sz[fa]的值,目的是把重心相邻的所有子树大小重新更新一遍,因为待会我们要
                               //从这个重心向下分治,向下分治的话我们是需要用到子树大小的
        }
    }
    int k, cnt;

    //注释部分就是统计答案的部分,真正的难点就在于这部分,其他的部分大同小异,相当于板子
    /*
    int temp[N], lens[K], cntt;
    void dfs2(int p, int fa, int w) {
        if (w > k) return;//剪枝,同时防止数组访问越界
        cnt += lens[k - w] + (w == k);//更新答案
        temp[cntt++] = w;
        for (auto to : e[p])
            if (!del[to.first] && to.first != fa)
                dfs2(to.first, p, w + 1);
    }
    */

    void run(int p) {
    /*
        for (auto to : e[p]) {
            if (del[to.first]) continue;
            dfs2(to.first, p, 1);
            for (int i = 0; i < cntt; ++i) lens[temp[i]]++;
            cntt = 0;
        }
        fill(lens, lens + K, 0);//清空该重心情况
     */ 
        del[p] = 1;
        for (auto to : e[p]) {
            if (!del[to.first]) {
                n = sz[to.first];//现在要遍历的树是上个重心的子树
                ctr = -1;
                dfs(to.first);
                run(ctr);
            }
        }
    }
    int count(int n0, int k0) {
        n = n0, k = k0; ctr = -1;
        dfs(1);//找重心
        run(ctr);//计算贡献
        return cnt;
    }
}
int main() {
    int n, k, u, v, w;
    cin >> n >> k;
    for (int i = 1; i <= n - 1; ++i)
    {
        cin >> u >> v;
        w = 1;
        e[u].push_back({ v,w });
        e[v].push_back({ u,w });
    }
    cout << CenDec::count(n, k) << endl;
    return 0;
}

 

posted @ 2023-01-04 16:47  Aacaod  阅读(16)  评论(0)    收藏  举报