ICPC 2019 徐州 M - Kill the tree

这题标程是直接找,我写的方法是对当前的点到根节点之间的路径上二分找当前点可以做重心的区间,复杂度O(nlognlogn)

#include <bits/stdc++.h>
using namespace std;
const int N = 2 * 1e5 + 10;
vector < int > edge[N];
int sz[N], f[N][21], dep[N];
vector < int > ans[N];
void dfs1(int u, int fa) {
    sz[u] = 1; dep[u] = dep[fa] + 1;
    f[u][0] = fa;
    for (int i = 1; i <= 20; i++)
        f[u][i] = f[f[u][i - 1]][i - 1];
    for (int i = 0; i < edge[u].size(); i++) {
        int v = edge[u][i];
        if (v == fa)    continue;
        dfs1(v, u);
        sz[u] += sz[v];    
    }
}

bool check1(int u, int len, int maxn) {
    int t = u;
    for (int i = 20; i >= 0; i--)
        if ((1 << i) & len)
            t = f[t][i];
    return sz[t] <= 2 * sz[u];
}

bool check2(int u, int len, int maxn) {
    int t = u;
    for (int i = 20; i >= 0; i--)
        if ((1 << i) & len)
            t = f[t][i];
    return 2 * maxn <= sz[t];
}

void dfs(int u, int fa) {
    if (u == 3)
        fa = fa + 1 - 1;
    int maxn = 0;
    for (int i = 0; i < edge[u].size(); i++)
        if (edge[u][i] != fa)
            maxn = max(maxn, sz[edge[u][i]]);
    int l = 0, r = dep[u] - 1;
    while (l < r) {
        int mid = (l + r + 1) / 2;
        if (check1(u, mid, maxn))
            l = mid;
        else    r = mid - 1;
    }
    int L = 0, R = dep[u] - 1;
    while (L < R) {
        int mid = (L + R) / 2;
        if (check2(u, mid, maxn))
            R = mid;
        else    L = mid + 1;
    }
    //printf("%d %d\n", L, r);
    int s = u, t = u;
    for (int i = 20; i >= 0; i--)
        if ((1 << i) & L)
            s = f[s][i];
    for (int i = 20; i >= 0; i--)
        if ((1 << i) & r)
            t = f[t][i];
    //printf("%d %d\n", s, t);
    swap(L, r); swap(s, t);
    if (r >= L) {
        if (r == L)
            if (2 * maxn <= sz[s] && sz[s] <= 2 * sz[u])
                ans[s].push_back(u);
    }
    else {
        while (t != f[s][0]) {
            ans[t].push_back(u);
            t = f[t][0];
        } 
    }
    for (int i = 0; i < edge[u].size(); i++)
        if (edge[u][i] != fa)
            dfs(edge[u][i], u);
}

int main() {
    int N;
    scanf("%d", &N);
    for (int i = 1; i < N; i++) {
        int a, b;
        scanf("%d %d", &a, &b);
        edge[a].push_back(b);
        edge[b].push_back(a);
    }
    dfs1(1, 0);
    dfs(1, 0);
    for (int i = 1; i <= N; i++, puts("")) {
        if (ans[i].size() == 2) {
            if (ans[i][0] < ans[i][1])
                printf("%d %d", ans[i][0], ans[i][1]);
            else    printf("%d %d", ans[i][1], ans[i][0]);
        }
        else
            printf("%d", ans[i][0]);
        //for (int j = 0; j < ans[i].size(); j++)
        //    printf("%d ", ans[i][j]);
    }
    return 0;
}

 

posted @ 2020-09-26 20:21  cminus  阅读(263)  评论(0)    收藏  举报