2020牛客暑期多校(七) C - A National Pandemic(树链剖分)

2020牛客暑期多校(七) C - A National Pandemic(树链剖分)

参考博客

题意:

一棵树支持3种操作:

  • 1 x w, 给x点加w,其它点y加 \(w-dist(x, y)\).
  • 2 x, 将x权值变为$min(0, f(x)) $;
  • 3 x, 查询x的权值\(f(x)\)

分析:

先推荐一个题单: 树链剖分练习题 如果没有学过树链剖分可以做一下。

首先2, 3操作用树链剖分处理都很直接,主要看1操作。给一个点x加w也还好处理,但给其他点加\(w - dist(x, y)\) 怎么加,难道要枚举点吗?显然点那么多会T。所以要处理这个操作可以观察这个式子可以写成 \(w - dist(1, x) - dist(1, y) + 2*dist(1, lca(x,y))\) 理解见下图,紫色是dist(1,x),绿色是dist(1, y) ,黄色是dist(1, lca(x,y))。

\(w - dist(1, x) - dist(1, y) + 2*dist(1, lca(x,y))​\)

观察式子可以看到\(w-dist(1,x)\)\(dist(1,y)\) 都可以用变量去累计,因为对一个查询3操作,它前面的1操作时的\(w-dist(1,x)\) 你可以累计下来,然后减\(dist(1, y)\) 的个数就是前面1操作的个数,也可以用一个变量allnum记录树量。

所以重点在处理\(dist(1,lca(x,y))\) 我们发现当查询一个点y时只要找到1到 y 路径上所以以前1操作标记的$lca(x,y) $点 ,求和这些点到 1 的距离即可,但这很麻烦不好处理。但是它是lca点到1的距离,所以我们可以在1处理时对1到x每个点权值+1,比如上图中处理x时,我把紫线上所有点+1,那么当处理2时我想要加的是1到lca的距离,可以发现此时1到lca的权值和就是1到lca的距离,这里用了差分的一个思想。当我们有很多x时,它们会在1到y条路径上1到某个点之间权值都加1,其实这个点就是lca,这个很好理解。所以我们只要用线段树维护权值和即可。但我们观察式子要2*dist(1,lca(x,y)).这只需要对每个1操作的x给线段树1到x之间的点+2即可。

代码:

#include<bits/stdc++.h>
using namespace std;
#define rep(i, a, n) for(int i = a; i <= n; ++ i);
#define per(i, a, n) for(int i = n; i >= a; -- i);
typedef long long ll;
const int N = 50010;
const ll mod = 1e9 + 7;
const double Pi = acos(- 1.0);
const int INF = 0x3f3f3f3f;
const int G = 3, Gi = 332748118;
ll qpow(ll a, ll b) { ll res = 1; while(b){ if(b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1;} return res; }
ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
ll lcm(ll a, ll b) { return a * b / gcd(a, b);}
bool cmp(int a, int b){ return a > b;}
//

int T, n, m;
int head[N << 1], cnt = 0;
struct node{
    int to, nxt;
}edge[N * 4];

struct Tree{
    int l, r; int val, lz;
}tree[N * 4];
int del[N]; 
int son[N], dfn[N], dep[N], top[N], fa[N], siz[N];
int tot;

void add(int u, int v){
    edge[cnt].to = v, edge[cnt].nxt = head[u], head[u] = cnt ++;
    edge[cnt].to = u, edge[cnt].nxt = head[v], head[v] = cnt ++;
}

void pushdown(int index){
    if(tree[index].lz){
        int temp = tree[index].lz;
        tree[index].lz = 0;
        tree[index << 1].val += (tree[index << 1].r - tree[index << 1].l + 1) * temp;
        tree[index << 1 | 1].val += (tree[index << 1 | 1].r - tree[index << 1 | 1].l + 1) * temp;
        tree[index << 1].lz += temp;
        tree[index << 1 | 1].lz += temp;
    }
}

void Build(int l, int r, int index){
    tree[index].l = l, tree[index].r = r;
    tree[index].lz = 0;
    if(l == r){
        tree[index].val = 0;
        return;
    }
    int mid = (tree[index].l + tree[index].r) >> 1;
    Build(l, mid, index << 1);
    Build(mid + 1, r, index << 1 | 1);
    tree[index].val = tree[index << 1].val + tree[index << 1 | 1].val;
}

void updata(int l, int r, int index, int val){
    if(tree[index].l >= l &&  tree[index].r <= r){
        tree[index].lz += val;
        tree[index].val += val * (tree[index].r - tree[index].l + 1);
        return;
    }
    if(tree[index].lz)  pushdown(index);
    int mid = (tree[index].l + tree[index].r) >> 1;
    if(l <= mid) updata(l, r, index << 1, val);
    if(r > mid) updata(l, r, index << 1 | 1, val);
    tree[index].val = tree[index << 1].val + tree[index << 1 | 1].val;
}

int query(int l, int r, int index){
    if(l <= tree[index].l && tree[index].r <= r){
        return tree[index].val;
    }
    if(tree[index].lz) pushdown(index);
    int mid = (tree[index].l + tree[index].r) >> 1;
    int ans = 0;
    if(l <= mid) ans += query(l, r, index << 1);
    if(r > mid) ans += query(l, r, index << 1 | 1);
    return ans;
}
// -------------------------------------

void Csol(int x, int y){
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        updata(dfn[top[x]], dfn[x], 1, 2);
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x, y);
    updata(dfn[x], dfn[y], 1, 2);
}

