BZOJ 3572 [HNOI2014]世界树 (虚树+DP)

题面:BZOJ传送门 洛谷传送门

题目大意:略

 

细节贼多的虚树$DP$

 

先考虑只有一次询问的情况

一个节点$x$可能被它子树内的一个到x距离最小的特殊点管辖,还可能被管辖fa[x]的特殊点管辖

跑两次$dfs$即可,时间$O(n)$

 

再考虑一条链的情况

一条链上有很多个特殊点,相邻两个特殊点$x,y$之间可能有很多连续的非特殊点,那么在这些连续的非特殊点上会有一个分界,前面一部分被$x$管辖,后面一部分被$y$管辖

在链上二分即可,时间$O(mlogm)$

 

正解就是把上面两种情况结合起来..用虚树维护一下

首先根据套路对特殊点建出虚树,虚树上会出现所有的特殊点以及一些作为$LCA$的非特殊点。

用第一种情况的方法在虚树上搜索一遍,求出虚树上的每个节点被哪些节点管辖

再考虑剩余节点的贡献

对于虚树上相邻的两个节点$x,y$,假设$dep[x]<dep[y]$,我们取出原树上端点为$x,y$的链$F$,然后把$F$的两个端点$x,y$去掉,贡献分为两种情况

$x,y$被同一个的节点管辖,那么链F上的节点以及链F上挂着的子树也都会被这个节点管辖

$x,y$被不同的节点管辖,借用第二种情况的方法,链F上一定存在一个分界点,上面一部分被管辖x的节点管辖,下面一部分被管辖y的节点管辖,倍增跳一下即可

