树链剖分入门

hdu2586 How far away ?

题目链接

题目大意

  询问树上两点距离

解题思路

  树剖求lca,\(dis(a,b) = dis(rt, a)+dis(rt, b)-dis(rt, lca(a, b))\times 2\)

代码

const int maxn = 1e5+10;
struct E {
    int to, w, nxt;
} e[maxn<<2];
int h[maxn], tot;
void add(int u, int v, int w) {
    e[++tot] = {v, w, h[u]};
    h[u] = tot;
}
int dep[maxn], fa[maxn], sz[maxn], son[maxn], dis[maxn];
void dfs1(int u, int p) {
    sz[u] = 1;
    for (int i = h[u]; i; i=e[i].nxt) {
        int v = e[i].to;
        if (v==p) continue;
        dep[v] = dep[u]+1;
        dis[v] = dis[u]+e[i].w;
        fa[v] = u;
        dfs1(v, u);
        sz[u] += sz[v];
        if (sz[v]>sz[son[u]]) son[u] = v;
    }
}
int top[maxn];
void dfs2(int u) {
    if (u==son[fa[u]]) top[u] = top[fa[u]];
    else top[u] = u;
    for (int i = h[u]; i; i=e[i].nxt) {
        int v = e[i].to;
        if (v!=fa[u]) dfs2(v);
    }
}
int lca(int u, int v) {
    while(top[u]!=top[v]) { //不在同一条重链上
        if (dep[top[u]]>dep[top[v]]) u = fa[top[u]]; //将顶端节点深度大的上移
        else v = fa[top[v]];
    }
    return dep[u]>dep[v]?v:u; //返回深度小的节点
}
int n, m;
void init() {
    tot = 0;
    for (int i = 0; i<=n; ++i) h[i] = 0, son[i] = 0;
}
int main(void) {
    int __; cin >> __;
    while(__--) {
        cin >> n >> m; 
        init();
        for (int i = 1, a, b, c; i<n; ++i) {
            scanf("%d%d%d", &a, &b, &c);
            add(a, b, c);
            add(b, a, c);
        }
        dfs1(1, 0);
        dfs2(1);
        while(m--) {
            int a, b; scanf("%d%d", &a, &b);
            printf("%d\n", dis[a]+dis[b]-dis[lca(a, b)]*2);
        }
    }   
    return 0;
}   

bzoj 1036 树的统计Count

题目链接

题目大意

  略

解题思路

  树剖之后建线段树即可。

代码

const int INF = 0x3f3f3f3f;
const int maxn = 1e5+10;
struct E {
    int to, nxt;
} e[maxn<<2];
int h[maxn], tot;
void add(int u, int v) {
    e[++tot] = {v, h[u]};
    h[u] = tot;
}
int dep[maxn], fa[maxn], sz[maxn], son[maxn];
void dfs1(int u, int p) {
    sz[u] = 1;
    for (int i = h[u]; i; i=e[i].nxt) {
        int v = e[i].to;
        if (v==p) continue;
        dep[v] = dep[u]+1;
        fa[v] = u;
        dfs1(v, u);
        sz[u] += sz[v];
        if (sz[v]>sz[son[u]]) son[u] = v;
    }
}
int top[maxn], tim, id[maxn], rev[maxn];
void dfs2(int u, int t) {
    top[u] = t;
    id[u] = ++tim; //给结点标时间戳
    rev[tim] = u; //时间戳对应的结点
    if (!son[u]) return;
    dfs2(son[u], t); //沿着重儿子dfs
    for (int i = h[u]; i; i=e[i].nxt) {
        int v = e[i].to;
        if (v!=fa[u] && v!=son[u]) dfs2(v, v);
    }
}
struct Node {
    int mx, sum;
} tree[maxn<<2];
int n, w[maxn], maxx, sum, m;
inline void push_up(int rt) {
    tree[rt].mx = max(tree[rt<<1].mx, tree[rt<<1|1].mx);
    tree[rt].sum = tree[rt<<1].sum+tree[rt<<1|1].sum;
}
void build(int rt, int l, int r) {
    if (l==r) {
        tree[rt].mx = tree[rt].sum = w[rev[l]];
        return;
    }
    int mid = (l+r)>>1;
    build(rt<<1, l, mid);
    build(rt<<1|1, mid+1, r);
    push_up(rt);
}
void query(int rt, int l, int r, int L, int R) {
    if (l>=L && r<=R) {
        maxx = max(maxx, tree[rt].mx);
        sum += tree[rt].sum;
        return;
    }
    int mid = (l+r)>>1;
    if (L<=mid) query(rt<<1, l, mid, L, R);
    if (R>mid) query(rt<<1|1, mid+1, r, L, R);
}
void ask(int u, int v) {
    while(top[u]!=top[v]) { //不在同一条重链上
        if (dep[top[u]]<dep[top[v]]) swap(u, v);
        query(1, 1, n, id[top[u]], id[u]);
        u = fa[top[u]];
    }
    if (dep[u]>dep[v]) swap(u, v); //在一条重链上
    query(1, 1, n, id[u], id[v]);
}
void update(int rt, int l, int r, int pos, int val) {
    if (l==r) {
        tree[rt].mx = tree[rt].sum = val;
        return;
    }
    int mid = (l+r)>>1;
    if (pos<=mid) update(rt<<1, l, mid, pos, val);
    else update(rt<<1|1, mid+1, r, pos, val);
    push_up(rt);
}
char str[11];
int main(void) {
    cin >> n;
    for (int i = 1, a, b; i<n; ++i) {
        scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    }
    for (int i = 1; i<=n; ++i) scanf("%d", &w[i]);
    dep[1] = 1;
    dfs1(1, 0);
    dfs2(1, 1);
    build(1, 1, n);
    cin >> m;
    for (int i = 1, a, b; i<=m; ++i) {
        scanf("%s", str);
        scanf("%d%d", &a, &b);
        if (str[0]=='C') update(1, 1, n, id[a], b);
        else {
            sum = 0;
            maxx = -INF;
            ask(a, b);
            if (str[1]=='M') printf("%d\n", maxx);
            else printf("%d\n", sum);
        }
    }
    return 0;
}   
posted @ 2021-05-12 21:18  shuitiangong  阅读(63)  评论(0)    收藏  举报