BZOJ - 2588 Spoj 10628. Count on a tree (可持久化线段树+LCA/树链剖分)

题目链接

第一种方法,dfs序上建可持久化线段树,然后询问的时候把两点之间的所有树链扒出来做差。

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const int N=1e5+10,inf=0x3f3f3f3f;
 5 int hd[N],ne,n,n2,m,fa[N],son[N],siz[N],dep[N],top[N],dfn[N],rnk[N],tot,rt[N],ls[N*20],rs[N*20],val[N*20],tot2,a[N],b[N],ql[100],qr[100],nl,nr;
 6 struct E {int v,nxt;} e[N<<1];
 7 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;}
 8 void dfs1(int u,int f,int d) {
 9     fa[u]=f,son[u]=0,siz[u]=1,dep[u]=d;
10     for(int i=hd[u]; ~i; i=e[i].nxt) {
11         int v=e[i].v;
12         if(v==fa[u])continue;
13         dfs1(v,u,d+1),siz[u]+=siz[v];
14         if(siz[v]>siz[son[u]])son[u]=v;
15     }
16 }
17 void dfs2(int u,int tp) {
18     top[u]=tp,dfn[u]=++tot,rnk[tot]=u;
19     if(son[u])dfs2(son[u],tp);
20     for(int i=hd[u]; ~i; i=e[i].nxt) {
21         int v=e[i].v;
22         if(v==fa[u]||v==son[u])continue;
23         dfs2(v,v);
24     }
25 }
26 #define mid ((l+r)>>1)
27 void upd(int& u,int v,int x,int l=1,int r=n2) {
28     if(!u)u=++tot2;
29     val[u]=val[v]+1;
30     if(l==r)return;
31     if(x<=mid)upd(ls[u],ls[v],x,l,mid),rs[u]=rs[v];
32     else upd(rs[u],rs[v],x,mid+1,r),ls[u]=ls[v];
33 }
34 int ask(int u,int v,int k) {
35     for(nl=nr=0; top[u]!=top[v]; u=fa[top[u]]) {
36         if(dep[top[u]]<dep[top[v]])swap(u,v);
37         ql[nl++]=rt[dfn[top[u]]-1],qr[nr++]=rt[dfn[u]];
38     }
39     if(dep[u]<dep[v])swap(u,v);
40     ql[nl++]=rt[dfn[v]-1],qr[nr++]=rt[dfn[u]];
41     int l=1,r=n2;
42     while(l<r) {
43         int cnt=0;
44         for(int i=0; i<nr; ++i)cnt+=val[ls[qr[i]]];
45         for(int i=0; i<nl; ++i)cnt-=val[ls[ql[i]]];
46         if(k<=cnt) {
47             for(int i=0; i<nr; ++i)qr[i]=ls[qr[i]];
48             for(int i=0; i<nl; ++i)ql[i]=ls[ql[i]];
49             r=mid;
50         } else {
51             k-=cnt;
52             for(int i=0; i<nr; ++i)qr[i]=rs[qr[i]];
53             for(int i=0; i<nl; ++i)ql[i]=rs[ql[i]];
54             l=mid+1;
55         }
56     }
57     return l;
58 }
59 int main() {
60     memset(hd,-1,sizeof hd),ne=0;
61     scanf("%d%d",&n,&m);
62     for(int i=1; i<=n; ++i)scanf("%d",&a[i]);
63     for(int i=1; i<=n; ++i)b[i-1]=a[i];
64     sort(b,b+n),n2=unique(b,b+n)-b;
65     for(int i=1; i<=n; ++i)a[i]=(lower_bound(b,b+n2,a[i])-b)+1;
66     for(int i=1; i<n; ++i) {
67         int u,v;
68         scanf("%d%d",&u,&v);
69         addedge(u,v),addedge(v,u);
70     }
71     tot=0,dfs1(1,0,1),dfs2(1,1);
72     memset(rt,0,sizeof rt),tot2=0;
73     for(int i=1; i<=n; ++i)upd(rt[i],rt[i-1],a[rnk[i]],1,n2);
74     for(int last=0; m--;) {
75         int u,v,k;
76         scanf("%d%d%d",&u,&v,&k),u^=last;
77         int ans=b[ask(u,v,k)-1];
78         printf("%d\n",ans),last=ans;
79     }
80     return 0;
81 }
View Code

