bzoj2588: Spoj 10628. Count on a tree(树上第k大)(主席树)

  每个节点继承父节点的树,则答案为query(root[x]+root[y]-root[lca(x,y)]-root[fa[lca(x,y)]])

#include<iostream>
#include<cstring>
#include<cstdlib>
#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=100010;
struct poi{int size,lt,rt;}tree[maxn*20];
struct zs{int too,pre;}e[maxn<<1];
int n,m,x,y,z,sz,tot,N;
int d[maxn],f[maxn][20],last[maxn],root[maxn],v[maxn],b[maxn],num[maxn];
inline void read(int &k)
{
    int f=1;k=0;char c=getchar();
    while(c<'0'||c>'9')c=='-'&&(f=-1),c=getchar();
    while(c<='9'&&c>='0')k=k*10+c-'0',c=getchar();
    k*=f;
}
inline void add(int x,int y){e[++tot].too=y;e[tot].pre=last[x];last[x]=tot;}
void insert(int &x,int l,int r,int cx)
{
    tree[++sz]=tree[x];tree[sz].size++;x=sz;
    if(l==r)return;
    int mid=(l+r)>>1;
    if(cx<=mid)insert(tree[x].lt,l,mid,cx);
    else insert(tree[x].rt,mid+1,r,cx);
}
int query(int a,int b,int c,int d,int l,int r,int k)
{
    if(l==r)return l;
    int mid=(l+r)>>1;
    int t1=tree[a].lt,t2=tree[b].lt,t3=tree[c].lt,t4=tree[d].lt;
    int tmp=tree[t1].size+tree[t2].size-tree[t3].size-tree[t4].size;
    if(tmp>=k)return query(t1,t2,t3,t4,l,mid,k);
    return query(tree[a].rt,tree[b].rt,tree[c].rt,tree[d].rt,mid+1,r,k-tmp);
}
void dfs(int x,int fa)
{
    root[x]=root[fa];insert(root[x],1,N,v[x]);
    d[x]=d[fa]+1;f[x][0]=fa;
    for(int i=last[x];i;i=e[i].pre)
    if(e[i].too!=fa)dfs(e[i].too,x);
}
inline int lca(int x,int y)
{
    if(d[x]<d[y])swap(x,y);
    for(int i=19;i>=0;i--)
    if(d[f[x][i]]>=d[y])x=f[x][i];
    if(x==y)return x;
    for(int i=19;i>=0;i--)
    if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];
    return f[x][0];
}
int main()
{
    read(n);read(m);
    for(int i=1;i<=n;i++)read(v[i]),b[i]=v[i];N=n;
    sort(b+1,b+1+N);N=unique(b+1,b+1+N)-b-1;
    for(int i=1;i<=n;i++)x=lower_bound(b+1,b+1+N,v[i])-b,num[x]=v[i],v[i]=x;
    for(int i=1;i<n;i++)
    read(x),read(y),add(x,y),add(y,x);
    dfs(1,0);for(int j=1;j<20;j++)for(int i=1;i<=n;i++)f[i][j]=f[f[i][j-1]][j-1];
    int lan=0;
    for(int i=1;i<=m;i++)
    {
        read(x);read(y);read(z);x^=lan;int fq=lca(x,y);
        printf("%d",lan=num[query(root[x],root[y],root[fq],root[f[fq][0]],1,N,z)]);
        if(i!=m)puts("");
    }
}
View Code

 

posted @ 2017-09-25 19:28  Sakits  阅读(135)  评论(0编辑  收藏  举报