洛谷P4556 【模板】线段树合并 / [Vani 有约会] 雨天的尾巴 线段树合并模板题

题目链接:https://www.luogu.com.cn/problem/P4556

解题思路:来自 oi.wiki

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5, M = 1e5;

int n, m, dep[maxn], fa[maxn][17];
vector<int> g[maxn];

void dfs1(int u, int p) {
    fa[u][0] = p;
    dep[u] = dep[p] + 1;
    for (auto v : g[u])
        if (v != p)
            dfs1(v, u);
}

int lca(int x, int y) {
    if (dep[x] < dep[y])
        swap(x, y);
    for (int i = 16; i >= 0; i--)
        if (dep[ fa[x][i] ] >= dep[y])
            x = fa[x][i];
    if (x == y)
        return x;
    for (int i = 16; i >= 0; i--)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}

// 线段树部分
struct Node {
    int l, r, id, cnt; // id出现次数最多的颜色编号,cnt最多出现次数
} tr[maxn*200];
int rt[maxn], idx;

void push_up(int u) {
    assert(tr[0].id == 0 && tr[0].cnt == 0);
    int l = tr[u].l, r = tr[u].r;
    if (tr[l].cnt >= tr[r].cnt) {
        tr[u].id = tr[l].id;
        tr[u].cnt = tr[l].cnt;
    }
    else {
        tr[u].id = tr[r].id;
        tr[u].cnt = tr[r].cnt;
    }
    if (!tr[u].cnt) tr[u].id = 0; // 注意加上这句话,不然会WA第2组测试点
}

void add(int p, int v, int l, int r, int &u) {
    if (!u)
        u = ++idx;
    if (l == r) {
        tr[u].id = p;
        tr[u].cnt += v;
        return;
    }
    int mid = l + r >> 1;
    (p <= mid) ? add(p, v, l, mid, tr[u].l) : add(p, v, mid+1, r, tr[u].r);
    push_up(u);
}

int Merge(int l, int r, int a, int b) {
    if (!a || !b) return a + b;
    if (l == r) {
        tr[a].cnt += tr[b].cnt;
        return a;
    }
    int mid = l + r >> 1;
    tr[a].l = Merge(l, mid, tr[a].l, tr[b].l);
    tr[a].r = Merge(mid+1, r, tr[a].r, tr[b].r);
    push_up(a);
    return a;
}

int ans[maxn];
void dfs2(int u, int p) {
    for (auto v : g[u]) {
        if (v != p) {
            dfs2(v, u);
            rt[u] = Merge(1, M, rt[u], rt[v]);
        }
    }
    ans[u] = tr[ rt[u] ].id;
}

int main() {
    scanf("%d%d", &n, &m);
    for (int i = 1, u, v; i < n; i++) {
        scanf("%d%d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs1(1, 0);
    for (int i = 1; i < 17; i++)
        for (int u = 1; u <= n; u++)
            fa[u][i] = fa[ fa[u][i-1] ][i-1];
    for (int i = 0, x, y, c; i < m; i++) {
        scanf("%d%d%d", &x, &y, &c);
        int z = lca(x, y);
        add(c, 1, 1, M, rt[x]);
        add(c, 1, 1, M, rt[y]);
        add(c, -1, 1, M, rt[z]);
        if (fa[z][0])
            add(c, -1, 1, M, rt[ fa[z][0] ]);
    }
    dfs2(1, 0);
    for (int i = 1; i <= n; i++)
        printf("%d\n", ans[i]);

    return 0;
}
posted @ 2026-03-19 14:45  quanjun  阅读(2)  评论(0)    收藏  举报