SPOJ COT2 Count on a tree II (树上莫队)

题目链接:http://www.spoj.com/problems/COT2/

参考博客:http://www.cnblogs.com/xcw0754/p/4763804.html
上面这个人推导部分写的简洁明了,方便理解,但是最后的分情况讨论有些迷,感觉是不必要的,更简洁的思路看下面的博客

传送门:http://blog.csdn.net/kuribohg/article/details/41458639

题意是这样的,给你一棵无根树,给你m个查询,每次查询输出节点x到节点y路径上不同颜色的节点有多少个(包括xy)

没什么说的,直接上代码吧

  1 #include<cstdio>
  2 #include<cstring>
  3 #include<iostream>
  4 #include<algorithm>
  5 #include<cmath>
  6 #include<vector>
  7 using namespace std;
  8 int const MAX_N=40000+5;
  9 int const MAX_M=100000+5;
 10 int n,m,ans,nowLca,lastL,lastR,tempAns;
 11 int color[MAX_N];
 12 vector<int> allColor;
 13 vector<int>::iterator it;
 14 int lin[MAX_N],e_cnt;
 15 struct Query{
 16     int x;
 17     int y;
 18     int id;
 19     int ans;
 20 }query[MAX_M];
 21 struct Edge{
 22     int next;
 23     int y;
 24 }e[2*MAX_N];
 25 void insert(int u,int v){
 26     e[++e_cnt].next=lin[u];
 27     e[e_cnt].y=v;
 28     lin[u]=e_cnt;
 29 }
 30 int b_len,b_cnt,bk[MAX_N];
 31 int fa[MAX_N][20],deep[MAX_N];
 32 int n_cnt,dfn[MAX_N],top,stack[MAX_N];
 33 bool vis[MAX_N];
 34 void dfs(int u){
 35     dfn[u]=++n_cnt;
 36     int btm=top;
 37     for(int i=lin[u];i;i=e[i].next){
 38         int v=e[i].y;
 39         if(!dfn[v]){
 40             fa[v][0]=u;
 41             deep[v]=deep[u]+1;
 42             dfs(v);
 43             if(top-btm>=b_len){
 44                 b_cnt++;
 45                 while(top>btm){
 46                     bk[stack[top--]]=b_cnt;
 47                 }
 48             }
 49         }
 50     }
 51     stack[++top]=u;
 52 }
 53 bool queryCmp_xkb_ydfn(Query a,Query b){
 54     if(bk[a.x]==bk[b.x])return dfn[a.y]<dfn[b.y];
 55     return bk[a.x]<bk[b.x];
 56 }
 57 bool queryCmp_id(Query a,Query b){
 58     return a.id<b.id;
 59 }
 60 int lca(int u,int v){
 61     if(deep[u]<deep[v])swap(u,v);
 62     for(int i=17;~i;i--)
 63         if(deep[fa[u][i]]>=deep[v])
 64             u=fa[u][i];
 65     if(u == v) return u;
 66     for(int i=17;~i;i--)
 67         if(fa[u][i]!=fa[v][i])
 68             u=fa[u][i],v=fa[v][i];
 69     return fa[u][0];
 70 }
 71 int c_cnt[MAX_N];
 72 void MoveToLca(int u){
 73     for(;u!=nowLca;u=fa[u][0]){
 74         if(vis[u]){
 75             vis[u]=false;
 76             c_cnt[color[u]]--;
 77             if(!c_cnt[color[u]])ans--;
 78         }else{
 79             vis[u]=true;
 80             if(!c_cnt[color[u]])ans++;
 81             c_cnt[color[u]]++;
 82         }
 83     }
 84 }
 85 int GetAns(int x,int u,int v){
 86     int tlca=lca(u,v);
 87     if(c_cnt[color[tlca]])return x;
 88     else return x+1;
 89 }
 90 
 91 void FirstMo(){
 92     int L=query[1].x,R=query[1].y;
 93     nowLca=lca(L,R);
 94     MoveToLca(L);
 95     MoveToLca(R);
 96     query[1].ans=GetAns(ans,L,R);
 97     lastL=L;lastR=R;
 98 }
 99 int main()