仔细一想这样似乎麻烦了点。因为没有修改操作,我们可以直接用子结点继承父节点的方式来建线段树,然后查询的时候,用u,v的线段树减去lca的线段树再减去lca父节点的线段树即可。

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const int N=1e5+10,inf=0x3f3f3f3f;
 5 int hd[N],ne,n,n2,m,fa[N],son[N],siz[N],dep[N],top[N],dfn[N],rnk[N],tot,rt[N],ls[N*20],rs[N*20],val[N*20],tot2,a[N],b[N],ql[100],qr[100],nl,nr;
 6 struct E {int v,nxt;} e[N<<1];
 7 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;}
 8 void dfs1(int u,int f,int d) {
 9     fa[u]=f,son[u]=0,siz[u]=1,dep[u]=d;
10     for(int i=hd[u]; ~i; i=e[i].nxt) {
11         int v=e[i].v;
12         if(v==fa[u])continue;
13         dfs1(v,u,d+1),siz[u]+=siz[v];
14         if(siz[v]>siz[son[u]])son[u]=v;
15     }
16 }
17 void dfs2(int u,int tp) {
18     top[u]=tp,dfn[u]=++tot,rnk[tot]=u;
19     if(son[u])dfs2(son[u],tp);
20     for(int i=hd[u]; ~i; i=e[i].nxt) {
21         int v=e[i].v;
22         if(v==fa[u]||v==son[u])continue;
23         dfs2(v,v);
24     }
25 }
26 #define mid ((l+r)>>1)
27 void upd(int& u,int v,int x,int l=1,int r=n2) {
28     if(!u)u=++tot2;
29     val[u]=val[v]+1;
30     if(l==r)return;
31     if(x<=mid)upd(ls[u],ls[v],x,l,mid),rs[u]=rs[v];
32     else upd(rs[u],rs[v],x,mid+1,r),ls[u]=ls[v];
33 }
34 void dfs3(int u) {
35     upd(rt[u],rt[fa[u]],a[u]);
36     for(int i=hd[u]; ~i; i=e[i].nxt) {
37         int v=e[i].v;
38         if(v==fa[u])continue;
39         dfs3(v);
40     }
41 }
42 int lca(int u,int v) {
43     for(; top[u]!=top[v]; u=fa[top[u]])if(dep[top[u]]<dep[top[v]])swap(u,v);
44     return dep[u]<dep[v]?u:v;
45 }
46 int ask(int u,int v,int w1,int w2,int k,int l=1,int r=n2) {
47     if(l==r)return l;
48     int cnt=val[ls[u]]+val[ls[v]]-val[ls[w1]]-val[ls[w2]];
49     return k<=cnt?ask(ls[u],ls[v],ls[w1],ls[w2],k,l,mid):ask(rs[u],rs[v],rs[w1],rs[w2],k-cnt,mid+1,r);
50 }
51 int main() {
52     memset(hd,-1,sizeof hd),ne=0;
53     scanf("%d%d",&n,&m);
54     for(int i=1; i<=n; ++i)scanf("%d",&a[i]);
55     for(int i=1; i<=n; ++i)b[i-1]=a[i];
56     sort(b,b+n),n2=unique(b,b+n)-b;
57     for(int i=1; i<=n; ++i)a[i]=(lower_bound(b,b+n2,a[i])-b)+1;
58     for(int i=1; i<n; ++i) {
59         int u,v;
60         scanf("%d%d",&u,&v);
61         addedge(u,v),addedge(v,u);
62     }
63     tot=0,dfs1(1,0,1),dfs2(1,1);
64     memset(rt,0,sizeof rt),tot2=0;
65     dfs3(1);
66     for(int last=0; m--;) {
67         int u,v,k;
68         scanf("%d%d%d",&u,&v,&k),u^=last;
69         int w=lca(u,v);
70         int ans=b[ask(rt[u],rt[v],rt[w],rt[fa[w]],k)-1];
71         printf("%d\n",ans),last=ans;
72     }
73     return 0;
74 }
View Code

