树(点分治)

题面

给定一个有N个点(编号0,1,…,N-1)的树,每条边都有一个权值(不超过1000)。

树上两个节点x与y之间的路径长度就是路径上各条边的权值之和。

求长度不超过K的路径有多少条。

输入格式

输入包含多组测试用例。

每组测试用例的第一行包含两个整数N和K。

接下来N-1行,每行包含三个整数u,v,l,表示节点u与v之间存在一条边,且边的权值为l。

当输入用例N=0,K=0时,表示输入终止,且该用例无需处理。

输出格式

每个测试用例输出一个结果。

每个结果占一行。

数据范围

N≤10000

输入样例:

5 4
0 1 3
0 2 1
0 3 2
2 4 1
0 0

输出样例:

8

题解

板子题, 求完距离, 尺取法

#include <bits/stdc++.h>
#define rep(i,a,b) for(int i=a;i<=b;++i)
using namespace std;

const int N = 1e4 + 5;

int n, m, k;
int h[N], to[N << 1], ne[N << 1], co[N << 1], tot;
int d[N], b[N];
int siz[N], mxcen, center;
int tax[N], cnt[N], ans, v[N];

inline void add(int u, int v, int c)
{
    ne[++tot] = h[u]; h[u] = tot; co[tot] = c; to[tot] = v;
}

void dfscen(int x, int f)
{
    siz[x] = (f != 0);
    int max_center = 0;
    for (int i = h[x]; i; i = ne[i])
    {
        int& y = to[i];
        if (y == f) continue;
        dfscen(y, x);
        siz[x] += siz[y];
        max_center = max(max_center, siz[y]);
    }
    max_center = max(max_center, n - siz[x]);

    if (max_center < mxcen)
    {
        mxcen = max_center;
        center = x;
    }

    d[center] = 0;
}

void dfsd(int u, int f, int t)
{
    b[u] = t;
    for (int i = h[u]; i; i = ne[i])
    {
        int y = to[i];
        if (y == f || v[y]) continue;
        d[y] = d[u] + co[i];
        dfsd(y, u, t == center ? y : t);
    }
}

bool cmp(int a, int b)
{
    return d[a] < d[b];
}

void work(int p, int f)
{
    memset(cnt, 0, sizeof cnt);
    memset(b, 0, sizeof b);
    mxcen = N + 1; dfscen(p, f);  
    v[center] = 1; dfsd(center, f, center);

    int cntx = 0;
    rep(i, 1, n) if (b[i]) tax[++cntx] = i, ++cnt[b[i]];
    sort(tax + 1, tax + 1 + cntx, cmp);

    for (int l = 1, r = cntx; l < r; --cnt[b[tax[l++]]])
    {
        while (r > l && d[tax[l]] + d[tax[r]] > k) --cnt[b[tax[r--]]];
        ans += r - l - cnt[b[tax[l]]] + 1;
    }

    for (int i = h[center], c = center; i; i = ne[i])
        if (ne[h[to[i]]] && v[to[i]] == 0) work(to[i], c);
}

int main()
{
    ios::sync_with_stdio(0); cin.tie(0);
    while (cin >> n >> k, n + k)
    {
        memset(h, 0, sizeof h); tot = ans = 0;
        memset(v, 0, sizeof v);
        rep(i, 2, n)
        {
            int u, v, c; cin >> u >> v >> c;
            add(u + 1, v + 1, c);
            add(v + 1, u + 1, c);
        }
        work(1, 0);
        cout << ans << '\n';
    }
    return 0;
}
posted @ 2020-06-09 10:44  洛绫璃  阅读(146)  评论(0编辑  收藏  举报