int Qsol(int x, int y){
    int ans = 0;
    while(top[x] != top[y]){
        if(dep[top[x]] < dep[top[y]]) swap(x, y);
        ans += query(dfn[top[x]], dfn[x], 1);
        x = fa[top[x]];
    }
    if(dep[x] > dep[y]) swap(x ,y);
    ans += query(dfn[x], dfn[y], 1);
    return ans;
}

void dfs1(int u, int pre){
    dep[u] = dep[pre] + 1;
    fa[u] = pre;
    siz[u] = 1;
    int maxx = -1;
    for(int i = head[u]; i != -1; i = edge[i].nxt){
        int v = edge[i].to;
        if(v == pre) continue;
        dfs1(v, u);
        siz[u] += siz[v];
        if(siz[v] > maxx){
            maxx = siz[v];
            son[u] = v;
        }
    }
}

void dfs2(int u, int topu){ //topu当前链的最顶端的节点
    dfn[u] = ++ tot;
    top[u] = topu;
    if(!son[u]) return;
    dfs2(son[u], topu);
    for(int i = head[u]; i != -1; i = edge[i].nxt){
        int v = edge[i].to;
        if(v == son[u] || v == fa[u]) continue;
        dfs2(v, v);
    }
}

int main()
{
    scanf("%d",&T);
    while(T --){
        scanf("%d%d",&n,&m);
        cnt = 0; tot = 0;
        for(int i = 1; i <= n; ++ i){
            head[i] = -1; del[i] = 0;
            son[i] = 0;
        }
        int x, y; 
        for(int i = 1; i < n; ++ i){
            scanf("%d%d",&x,&y);
            add(x, y);
        }
        dep[0] = 0;
        dfs1(1, 0);
        dfs2(1, 1);
        Build(1, n, 1);
        
        int op;
        int wval = 0, allnum = 0;
        while(m --){
            scanf("%d",&op);
            if(op == 1){
                scanf("%d%d",&x,&y);
                Csol(x, 1);
                wval = wval + y - dep[x];
                allnum ++;
            }
            else if(op == 2){
                scanf("%d",&x);
                int res = Qsol(x, 1) + wval - allnum * dep[x];
                if(res > del[x]) del[x] = res;
            }
            else{
                scanf("%d",&x);
                int res = Qsol(x, 1) + wval - allnum * dep[x] - del[x];
                printf("%d\n",res);
            }
        }
    }
    return 0;
}
posted @ 2020-08-19 20:37  A_sc  阅读(113)  评论(0编辑  收藏  举报