然后我又测试了倍增和RMQ求LCA的方法,发现居然还不如dfs序+可持久化线段树的方法快~~毕竟倍增和RMQ预处理的时间和空间复杂度都是$O(nlogn)$,而树剖只需要$O(n)$,而且查询速度也比较快。

倍增:

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const int N=1e5+10,inf=0x3f3f3f3f;
 5 int hd[N],ne,n,n2,m,fa[N][20],dep[N],rt[N],ls[N*20],rs[N*20],val[N*20],tot2,a[N],b[N];
 6 struct E {int v,nxt;} e[N<<1];
 7 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;}
 8 #define mid ((l+r)>>1)
 9 void upd(int& u,int v,int x,int l=1,int r=n2) {
10     if(!u)u=++tot2;
11     val[u]=val[v]+1;
12     if(l==r)return;
13     if(x<=mid)upd(ls[u],ls[v],x,l,mid),rs[u]=rs[v];
14     else upd(rs[u],rs[v],x,mid+1,r),ls[u]=ls[v];
15 }
16 void dfs(int u,int f,int d) {
17     fa[u][0]=f,dep[u]=d,upd(rt[u],rt[fa[u][0]],a[u]);
18     for(int i=1; i<20; ++i)fa[u][i]=fa[fa[u][i-1]][i-1];
19     for(int i=hd[u]; ~i; i=e[i].nxt) {
20         int v=e[i].v;
21         if(v==fa[u][0])continue;
22         dfs(v,u,d+1);
23     }
24 }
25 int lca(int u,int v) {
26     if(dep[u]<dep[v])swap(u,v);
27     for(int i=19; dep[u]!=dep[v]; --i)if(dep[fa[u][i]]>=dep[v])u=fa[u][i];
28     if(u==v)return u;
29     for(int i=19; i>=0; --i)if(fa[u][i]!=fa[v][i])u=fa[u][i],v=fa[v][i];
30     return fa[u][0];
31 }
32 int ask(int u,int v,int w1,int w2,int k,int l=1,int r=n2) {
33     if(l==r)return l;
34     int cnt=val[ls[u]]+val[ls[v]]-val[ls[w1]]-val[ls[w2]];
35     return k<=cnt?ask(ls[u],ls[v],ls[w1],ls[w2],k,l,mid):ask(rs[u],rs[v],rs[w1],rs[w2],k-cnt,mid+1,r);
36 }
37 int main() {
38     memset(hd,-1,sizeof hd),ne=0;
39     scanf("%d%d",&n,&m);
40     for(int i=1; i<=n; ++i)scanf("%d",&a[i]);
41     for(int i=1; i<=n; ++i)b[i-1]=a[i];
42     sort(b,b+n),n2=unique(b,b+n)-b;
43     for(int i=1; i<=n; ++i)a[i]=(lower_bound(b,b+n2,a[i])-b)+1;
44     for(int i=1; i<n; ++i) {
45         int u,v;
46         scanf("%d%d",&u,&v);
47         addedge(u,v),addedge(v,u);
48     }
49     memset(rt,0,sizeof rt),tot2=0;
50     dfs(1,0,1);
51     for(int last=0; m--;) {
52         int u,v,k;
53         scanf("%d%d%d",&u,&v,&k),u^=last;
54         int w=lca(u,v);
55         int ans=b[ask(rt[u],rt[v],rt[w],rt[fa[w][0]],k)-1];
56         printf("%d\n",ans),last=ans;
57     }
58     return 0;
59 }
View Code

