重链剖分

#include <bits/stdc++.h>
#define N 100005
#define inf 1000000000
using namespace std;
struct Edge
{
    int  v, next;
    int u;
} edge[N];
int tot, head[N];
int n, q, num[4 * N];
int siz[100005], son[100005], fa[100005], d[100005], top[100005];
int seg[100005], rev[100005];
int sum[4 * N], maxx[4 * N];
void add(int u, int v)
{
    edge[++tot].u = u;
    edge[tot].v = v;
    edge[tot].next = head[u];
    head[u] = tot;
    edge[++tot].u = v;
    edge[tot].v = u;
    edge[tot].next = head[v];
    head[v] = tot;
}
void dfs1(int u, int f)
{
    siz[u] = 1;
    d[u] = d[f] + 1;
    fa[u] = f;
    for (int i = head[u]; i; i = edge[i].next)
    {
        int v = edge[i].v;
        if (v != f)
        {
            dfs1(v, u);
            siz[u] += siz[v];
            if (siz[v] > siz[son[u]])
                son[u] = v;
        }
    }
}
void dfs2(int u)
{
    if (son[u]) 
    {
        seg[son[u]] = ++seg[0];
        rev[seg[0]] = son[u];
        top[son[u]] = top[u];
        dfs2(son[u]);
    }
    for (int i = head[u]; i; i = edge[i].next)
    {
        int v = edge[i].v;
        if (!top[v])
        {
            seg[v] = ++seg[0];
            rev[seg[0]] = v;
            top[v] = v;
            dfs2(v);
        }
    }
}

void build(int k, int l, int r)
{
    if (l == r)
    {
        sum[k] = maxx[k] = num[rev[l]];
        return;
    }
    int mid = (l + r) / 2;
    build(k * 2, l, mid);
    build(k * 2 + 1, mid + 1, r);
    sum[k] = sum[k * 2] + sum[k * 2 + 1];
    maxx[k] = max(maxx[k * 2], maxx[k * 2 + 1]);
}
void update(int k, int l, int r, int q, int v)
{
    if (l == r)
    {
        sum[k] = maxx[k] = v;
        return;
    }
    int mid = (l + r) / 2;
    if (q <= mid)
        update(k * 2, l, mid, q, v);
    else
        update(k * 2 + 1, mid + 1, r, q, v);
    sum[k] = sum[k * 2] + sum[k * 2 + 1];
    maxx[k] = max(maxx[k * 2], maxx[k * 2 + 1]);
}
int ans_sum(int k, int l, int r, int ql, int qr)
{
    int mid = (l + r) / 2, ans = 0;
    if (ql <= l && r <= qr)
        return sum[k];
    if (ql <= mid)
        ans += ans_sum(k * 2, l, mid, ql, qr);
    if (qr > mid)
        ans += ans_sum(k * 2 + 1, mid + 1, r, ql, qr);
    return ans;
}
int ans_max(int k, int l, int r, int ql, int qr)
{
    int mid = (l + r) / 2, ans = -inf;
    if (ql <= l && r <= qr)
        return maxx[k];
    if (ql <= mid)
        ans = max(ans, ans_max(k * 2, l, mid, ql, qr));
    if (qr > mid)
        ans = max(ans, ans_max(k * 2 + 1, mid + 1, r, ql, qr));
    return ans;
}
int find_sum(int u, int v)
{
    int ans = 0;
    while (top[u] != top[v])
    {
        if (d[top[u]] < d[top[v]])
            swap(u, v);
        ans += ans_sum(1, 1, n, seg[top[u]], seg[u]);
        u = fa[top[u]];
    }
    if (d[u] < d[v])
        swap(u, v);
    ans += ans_sum(1, 1, n, seg[v], seg[u]);
    return ans;
}
int find_max(int u, int v)
{
    int ans = -inf;
    while (top[u] != top[v])
    {
        if (d[top[u]] < d[top[v]])
            swap(u, v);
        ans = max(ans, ans_max(1, 1, n, seg[top[u]], seg[u]));
        u = fa[top[u]];
    }
    if (d[u] < d[v])
        swap(u, v);
    ans = max(ans, ans_max(1, 1, n, seg[v], seg[u]));
    return ans;
}
int main()
{
    scanf("%d", &n);
    for (int i = 1; i < n; i++)
    {
        int u, v;
        scanf("%d%d", &u, &v);
        add(u, v);
    }
    for (int i = 1; i <= n; i++)
        scanf("%d", &num[i]);
    dfs1(1, 1);
    seg[1] = ++seg[0];
    rev[seg[0]] = 1;
    top[1] = 1;
    dfs2(1);
    build(1, 1, n);
    scanf("%d", &q);
    while (q--)
    {
        int x, y;
        char s[10];
        scanf("%s%d%d", s, &x, &y);
        if (s[1] == 'H')
            update(1, 1, n, seg[x], y);
        if (s[1] == 'M')
            printf("%d\n", find_max(x, y));
        if (s[1] == 'S')
            printf("%d\n", find_sum(x, y));
    }
    return 0;
}
posted @ 2025-05-28 19:36  流氓兔LMT  阅读(9)  评论(0)    收藏  举报