BZOJ 1036: [ZJOI2008]树的统计Count

题意:

一棵树上有n个节点,编号分别为1到n,每个节点都有一个权值w。我们将以下面的形式来要求你对这棵树完成一些操作: I. CHANGE u t : 把结点u的权值改为t II. QMAX u v: 询问从点u到点v的路径上的节点的最大权值 III. QSUM u v: 询问从点u到点v的路径上的节点的权值和 注意:从点u到点v的路径上的节点包括u和v本身

题解:

这是树链剖分的模板题,不过我树链剖分写挂了T_T,只有抄网上的版。。。

代码:

来源:http://coraon.com/zjoi-2008/

#include <cstdio>
#include <cstring>
#include <vector>
#include <algorithm>
#define MAXN 30001
#define INF 0x3f3f3f3f
#define lchild rt << 1, l, m
#define rchild rt << 1 | 1, m + 1, r
using namespace std;
int n, w[MAXN], mw[MAXN];
vector<int>e[MAXN];
 
class Segment_Tree{
private:
    int sum[MAXN << 2], upper[MAXN << 2];
    void push_up(int rt){
        sum[rt] = sum[rt << 1] + sum[rt << 1 | 1];
        upper[rt] = max(upper[rt << 1], upper[rt << 1 | 1]);
    }
public:
    void build(int rt = 1, int l = 1, int r = n){
        if(l == r){ sum[rt] = upper[rt] = mw[l]; return; }
        sum[rt] = 0; upper[rt] = -INF;
        int m = (l + r) >> 1;
        build(lchild); build(rchild);
        push_up(rt);
    }
    void update(int P, int val, int rt = 1, int l = 1, int r = n){
        if(l == r) { sum[rt] = upper[rt] = val; return; }
        int m = (l + r) >> 1;
        if(P <= m) update(P, val, lchild);
        else update(P, val, rchild);
        push_up(rt);
    }
    int query(int L, int R, bool opt, int rt = 1, int l = 1, int r = n){
        if(L <= l && r <= R){
            if(opt) return upper[rt];
            else return sum[rt];
        }
        int m = (l + r) >> 1;
        if(opt){
            int lans = -INF, rans = -INF;
            if(L <= m) lans = query(L, R, opt, lchild);
            if(R > m) rans = query(L, R, opt, rchild);
            return max(lans, rans);
        }
        else{
            if(L > m) return query(L, R, opt, rchild);
            else if(R <= m) return query(L, R, opt, lchild);
            else return query(L, m, opt, lchild) + query(m + 1, R, opt, rchild);
        }
    }
};
 
class HLD: public Segment_Tree{
public:
    int dep[MAXN], fa[MAXN], sz[MAXN];
    int son[MAXN], top[MAXN], dfn[MAXN], dfs_clock;
 
    void init(){
        memset(dep, 0, sizeof(dep));
        memset(son, 0, sizeof(son));
        dep[1] = 1;
        dfs_clock = 0;
    }
 
    void dfs1(int u){
        sz[u] = 1;
        for(int i = 0; i < e[u].size(); i++){
            int v = e[u][i];
            if(dep[v]) continue;
            dep[v] = dep[u] + 1;
            fa[v] = u;
            dfs1(v);
            sz[u] += sz[v];
            if(sz[son[u]] < sz[v])
                son[u] = v;
        }
    }
 
    void dfs2(int u, int tp){
        top[u] = tp;
        dfn[u] = ++dfs_clock;
        mw[dfn[u]] = w[u];
        if(son[u]) dfs2(son[u], tp); //拉链
        for(int i = 0; i < e[u].size(); i++){
            int v = e[u][i];
            if(v == fa[u] || v == son[u]) continue;
            dfs2(v, v); //建链
        }
    }
 
    int getsum(int u, int v){
        int ans = 0;
        while(top[u] != top[v]){ //一直爬直到在u, v同一条重链
            if(dep[top[u]] > dep[top[v]]) swap(u, v);
            ans += query(dfn[top[v]], dfn[v], 0);
            v = fa[top[v]];
        }
        if(dep[u] > dep[v]) swap(u, v);
        ans += query(dfn[u], dfn[v], 0); //属于同一条重链的时候直接区间询问
        return ans;
    }
 
    int getmax(int u, int v){
        int ans = -INF;
        while(top[u] != top[v]){
            if(dep[top[u]] > dep[top[v]]) swap(u, v);
            ans = max(ans, query(dfn[top[v]], dfn[v], 1));
            v = fa[top[v]];
        }
        if(dep[u] > dep[v]) swap(u, v);
        ans = max(ans, query(dfn[u], dfn[v], 1));
        return ans;
    }
}hld;
 
int main(){
#ifdef _DEBUG
    freopen("d:\\2008.txt", "r", stdin);
#endif
    char opt[10];
    int u, v, m;
    while(scanf("%d", &n) != EOF){
        for(int i = 1; i <= n; i++)
            e[i].clear();
        hld.init();
        for(int i = 1; i < n; i++){
            scanf("%d %d", &u, &v);
            e[u].push_back(v);
            e[v].push_back(u);
        }
        for(int i = 1; i <= n; i++)
            scanf("%d", w + i);
        hld.dfs1(1);
        hld.dfs2(1, 1);
        hld.build();
        scanf("%d", &m);
        for(int i = 0; i < m; i++){
            scanf("%s %d %d", opt, &u, &v);
            if(opt[0] == 'C')
                hld.update(hld.dfn[u], v);
            else if(opt[1] == 'M')
                printf("%d\n", hld.getmax(u, v));
            else
                printf("%d\n", hld.getsum(u, v));
        }
    }
    return 0;
}

 

 

 

posted @ 2015-09-22 22:07  好地方bug  阅读(156)  评论(0编辑  收藏  举报