点分树

重工业

当答案与树的形态无关,只与子树信息有关时
然后你又要来点修改.
可用点分树
所以说你要来两斤数据结构维护

#include<bits/stdc++.h>
#define int long long
#define F(i,i0,n) for(int i=(i0);i<=(n);i++)
#define pii pair<int,int>
#define fr first
#define sc second
using namespace std;
inline int rd() {
    int f = 0, x = 0; char ch = getchar();
    while (!isdigit(ch)) { if (ch == '-')f = 1; ch = getchar(); }
    while (isdigit(ch)) { x = (x << 3) + (x << 1) + ch - 48; ch = getchar(); }
    return f ? -x : x;
}
const int N = 2e5 + 7, inf = LONG_LONG_MAX;
struct Id { int v, nt; }e[N << 1];
int p[N], id = 1;
void add(int x, int y) { e[++id] = { y,p[x] }; p[x] = id; }
int n, w[N], dep[N];
struct Lca {
    int top[N], fa[N], siz[N], son[N], Tim;
    void dfs1(int x, int ffa) {
        siz[x] = 1; fa[x] = ffa;dep[x] = dep[ffa] + 1;int v;
        for (int i = p[x]; i; i = e[i].nt) 
            if ((v=e[i].v) != ffa) {
                dfs1(v, x);
                siz[x] += siz[v];
                if (siz[son[x]] < siz[v])son[x] = v;
            }
    }
    void dfs2(int x, int tp) {
        top[x] = tp;
        if (son[x])dfs2(son[x], tp);int v;
        for (int i = p[x]; i; i = e[i].nt) 
            if ((v=e[i].v) != fa[x] && v != son[x])dfs2(v, v);
    }
    int lca(int x, int y) {
        while (top[x] != top[y]) {
            if (dep[top[x]] < dep[top[y]])swap(x, y);
            x = fa[top[x]];
        }
        return dep[x] < dep[y] ? x : y;
    }
    void build() {dfs1(1, 0);dfs2(1, 1);}
}Tr;
int dist(int x, int y) { return dep[x] + dep[y] - 2 * dep[Tr.lca(x, y)]; }
struct BIT {//有一的偏移 
    int lim; vector<int>c;
    void build(int x) { c.resize(x + 1); lim = x; }
    void upd(int x, int k) { x++; for (; x <= lim; x += x & -x)c[x] += k; }
    int sum(int x) { x++; x = min(x, lim); int ans = 0; for (; x; x -= x & -x)ans += c[x]; return ans; }
}tr1[N], tr2[N];
int rt, siz[N], sum, maxp[N], vis[N], fa[N];
void getrt(int x, int ffa) {
    siz[x] = 1; maxp[x] = 0;int v;
    for (int i = p[x]; i; i = e[i].nt) 
        if ((v= e[i].v) != ffa && !vis[v]){
            getrt(v, x);
            siz[x] += siz[v], maxp[x] = max(siz[v], maxp[x]);   
        }
    maxp[x] = max(maxp[x], sum - siz[x]);
    if (maxp[x] < maxp[rt])rt = x;
}
void solve(int x, int ffa) {
    fa[x] = ffa; vis[x] = 1;
    tr1[x].build(sum/ 2 + 1);
    tr2[x].build(sum + 1);int v;
    for (int i = p[x]; i; i = e[i].nt) 
        if (!vis[v= e[i].v]) { sum = siz[v]; maxp[rt = 0] = inf; getrt(v, 0); solve(rt, x); }
}
void upd(int x, int k) {
    tr1[x].upd(0, k);
    for (int i = x; fa[i]; i = fa[i]) {
        tr1[fa[i]].upd(dist(fa[i], x), k);
        tr2[i].upd(dist(fa[i], x), k);
    }
}
int que(int x, int k) {
    int ans = tr1[x].sum(k);
    for (int i = x; fa[i]; i = fa[i]) {
        int tmp = k - dist(fa[i], x); if (tmp < 0)continue;
        ans += tr1[fa[i]].sum(tmp);
        ans -= tr2[i].sum(tmp);
    }
    return ans;
}
int lst = 0;
signed main() {
    n = rd();
    int m = rd();
    F(i, 1, n)w[i] = rd();
    F(i, 1, n - 1) {
        int x = rd(), y = rd();
        add(x, y); add(y, x);
    }
    Tr.build();
    sum = n; maxp[rt = 0] = inf; getrt(1, 0); solve(rt, 0);
    F(i, 1, n)upd(i, w[i]);
    while (m--) {
        int op = rd(), x = rd(), y = rd(); x ^= lst; y ^= lst;
        if (op)upd(x, y - w[x]), w[x] = y;
        else cout << (lst = que(x, y)) << '\n';
    }
    return 0;
}
posted @ 2023-09-09 17:08  ussumer  阅读(18)  评论(0)    收藏  举报