RMQ:

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 typedef long long ll;
 4 const int N=1e5+10,inf=0x3f3f3f3f;
 5 int hd[N],ne,n,n2,m,fa[N],dep[N],pos[N],ST[N<<1][20],Log[N<<1],tot,rt[N],ls[N*20],rs[N*20],val[N*20],tot2,a[N],b[N];
 6 struct E {int v,nxt;} e[N<<1];
 7 void addedge(int u,int v) {e[ne]= {v,hd[u]},hd[u]=ne++;}
 8 #define mid ((l+r)>>1)
 9 void upd(int& u,int v,int x,int l=1,int r=n2) {
10     if(!u)u=++tot2;
11     val[u]=val[v]+1;
12     if(l==r)return;
13     if(x<=mid)upd(ls[u],ls[v],x,l,mid),rs[u]=rs[v];
14     else upd(rs[u],rs[v],x,mid+1,r),ls[u]=ls[v];
15 }
16 void dfs(int u,int f,int d) {
17     fa[u]=f,dep[u]=d,ST[++tot][0]=u,pos[u]=tot,upd(rt[u],rt[fa[u]],a[u]);
18     for(int i=hd[u]; ~i; i=e[i].nxt) {
19         int v=e[i].v;
20         if(v==fa[u])continue;
21         dfs(v,u,d+1),ST[++tot][0]=u;
22     }
23 }
24 bool cmp(int a,int b) {return dep[a]<dep[b];}
25 void initST() {
26     for(int j=1; j<20; ++j)
27         for(int i=1; i+(1<<j)-1<=tot; ++i)
28             ST[i][j]=min(ST[i][j-1],ST[i+(1<<(j-1))][j-1],cmp);
29 }
30 int lca(int u,int v) {
31     int l=pos[u],r=pos[v];
32     if(l>r)swap(l,r);
33     int i=Log[r-l+1];
34     return min(ST[l][i],ST[r-(1<<i)+1][i],cmp);
35 }
36 int ask(int u,int v,int w1,int w2,int k,int l=1,int r=n2) {
37     if(l==r)return l;
38     int cnt=val[ls[u]]+val[ls[v]]-val[ls[w1]]-val[ls[w2]];
39     return k<=cnt?ask(ls[u],ls[v],ls[w1],ls[w2],k,l,mid):ask(rs[u],rs[v],rs[w1],rs[w2],k-cnt,mid+1,r);
40 }
41 int main() {
42     for(int i=1; i<(N<<1); ++i)Log[i]=log2(i+0.5);
43     memset(hd,-1,sizeof hd),ne=0;
44     scanf("%d%d",&n,&m);
45     for(int i=1; i<=n; ++i)scanf("%d",&a[i]);
46     for(int i=1; i<=n; ++i)b[i-1]=a[i];
47     sort(b,b+n),n2=unique(b,b+n)-b;
48     for(int i=1; i<=n; ++i)a[i]=(lower_bound(b,b+n2,a[i])-b)+1;
49     for(int i=1; i<n; ++i) {
50         int u,v;
51         scanf("%d%d",&u,&v);
52         addedge(u,v),addedge(v,u);
53     }
54     memset(rt,0,sizeof rt),tot2=0;
55     dfs(1,0,1),initST();
56     for(int last=0; m--;) {
57         int u,v,k;
58         scanf("%d%d%d",&u,&v,&k),u^=last;
59         int w=lca(u,v);
60         int ans=b[ask(rt[u],rt[v],rt[w],rt[fa[w]],k)-1];
61         printf("%d\n",ans),last=ans;
62     }
63     return 0;
64 }
View Code
posted @ 2019-03-26 16:24  jrltx  阅读(111)  评论(0编辑  收藏  举报