重链剖分
#include <bits/stdc++.h>
#define N 100005
#define inf 1000000000
using namespace std;
struct Edge
{
int v, next;
int u;
} edge[N];
int tot, head[N];
int n, q, num[4 * N];
int siz[100005], son[100005], fa[100005], d[100005], top[100005];
int seg[100005], rev[100005];
int sum[4 * N], maxx[4 * N];
void add(int u, int v)
{
edge[++tot].u = u;
edge[tot].v = v;
edge[tot].next = head[u];
head[u] = tot;
edge[++tot].u = v;
edge[tot].v = u;
edge[tot].next = head[v];
head[v] = tot;
}
void dfs1(int u, int f)
{
siz[u] = 1;
d[u] = d[f] + 1;
fa[u] = f;
for (int i = head[u]; i; i = edge[i].next)
{
int v = edge[i].v;
if (v != f)
{
dfs1(v, u);
siz[u] += siz[v];
if (siz[v] > siz[son[u]])
son[u] = v;
}
}
}
void dfs2(int u)
{
if (son[u])
{
seg[son[u]] = ++seg[0];
rev[seg[0]] = son[u];
top[son[u]] = top[u];
dfs2(son[u]);
}
for (int i = head[u]; i; i = edge[i].next)
{
int v = edge[i].v;
if (!top[v])
{
seg[v] = ++seg[0];
rev[seg[0]] = v;
top[v] = v;
dfs2(v);
}
}
}
void build(int k, int l, int r)
{
if (l == r)
{
sum[k] = maxx[k] = num[rev[l]];
return;
}
int mid = (l + r) / 2;
build(k * 2, l, mid);
build(k * 2 + 1, mid + 1, r);
sum[k] = sum[k * 2] + sum[k * 2 + 1];
maxx[k] = max(maxx[k * 2], maxx[k * 2 + 1]);
}
void update(int k, int l, int r, int q, int v)
{
if (l == r)
{
sum[k] = maxx[k] = v;
return;
}
int mid = (l + r) / 2;
if (q <= mid)
update(k * 2, l, mid, q, v);
else
update(k * 2 + 1, mid + 1, r, q, v);
sum[k] = sum[k * 2] + sum[k * 2 + 1];
maxx[k] = max(maxx[k * 2], maxx[k * 2 + 1]);
}
int ans_sum(int k, int l, int r, int ql, int qr)
{
int mid = (l + r) / 2, ans = 0;
if (ql <= l && r <= qr)
return sum[k];
if (ql <= mid)
ans += ans_sum(k * 2, l, mid, ql, qr);
if (qr > mid)
ans += ans_sum(k * 2 + 1, mid + 1, r, ql, qr);
return ans;
}
int ans_max(int k, int l, int r, int ql, int qr)
{
int mid = (l + r) / 2, ans = -inf;
if (ql <= l && r <= qr)
return maxx[k];
if (ql <= mid)
ans = max(ans, ans_max(k * 2, l, mid, ql, qr));
if (qr > mid)
ans = max(ans, ans_max(k * 2 + 1, mid + 1, r, ql, qr));
return ans;
}
int find_sum(int u, int v)
{
int ans = 0;
while (top[u] != top[v])
{
if (d[top[u]] < d[top[v]])
swap(u, v);
ans += ans_sum(1, 1, n, seg[top[u]], seg[u]);
u = fa[top[u]];
}
if (d[u] < d[v])
swap(u, v);
ans += ans_sum(1, 1, n, seg[v], seg[u]);
return ans;
}
int find_max(int u, int v)
{
int ans = -inf;
while (top[u] != top[v])
{
if (d[top[u]] < d[top[v]])
swap(u, v);
ans = max(ans, ans_max(1, 1, n, seg[top[u]], seg[u]));
u = fa[top[u]];
}
if (d[u] < d[v])
swap(u, v);
ans = max(ans, ans_max(1, 1, n, seg[v], seg[u]));
return ans;
}
int main()
{
scanf("%d", &n);
for (int i = 1; i < n; i++)
{
int u, v;
scanf("%d%d", &u, &v);
add(u, v);
}
for (int i = 1; i <= n; i++)
scanf("%d", &num[i]);
dfs1(1, 1);
seg[1] = ++seg[0];
rev[seg[0]] = 1;
top[1] = 1;
dfs2(1);
build(1, 1, n);
scanf("%d", &q);
while (q--)
{
int x, y;
char s[10];
scanf("%s%d%d", s, &x, &y);
if (s[1] == 'H')
update(1, 1, n, seg[x], y);
if (s[1] == 'M')
printf("%d\n", find_max(x, y));
if (s[1] == 'S')
printf("%d\n", find_sum(x, y));
}
return 0;
}
本文来自博客园,作者:流氓兔LMT,转载请注明原文链接:https://www.cnblogs.com/-include-lmt/p/18901140

浙公网安备 33010602011771号