【BZOJ1146】【CTSC2008】网络管理

题意:树上带修改路径第k大。

做法:树链剖分+树状数组+主席树。TuT debug了一天,各种各样的迷之错误2333.

传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=1146

#include <cstdio>
#include <cstdlib>
#include <iostream>
#include <algorithm>
#define MaxN 200010
#define MaxM 8000010
using namespace std;
int n, m, tot = 0, N = 0, S = 0, cnt = 0, len1, len2, ss = 0;
int num[MaxN], da[MaxN], Map[MaxN], A[MaxN], B[MaxN], C[MaxN];
int sz[MaxN], depth[MaxN], dfn[MaxN], son[MaxN], fa[MaxN], top[MaxN], head[MaxN];
int root[MaxN], t1[MaxN], t2[MaxN];
int sum[MaxM], ls[MaxM], rs[MaxM];
int read(){
    int ret = 0; char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    while (c >= '0' && c <= '9') ret = ret*10 + c-'0', c = getchar();
    return ret;
}

struct rec{
    int nxt, v;
}E[MaxN];

int find(int x){
    int l = 1, r = N, mid;
    while(l < r){
        mid = (l+r) >> 1;
        if (x <= Map[mid]) r = mid;
        else l = mid+1;
    }
    return l;    
}

// seg

void updata(int &x, int fa, int l, int r, int k, int c){
    if (!x) x = ++S;
    ls[x] = ls[fa], rs[x] = rs[fa], sum[x] = sum[fa]+c;
    if (l == r) return;
    int mid = (l+r) >> 1;
    if (k <= mid) updata(ls[x], ls[fa], l, mid, k, c);
    else updata(rs[x], rs[fa], mid+1, r, k, c);
}

// bit

int lowbit(int x){
    return x & (-x);
}

// HL_Decomp

void adde(int a, int b){
    E[++ss] = (rec) {head[a], b};
    head[a] = ss;
}

void dfs1(int u){
    sz[u] = 1;
    int MaxSon = 0;
    for (int i = head[u], v; i; i = E[i].nxt){
        v = E[i].v;
        if (v == fa[u]) continue;
        fa[v] = u;
        depth[v] = depth[u] + 1;
        dfs1(v); sz[u] += sz[v];
        if (sz[v] > MaxSon) MaxSon = sz[v], son[u] = v;
    }
}

void dfs2(int u, int Top){
    if (!u) return;
    top[u] = Top; dfn[u] = ++cnt;
    dfs2(son[u], Top);
    for (int i = head[u]; i; i = E[i].nxt){
        int v = E[i].v;
        if (v != son[u] && v != fa[u]) dfs2(v, v);
    }
}

int query(int u, int v){
    int ret = 0;
    while (top[u] != top[v]){
        if (depth[top[u]] > depth[top[v]]) swap(u, v);
        for (int i = dfn[top[v]]-1; i; i -= lowbit(i)) t1[++len1] = root[i];
        for (int i = dfn[v]; i; i -= lowbit(i)) t2[++len2] = root[i];
        ret += depth[v] - depth[top[v]] + 1;
        v = fa[top[v]];
    }
    if (depth[u] > depth[v]) swap(u, v);
    ret += depth[v] - depth[u] + 1;
    for (int i = dfn[u]-1; i; i -= lowbit(i)) t1[++len1] = root[i];
    for (int i = dfn[v]; i; i -= lowbit(i)) t2[++len2] = root[i];
    return ret;
} 

// qwq

void Read_Data(){
    n = read(); m = read();
    for (int i = 1; i <= n; i++) num[i] = read(), da[++tot] = num[i];
    for (int i = 1; i < n; i++){
        int u, v;
        u = read(); v = read();
        adde(u, v); adde(v, u);
    }
    for (int i = 1; i <= m; i++){
        A[i] = read(); B[i] = read(); C[i] = read();
        if (!A[i]) da[++tot] = C[i];
    }
    sort(da+1, da+1+tot);
    Map[++N] = da[1];
    for (int i = 2; i <= tot; i++) 
        if (da[i] != da[i-1]) Map[++N] = da[i];
}

int query(int k){
    int l = 1, r = N;
    while (l < r){
        int mid = (l+r) >> 1, t = 0;
        for (int i = 1; i <= len1; i++) t -= sum[ls[t1[i]]];
        for (int i = 1; i <= len2; i++) t += sum[ls[t2[i]]];
    //    printf("l:%d r:%d sum:%d k:%d\n", l, r, t, k);
        if (k <= t){
            r = mid;
            for (int i = 1; i <= len1; i++) t1[i] = ls[t1[i]];
            for (int i = 1; i <= len2; i++) t2[i] = ls[t2[i]];
        }
        else {
            k -= t; l = mid+1;
            for (int i = 1; i <= len1; i++) t1[i] = rs[t1[i]];
            for (int i = 1; i <= len2; i++) t2[i] = rs[t2[i]];
        }
    }
    return Map[l];
}

void Solve(){
    dfs1(1);
    dfs2(1, 1);
    int p, t;
    for (int i = 1; i <= n; i++){
        //cout<<dfn[i]<<endl;
        int t = find(num[i]);
        for (int j = dfn[i]; j <= n; j += lowbit(j)) updata(root[j], root[j], 1, N, t, 1);
    }
    for (int i = 1; i <= m; i++){
        if (!A[i]){
            t = find(num[B[i]]);
            for (int j = dfn[B[i]]; j <= n; j += lowbit(j)) updata(root[j], root[j], 1, N, t, -1);
            num[B[i]] = C[i]; t = find(C[i]);
            for (int j = dfn[B[i]]; j <= n; j += lowbit(j)) updata(root[j], root[j], 1, N, t, 1);            
        }
        else {
            len1 = len2 = 0; 
            int k = A[i], sum = query(B[i], C[i]);         
            if (sum < k) printf("invalid request!\n");
            else printf("%d\n", query(sum-k+1));
        }
    }
}

int main(){
    Read_Data();
    Solve();
    return 0;
}
View Code

 

posted @ 2016-03-09 19:45  Lukaluka  阅读(355)  评论(0编辑  收藏  举报