bzoj3351:[ioi2009]Regions

思路:首先如果颜色相同直接利用以前的答案即可,可以离线排序或是在线hash,然后考虑怎么快速统计答案。

首先如果点a是点b的祖先,那么一定有点b在以点a为根的子树的dfs序区间内的,于是先搞出dfs序。

然后如果颜色a的点数很小,颜色b的点数很大,那么可以考虑枚举a的点数,然后对于每一种颜色开个vector记录一下有哪些点是这种颜色,然后按照它们的dfs序排序,就可以用颜色a中的每个点在颜色b中二分出哪些点属于以该点为根的子树对应的dfs序区间了。复杂度O(size(a)*log(size(b))),size(a)表示颜色a的vector的size()。

然后如果颜色b的点数很小,颜色a的点数很大,那么就枚举b的点数,这时要考虑的问题就成了一个点被多少段区间覆盖了,然后离散化差分预处理,再去二分(我写的是vector的离散化)。复杂度O(size(b)*log(size(a)))

但如果a,b的点数差不多且都很大(也就是几乎为sqrt(n)),那么算法复杂度就会变成O(sqrt(n)*log(n))了,再乘以一个q就会GG,于是只能另寻他法,然后可以发现直接两个指针扫过去,一个扫区间端点另一个扫要询问的点,然后如果扫到一个点就直接统计答案,然后这就变成了O(size(a)+size(b))了。

那这个很大是有多大,很小是有多小呢?

对于第一种算法使用条件是size(b)>x,第二种算法使用条件是size(a)>x,其余则用第三种算法。

对于第一、二种情况,时间复杂度最大是O(n^2logn/x),然后对于第三种则是O(n*x),然后根据基本不等式x=sqrt(nlogn),总时间复杂度为O(n*sqrt(nlogn))。

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<cmath>
using namespace std;
#define maxn 200005
#define maxr 30000
 
int n,r,Q,tot,cnt;
int now[maxn],pre[2*maxn],son[2*maxn],color[maxn],dfn[maxn],size[maxn];
long long ans[maxn];
 
inline int read(){
    int x=0,f=1;char ch=getchar();
    for (;ch<'0'||ch>'9';ch=getchar()) if (ch=='-') f=-1;
    for (;ch>='0'&&ch<='9';ch=getchar()) x=x*10+ch-'0';
    return x*f;
}
 
struct node{
    int dfn,bo;
    node(){}
    node(int a,int b){dfn=a,bo=b;}
    bool operator <(const node &a)const{return dfn<a.dfn;}
};
 
struct query{
    int x,y,id;
    bool operator <(const query &a)const{return x<a.x||(x==a.x&&y<a.y);}
}q[maxn];
 
bool cmp(int a,int b){return dfn[a]<dfn[b];}
 
vector<int> col[maxr],val[maxr];
vector<node> v[maxr];
vector<int> fuckpps[maxr];
 
void add(int a,int b){
    son[++tot]=b;
    pre[tot]=now[a];
    now[a]=tot;
}
 
void link(int a,int b){
    add(a,b),add(b,a);
}
 
void dfs(int x,int fa){
    dfn[x]=++cnt;
    for (int p=now[x];p;p=pre[p])
        if (son[p]!=fa) dfs(son[p],x),size[x]+=size[son[p]]+1;
}
 
int binary_search(int l,int r,int b,int pos){
    int ans=-1;
    while (l<=r){
        int mid=(l+r)>>1;
        if (pos>=fuckpps[b][mid]) ans=mid,l=mid+1;
        else r=mid-1;
    }
    return ans+1;
}
 
long long solve1(int a,int b){
    long long ans=0;
    for (unsigned int i=0;i<col[a].size();i++){
        int x=col[a][i],l=binary_search(0,fuckpps[b].size()-1,b,dfn[x]-1),r=binary_search(0,fuckpps[b].size()-1,b,dfn[x]+size[x]);
        ans+=r-l;
    }
    return ans;
}
 
int binary_search2(int l,int r,int b,int pos){
    int ans=-1;
    while (l<=r){
        int mid=(l+r)>>1;
        if (v[b][mid].dfn<=pos) ans=mid,l=mid+1;
        else r=mid-1;
    }
    return ans;
}
 
long long solve2(int a,int b){
    long long ans=0;
    for (unsigned int i=0;i<col[b].size();i++){
        int x=col[b][i],pos=binary_search2(0,v[a].size()-1,a,dfn[x]);
        if (pos!=-1) ans+=val[a][pos];
    }
    return ans;
}
 
long long solve3(int a,int b){
    long long ans=0;unsigned int i=0,j=0,tt=0;
    while (i<v[a].size() && j<col[b].size())
        if (v[a][i].dfn<=dfn[col[b][j]]) tt=val[a][i],i++;else ans+=tt,j++;
    return ans;
}
 
int main(){
    n=read(),r=read(),Q=read();int siz=sqrt(n*log2(n));
    for (int i=1,x;i<=n;i++){
        if (i!=1) x=read(),link(i,x);
        color[i]=read();col[color[i]].push_back(i);
    }
    dfs(1,0);
    for (int i=1;i<=n;i++) fuckpps[color[i]].push_back(dfn[i]);
    for (int i=1;i<=r;i++) sort(col[i].begin(),col[i].end(),cmp),sort(fuckpps[i].begin(),fuckpps[i].end());
    for (int i=1;i<=r;i++){
        for (unsigned int j=0;j<col[i].size();j++)
            v[i].push_back(node(dfn[col[i][j]],1)),v[i].push_back(node(dfn[col[i][j]]+size[col[i][j]]+1,-1));
        sort(v[i].begin(),v[i].end());int sum=0;
        for (unsigned int j=0;j<v[i].size();j++){
            sum+=v[i][j].bo;
            val[i].push_back(sum);
        }
    }
    for (int i=1;i<=Q;i++) q[i].x=read(),q[i].y=read(),q[i].id=i;
    sort(q+1,q+Q+1);
    for (int i=1;i<=Q;i++){
        if (q[i].x==q[i-1].x && q[i].y==q[i-1].y){ans[q[i].id]=ans[q[i-1].id];continue;}
        if (col[q[i].y].size()+1>=siz&&col[q[i].x].size()+1<siz) ans[q[i].id]=solve1(q[i].x,q[i].y);
        else if (col[q[i].x].size()+1>=siz&&col[q[i].y].size()+1<siz) ans[q[i].id]=solve2(q[i].x,q[i].y);
        else ans[q[i].id]=solve3(q[i].x,q[i].y);
    }
    for (int i=1;i<=Q;i++) printf("%lld\n",ans[i]);
    return 0;
}
View Code

 

posted @ 2016-10-14 17:44  DUXT  阅读(545)  评论(0编辑  收藏  举报