洛谷P4074 [WC2013] 糖果公园 题解 树上带修莫队

题目链接:https://www.luogu.com.cn/problem/P4074

解题思路完全来自 oi.wiki

示例程序:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5;

int n, m, Q, blo, V[maxn], W[maxn], c[maxn], fa[maxn][17], dep[maxn], id[maxn * 2], idx, dfn[maxn], nfd[maxn], idx1, idx2;
long long ans[maxn], sum;
vector<int> g[maxn];

void dfs(int u, int p) {
    id[ dfn[u] = ++idx ] = u;
    dep[u] = dep[p] + 1;
    fa[u][0] = p;
    for (auto v : g[u])
        if (v != p)
            dfs(v, u);
    id[ nfd[u] = ++idx ] = u;
}

int lca(int x, int y) {
    if (dep[x] < dep[y])
        swap(x, y);
    for (int i = 16; i >= 0; i--) {
        int p = fa[x][i];
        if (dep[p] >= dep[y])
            x = p;
    }
    if (x == y)
        return x;
    for (int i = 16; i >= 0; i--)
        if (fa[x][i] != fa[y][i])
            x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}

struct Query {
    int id, l, r, t;
} query[maxn];

struct Event {
    int x, y;
} event[maxn];

bool vis[maxn];
int cnt[maxn];

void add(int u) {
    int j = c[u];
    // 第 i 次品尝第 j 种糖果,愉悦指数 H 增加 V[j] * W[i]
    int i = ++cnt[j];
    sum += 1ll * V[j] * W[i];
}

void del(int u) {
    int j = c[u];
    int i = cnt[j]--;
    sum -= 1ll * V[j] * W[i];
}

void add_or_del(int u) {
    vis[u] ? del(u) : add(u);
    vis[u] ^= 1;
}

int main() {
    scanf("%d%d%d", &n, &m, &Q);
    blo = pow(n, 2.0/3);
    for (int i = 1; i <= m; i++)
        scanf("%d", V+i);
    for (int i = 1; i <= n; i++)
        scanf("%d", W+i);
    for (int i = 1, u, v; i < n; i++) {
        scanf("%d%d", &u, &v);
        g[u].push_back(v);
        g[v].push_back(u);
    }
    for (int i = 1; i <= n; i++)
        scanf("%d", c+i);
    dfs(1, 0);
    for (int i = 1; i <= 16; i++)
        for (int u = 1; u <= n; u++)
            fa[u][i] = fa[ fa[u][i-1] ][i-1];
    for (int i = 0, op, x, y; i < Q; i++) {
        scanf("%d%d%d", &op, &x, &y);
        if (!op) { // op == 0
            event[++idx2] = {x, y};
        }
        else {  // op == 1
            idx1++;
            if (dfn[x] > dfn[y])
                swap(x, y);
            int z = lca(x, y), l, r;
            if (z == x || z == y)
                l = dfn[x], r = dfn[y];
            else
                l = nfd[x], r = dfn[y];
            query[idx1] = {idx1, l, r, idx2};
        }
    }
    sort(query+1, query+idx1+1, [](auto a, auto b) {
        if (a.l / blo != b.l / blo)
            return a.l < b.l;
        if (a.r / blo != b.r / blo)
            return a.r < b.r;
        return a.t < b.t;
      });
    for (int i = 1, l = 1, r = 0, t = 0; i <= idx1; i++) {
        while (l < query[i].l) add_or_del(id[l++]);
        while (l > query[i].l) add_or_del(id[--l]);
        while (r < query[i].r) add_or_del(id[++r]);
        while (r > query[i].r) add_or_del(id[r--]);
        for (; t < query[i].t; t++) {
            auto [x, y] = event[t+1];
            int z = c[x];
            if (vis[x]) {
                del(x);
                c[x] = y;
                add(x);
            }
            else
                c[x] = y;
            event[t+1].y = z;
        }
        for (; t > query[i].t; t--) {
            auto [x, y] = event[t];
            int z = c[x];
            if (vis[x]) {
                del(x);
                c[x] = y;
                add(x);
            }
            else
                c[x] = y;
            event[t].y = z;
        }
        int x = id[l], y = id[r], z = lca(x, y);
        if (z == x || z == y) {
            ans[ query[i].id ] = sum;
        }
        else {
            add(z);
            ans[ query[i].id ] = sum;
            del(z);
        }
    }
    for (int i = 1; i <= idx1; i++)
        printf("%lld\n", ans[i]);
    return 0;
}
posted @ 2026-03-05 15:08  quanjun  阅读(0)  评论(0)    收藏  举报