【bzoj2588】Spoj 10628. Count on a tree 离散化+主席树

题目描述

给定一棵N个节点的树,每个点有一个权值,对于M个询问(u,v,k),你需要回答u xor lastans和v这两个节点间第K小的点权。其中lastans是上一个询问的答案,初始为0,即第一个询问的u是明文。

输入

第一行两个整数N,M。
第二行有N个整数,其中第i个整数表示点i的权值。
后面N-1行每行两个整数(x,y),表示点x到点y有一条边。
最后M行每行两个整数(u,v,k),表示一组询问。

输出

M行,表示每个询问的答案。

样例输入

8 5
105 2 9 3 8 5 7 7
1 2
1 3
1 4
3 5
3 6
3 7
4 8
2 5 1
0 5 2
10 5 3
11 5 4
110 8 2

样例输出

2
8
9
105
7


题解

主席树+最近公共祖先

需要明确主席树的原理:线段树相加减。

那么A到B的路径就是 A到根的路径+B到根的路径-最近公共祖先到根的路径-最近公共祖先的父亲到根的路径。

可以直接在树上建立主席树,注意每棵树是从它父亲的树推来的。

然后查询即可。

注意最后一行千万不能有换行,否则无限PE!

#include <cstdio>
#include <algorithm>
#define N 100001
using namespace std;
struct data
{
    int num , rank;
}a[N];
int root[N] , lp[N << 5] , rp[N << 5] , sum[N << 5] , val[N] , top , tot;
int head[N] , to[N << 1] , next[N << 1] , cnt , fa[N] , bl[N] , deep[N] , si[N] , q[N] , tail;
bool cmp1(data a , data b)
{
    return a.num < b.num;
}
bool cmp2(data a , data b)
{
    return a.rank < b.rank;
}
void add(int x , int y)
{
    to[++cnt] = y;
    next[cnt] = head[x];
    head[x] = cnt;
}
void dfs1(int x)
{
    int i;
    si[x] = 1;
    for(i = head[x] ; i ; i = next[i])
    {
        if(to[i] != fa[x])
        {
            fa[to[i]] = x;
            deep[to[i]] = deep[x] + 1;
            dfs1(to[i]);
            si[x] += si[to[i]];
        }
    }
}
void dfs2(int x , int c)
{
    int i , k = 0;
    bl[x] = c;
    q[++tail] = x;
    for(i = head[x] ; i ; i = next[i])
        if(to[i] != fa[x] && si[to[i]] > si[k])
            k = to[i];
    if(k)
    {
        dfs2(k , c);
        for(i = head[x] ; i ; i = next[i])
            if(to[i] != fa[x] && to[i] != k)
                dfs2(to[i] , to[i]);
    }
}
int getlca(int x , int y)
{
    while(bl[x] != bl[y])
    {
        if(deep[bl[x]] < deep[bl[y]])
            swap(x , y);
        x = fa[bl[x]];
    }
    if(deep[x] < deep[y]) return x;
    return y;
}
void pushup(int x)
{
    sum[x] = sum[lp[x]] + sum[rp[x]];
}
void ins(int x , int &y , int l , int r , int p)
{
    y = ++tot;
    if(l == r)
    {
        sum[y] = sum[x] + 1;
        return;
    }
    int mid = (l + r) >> 1;
    if(p <= mid) rp[y] = rp[x] , ins(lp[x] , lp[y] , l , mid , p);
    else lp[y] = lp[x] , ins(rp[x] , rp[y] , mid + 1 , r , p);
    pushup(y);
}
int query(int a , int b , int c , int d , int l , int r , int p)
{
    if(l == r) return val[l];
    int mid = (l + r) >> 1;
    if(sum[lp[a]] + sum[lp[b]] - sum[lp[c]] - sum[lp[d]] >= p) return query(lp[a] , lp[b] , lp[c] , lp[d] , l , mid , p);
    else return query(rp[a] , rp[b] , rp[c] , rp[d] , mid + 1 , r , p - sum[lp[a]] - sum[lp[b]] + sum[lp[c]] + sum[lp[d]]);
}
int main()
{
    int n , m , i , x , y , z , f , last = 0;
    scanf("%d%d" , &n , &m);
    for(i = 1 ; i <= n ; i ++ )
    {
        scanf("%d" , &a[i].num);
        a[i].rank = i;
    }
    sort(a + 1 , a + n + 1 , cmp1);
    val[0] = 0x80000000;
    for(i = 1 ; i <= n ; i ++ )
    {
        if(a[i].num != val[top]) val[++top] = a[i].num;
        a[i].num = top;
    }
    sort(a + 1 , a + n + 1 , cmp2);
    for(i = 1 ; i < n ; i ++ )
    {
        scanf("%d%d" , &x , &y);
        add(x , y);
        add(y , x);
    }
    dfs1(1);
    dfs2(1 , 1);
    for(i = 1 ; i <= tail ; i ++ )
        ins(root[fa[q[i]]] , root[q[i]] , 1 , top , a[q[i]].num);
    while(m -- )
    {
        scanf("%d%d%d" , &x , &y , &z);
        x ^= last;
        f = getlca(x , y);
        last = query(root[x] , root[y] , root[f] , root[fa[f]] , 1 , top , z);
        printf("%d" , last);
        if(m) printf("\n");
    }
    return 0;
}
posted @ 2017-01-17 11:33  GXZlegend  阅读(319)  评论(0编辑  收藏  举报