学习笔记::树上莫队

王室联邦:树分块,参见popoqqq大神的博客,讲得很详细

莫队:小z的袜子

树上莫队:把前面两个东西结合在一起,不要管什么xor,就是写一个solve,走过的路径赋成走过,因为lca没走过,所以没计算过,加进去,计算后再减去,因为lca最终是不需要的

苹果树(不知道对不对)

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
#define N 200010
struct edge
{
    int to,nxt;
}e[N];
struct data
{
    int u,v,a,b,id;
}q[N];
int n,m,tot,cnt=1,Time,ans,top,size;
int dfn[N],belong[N],head[N],dep[N],used[N],c[N],p[N],s[N];
int fa[30][N],answer[N];
void link(int u,int v)
{
    e[++cnt].nxt=head[u];
    head[u]=cnt;
    e[cnt].to=v;
}
bool cp(data x,data y)
{
    if(belong[x.u]!=belong[x.v]) return belong[x.u]<belong[x.v];
    return dfn[x.u]<dfn[x.v];
}
void reverse(int u)
{
    if(!used[u]) 
    {
        used[u]=1; p[c[u]]++; if(p[c[u]]==1) ans++;
    }
    else
    {
        used[u]=0; p[c[u]]--; if(!p[c[u]]) ans--;
    }
}
void dfs(int u,int last)
{
    int bottom=top+1; dfn[u]=++Time;
    for(int i=head[u];i;i=e[i].nxt) if(e[i].to!=last)
    {
        int v=e[i].to;
        dep[v]=dep[u]+1; fa[0][v]=u;
        dfs(v,u);        
        if(top-bottom+1>=size)
        {
            ++tot; ++top;
            while(top>=bottom) belong[s[--top]]=tot;
        }
    }
    s[++top]=u;
}
void solve(int u,int v)
{
    while(u!=v) 
        if(dep[u]>dep[v])
        {
            reverse(u); u=fa[0][u];
        }  
        else
        {
            reverse(v); v=fa[0][v];
        } 
}
void init()
{
    for(int i=1;i<=22;i++)
        for(int j=1;j<=n;j++) if(fa[i-1][j]!=-1) fa[i][j]=fa[i-1][fa[i-1][j]];
}
int lca(int u,int v)
{
    if(dep[u]<dep[v]) swap(u,v);
    int temp=dep[u]-dep[v];
    for(int i=22;i>=0;i--) 
        if(temp&(1<<i)) u=fa[i][u];
    if(u==v) return u;
    for(int i=22;i>=0;i--)
    {
        if(fa[i][u]!=fa[i][v]) 
        {
            u=fa[i][u];
            v=fa[i][v];
        }
    }
    return fa[0][u];
}
int main()
{
    memset(fa,-1,sizeof(fa));
    scanf("%d%d",&n,&m);
    size=(int)(sqrt(n));
    int root=1;
    for(int i=1;i<=n;i++) scanf("%d",&c[i]);
    for(int i=1;i<=n;i++)
    {
        int u,v; scanf("%d%d",&u,&v);
        if(!u) root=v; else if(!v) root=u;
        else 
        {
            link(u,v); link(v,u);
        }
    }
    dfs(root,0);
    init();
    for(int i=1;i<=m;i++)
    {
        scanf("%d%d%d%d",&q[i].u,&q[i].v,&q[i].a,&q[i].b);
        q[i].id=i;
    }    
    sort(q+1,q+m+1,cp);
    solve(q[1].u,q[1].v);
    int x=lca(q[1].u,q[1].v);
    reverse(x);
    answer[q[1].id]=ans;
    if(p[q[1].a]&&p[q[1].b]&&q[1].a!=q[1].b) answer[q[1].id]--;
    reverse(x);
    for(int i=2;i<=m;i++)
    {
        solve(q[i-1].u,q[i].u);
        solve(q[i-1].v,q[i].v);
        int x=lca(q[i].u,q[i].v);
        reverse(x);
        answer[q[i].id]=ans;
        if(p[q[i].a]&&p[q[i].b]&&q[i].a!=q[i].b) answer[q[i].id]--;
        reverse(x);
    }
    for(int i=1;i<=m;i++) printf("%d\n",answer[i]);
    return 0;
}
View Code

 

posted @ 2017-01-25 07:31  19992147  阅读(193)  评论(0编辑  收藏  举报