100 {
101     while(~scanf("%d%d",&n,&m)){
102         allColor.clear();
103         //read colors
104         for(int i=1;i<=n;i++){
105             scanf("%d",&color[i]);
106             allColor.push_back(color[i]);
107         }
108         //Discretization the color
109         sort(allColor.begin(),allColor.end());
110         it=unique(allColor.begin(),allColor.end());
111         allColor.resize(distance(allColor.begin(),it));
112         for(int i=1;i<=n;i++){
113             color[i]=lower_bound(allColor.begin(),allColor.end(),color[i])-allColor.begin()+1;
114         }
115         //read the tree
116         memset(lin,0,sizeof(lin));
117         e_cnt=0;
118         for(int i=1;i<n;i++){
119             int u,v;
120             scanf("%d%d",&u,&v);
121             insert(u,v);
122             insert(v,u);
123         }
124         //divide in blocks
125         //prepare for lca: get deep, dfs order, get father
126         b_len=sqrt((double)n);
127         b_cnt=0;
128         n_cnt=0;
129         top=0;
130         deep[0]=0,deep[1]=1;
131         memset(fa,0,sizeof(fa));
132         memset(dfn,0,sizeof(dfn));
133         memset(vis,0,sizeof(vis));
134         memset(stack,0,sizeof(stack));
135         dfs(1);
136         b_cnt++;
137         while(top){
138             bk[stack[top--]]=b_cnt;
139         }
140         for(int i = 1; (1<<i) <= n; i++)
141             for(int j = 1; j <= n; j++)
142                 fa[j][i] = fa[fa[j][i-1]][i-1];
143         //read the query
144         for(int i=1;i<=m;i++){
145             scanf("%d%d",&query[i].x,&query[i].y);
146             query[i].id=i;
147             if(dfn[query[i].x]>dfn[query[i].y])
148                 swap(query[i].x,query[i].y);
149         }
150         //mo's algorithm
151         sort(query+1,query+m+1,queryCmp_xkb_ydfn);
152         memset(vis,0,sizeof(vis));
153         memset(c_cnt,0,sizeof(c_cnt));
154         lastL=1,lastR=1,ans=0;
155         for(int i=1;i<=m;i++){
156             int L=query[i].x,R=query[i].y;
157             nowLca=lca(L,lastL);
158             MoveToLca(L);
159             MoveToLca(lastL);
160             nowLca=lca(R,lastR);
161             MoveToLca(R);
162             MoveToLca(lastR);
163             query[i].ans=GetAns(ans,L,R);
164             lastL=L;lastR=R;
165         }
166         sort(query+1,query+m+1,queryCmp_id);
167         for(int i=1;i<=m;i++){
168             printf("%d\n",query[i].ans);
169         }
170     }
171     return 0;
172 }

wa了无数发,最后发现是dfs的时候用vis数组判断点是否访问过,结果忘记初始化vis[1]了。。。。好蠢,后来改成了直接用dfn数组判断是否访问过了,本意是节省资源,结果居然A了。。。人生处处是惊喜。

这里顺便总结下lca写法吧。

这里的是倍增法求LCA

int fa[MAX_N][20],deep[MAX_N];
int dfn[MAX_N];

void dfs(int u){
    dfn[u]=++n_cnt;
    for(int i=lin[u];i;i=e[i].next){
        int v=e[i].y;
        if(!dfn[v]){
            fa[v][0]=u;
            deep[v]=deep[u]+1;
            dfs(v);
        }
    }
}
int lca(int u,int v){
    if(deep[u]<deep[v])swap(u,v);
    for(int i=17;~i;i--)
        if(deep[fa[u][i]]>=deep[v])
            u=fa[u][i];
    if(u == v) return u;
    for(int i=17;~i;i--)
        if(fa[u][i]!=fa[v][i])
            u=fa[u][i],v=fa[v][i];
    return fa[u][0];
}

int main()
{
        n_cnt=0;
        deep[0]=0,deep[1]=1;
        memset(fa,0,sizeof(fa));
        memset(dfn,0,sizeof(dfn));
        dfs(1);

        for(int i = 1; (1<<i) <= n; i++)
            for(int j = 1; j <= n; j++)
                fa[j][i] = fa[fa[j][i-1]][i-1];
            return 0;
}

dfn换成vis也是一样的

首先初始化deep数组表示节点深度,然后是fa[i][j]表示节点i的第2^j的父亲节点是什么,

先dfs求出deep和直系父亲

然后双重循环处理fa数组

lca的时候,先对齐u和v,就是让他们处于同一深度,然后一起向上,直到lca的儿子为止,返回他的父亲,就是lca。

posted @ 2016-09-06 18:54  徐王  阅读(292)  评论(0编辑  收藏  举报