看起来很好写,实际上细节真的不少啊..上午迷迷糊糊写+调了4h才过

 

  1 #include <cstdio>
  2 #include <cstring>
  3 #include <algorithm>
  4 #define ll long long 
  5 #define N1 300100
  6 using namespace std;
  7 const int inf=0x3f3f3f3f;
  8  
  9 template <typename _T> void read(_T &ret)
 10 {
 11     ret=0; _T fh=1; char c=getchar();
 12     while(c<'0'||c>'9'){ if(c=='-') fh=-1; c=getchar(); }
 13     while(c>='0'&&c<='9'){ ret=ret*10+c-'0'; c=getchar(); }
 14     ret=ret*fh;
 15 }
 16 template <typename _T> _T chkmin(_T &x,_T &y,_T vx,_T vy)
 17 {
 18     if(vx<vy) return x;
 19     if(vx>vy) return y;
 20     return x<y?x:y;
 21 }
 22  
 23 struct Edge{
 24 int to[N1*2],nxt[N1*2],head[N1],cte;
 25 void ae(int u,int v)
 26 { cte++; to[cte]=v; nxt[cte]=head[u]; head[u]=cte; } //val[cte]=w; 
 27 void clr()
 28 { memset(to,0,(cte+1)*4); memset(nxt,0,(cte+1)*4); cte=0; } 
 29 }e,g;
 30  
 31 int n,de;
 32 int lg[N1*2];
 33 int dep[N1],ff[N1][20],eu[N1*2][20],st[N1],sz[N1],cur;
 34  
 35 void dfs(int x)
 36 {
 37     int j,v; ff[x][0]=x; eu[st[x]=++cur][0]=x; 
 38     for(j=e.head[x];j;j=e.nxt[j])
 39     {
 40         v=e.to[j]; if(v==ff[x][1]) continue;
 41         ff[v][1]=x; dep[v]=dep[x]+1; 
 42         dfs(v); 
 43         eu[++cur][0]=x; sz[x]+=sz[v];
 44     }
 45     sz[x]++;
 46 }
 47 void get_st()
 48 {
 49     int i,j;
 50     for(j=2;j<=18;j++)
 51     for(i=1;i<=n;i++)
 52         ff[i][j]=ff[ ff[i][j-1] ][j-1];
 53     for(i=2,lg[1]=0;i<=cur;i++) lg[i]=lg[i>>1]+1;
 54     for(j=1;j<=lg[cur];j++)
 55     for(i=1;i+(1<<j)-1<=cur;i++)
 56         eu[i][j]=dep[eu[i][j-1]] < dep[eu[i+(1<<(j-1))][j-1]] ? eu[i][j-1] : eu[i+(1<<(j-1))][j-1];
 57 }
 58 int LCA(int x,int y)
 59 {
 60     x=st[x], y=st[y]; if(x>y) swap(x,y); int l=y-x+1;
 61     return dep[eu[x][lg[l]]] < dep[eu[y-(1<<lg[l])+1][lg[l]]] ? eu[x][lg[l]] : eu[y-(1<<lg[l])+1][lg[l]];
 62 }
 63 int Dis(int x,int y)
 64 {
 65     if(!x||!y) return inf; int F=LCA(x,y);
 66     return dep[x]+dep[y]-2*dep[F];
 67 }
 68 int jump(int x,int D)
 69 {
 70     int i;
 71     for(i=lg[dep[x]-D]+1;i>=0;i--)
 72         if(ff[x][i] && dep[ff[x][i]]>=D) x=ff[x][i];
 73     return x;
 74 }
 75  
 76 namespace virtree{
 77  
 78 int cmp_dfsorder(int x,int y){ return st[x]<st[y]; }
 79 int vir[N1],num,stk[N1],tp,ctrl[N1],spe[N1],ans[N1],org[N1],dctrl[N1];
 80  
 81 void push(int x)
 82 {
 83     int y=stk[tp], F=LCA(x,y); stk[0]=F;
 84     while(tp>0 && dep[stk[tp-1]]>dep[F])
 85     {
 86         g.ae(stk[tp-1],stk[tp]); 
 87         stk[tp]=0; tp--;
 88     }
 89     if(dep[stk[tp]]>dep[F])
 90     {
 91         g.ae(F,stk[tp]); 
 92         stk[tp]=0; tp--;
 93     }
 94     if(!tp||stk[tp]!=F) stk[++tp]=F;
 95     if(stk[tp]!=x) stk[++tp]=x;
 96 }
 97 int build()
 98 {
 99     int i;
100     for(i=1;i<=num;i++) spe[vir[i]]=1, org[i]=vir[i]; 
101     if(!spe[1]) vir[++num]=1;
102     sort(vir+1,vir+num+1,cmp_dfsorder);
103     stk[++tp]=vir[1];
104     for(i=2;i<=num;i++) push(vir[i]);
105     while(tp>1) g.ae(stk[tp-1],stk[tp]), tp--; 
106     return stk[tp];
107 }
108  
109 int dfs_son(int x)
110 {
111     int j,v,mi=0;
112     for(j=g.head[x];j;j=g.nxt[j])
113     {
114         v=g.to[j]; dfs_son(v);
115         mi=chkmin(ctrl[v],mi,dep[ctrl[v]],dep[mi]);
116     }
117     ctrl[x]=(spe[x]?x:mi); 
118     return ctrl[x];
119 }
120 void dfs_fa(int x)
121 {
122     int j,v;
123     for(j=g.head[x];j;j=g.nxt[j])
124     {
125         v=g.to[j];
126         ctrl[v]=chkmin(ctrl[x],ctrl[v],Dis(ctrl[x],v),Dis(ctrl[v],v));
127         dfs_fa(v);
128     }
129 }
130 void debug(int x)
131 {
132     int j,v;
133     if(spe[x]) printf("%d ",x);
134     for(j=e.head[x];j;j=e.nxt[j])
135     {
136         v=e.to[j]; if(v==ff[x][1]) continue;
137         debug(v);
138     }
139 }
140 void dfs_ans(int x)
141 {
142     int j,v,sum=sz[x],cx=ctrl[x]; dctrl[x]=Dis(x,ctrl[x]);
143     for(j=g.head[x];j;j=g.nxt[j])
144     {
145         v=g.to[j]; dfs_ans(v);
146         int cv=ctrl[v],p,pv,tmp;
147         pv=jump(v,dep[x]+1);
148         sum-=sz[pv];
149         if(dep[v]==dep[x]+1) continue; 
150         if(cx!=cv){
151             tmp=dctrl[x]+dctrl[v]+dep[v]-dep[x]; 
152             if(tmp&1){
153                 p=dep[v]-((tmp-1)/2-dctrl[v]);
154                 p=jump(v, min(dep[v],max(dep[x]+1,p)) );
155             }else{
156                 p=dep[v]-((tmp-1)/2-dctrl[v]);
157                 if(p<=dep[v]){
158                     p=jump(v, min(dep[v],max(dep[x]+1,p)) );
159                     if(cv<cx && dep[p]-1>=dep[x]+1) p=ff[p][1];
160                 }else p=v; 
161             }
162             ans[cx]+=sz[pv]-sz[p], ans[cv]+=sz[p]-sz[v];
163         }else{
164             ans[cx]+=sz[pv]-sz[v];
165         }
166     }
167     ans[cx]+=sum;
168 }
169 void dfs_clear(int x)
170 {
171     int j,v;
172     for(j=g.head[x];j;j=g.nxt[j])
173     {
174         v=g.to[j]; dfs_clear(v);
175     }
176     g.head[x]=ctrl[x]=spe[x]=ans[x]=0;
177 }
178  
179 void solve(int Num)
180 {
181     num=Num;
182     int root=build(),i; 
183     dfs_son(root); 
184     dfs_fa(root);
185     dfs_ans(root);
186     for(i=1;i<=Num;i++) if(i!=Num) printf("%d ",ans[org[i]]); else printf("%d\n",ans[org[i]]);
187     dfs_clear(root);
188     memset(vir,0,(num+1)*4); memset(org,0,(num+1)*4); memset(stk,0,(tp+1)*4); g.clr(); tp=0;
189 }
190  
191 };
192  
193 int main()
194 {
195     int i,j,ans=0,x,y,q,Q;
196     scanf("%d",&n);
197     for(i=1;i<n;i++) read(x), read(y), e.ae(x,y), e.ae(y,x);
198     dfs(1); get_st(); dep[0]=inf;
199     scanf("%d",&Q); 
200     for(q=1;q<=Q;q++)
201     {
202         read(x);
203         for(i=1;i<=x;i++) read(virtree::vir[i]);
204         virtree::solve(x);
205     }
206     return 0;
207 }

 

posted @ 2019-04-01 22:22  guapisolo  阅读(211)  评论(0编辑  收藏  举报