树上散点
一个博客需要一份头图:
强制在线(论一个 ^ 引起的癫狂:
P6177 Count on a tree II/【模板】树分块
题意:
给定一棵树,每个节点有颜色。
每次询问一条路径上不同颜色的个数,强制在线。
数据范围 \(1 \leq n \leq 4\times 10^4,1\leq m \leq 10^5\),其中 \(n\) 为点数,\(m\) 为查询数。
考虑在树上随机选 \(k\) 个点作为关键点,使得树上每个点距离其最近的祖先关键点距离不超过 \(\frac{n}{k}\)。
具体地,如果一个点的 \(1 \sim \frac{n}{k}\) 级祖先都不是关键点,则把 \(\frac{n}{k}\) 级祖先标记为关键点。
为方便,最好也取根节点为关键点。
然后预处理出每条路径上,各各关键节点之间的答案,可以用 bitset 维护。
那么最多要处理 \(k^2\) 个点之间的答案,预处理复杂度为 \(O(\frac{nk^2}{w})\)。
具体地,只需要用栈维护当前路径的节点即可。
下面考虑如何处理询问。可以把一个询问拆成这样:
其中紫色为关键节点。
那么对于红色点与紫色点之间,则暴力跳,对于紫色点之间则用之前预处理出的答案。
对于从 \(u_0 \rightarrow u_1\),只需要记录下每个关键点在树链上的上一个关键点 \(lst\) 即可。
还要注意空间复杂度,这里取 \(k=80\)。
#include<bits/stdc++.h>
using namespace std;
const int N=4e4+5;
int n,m,q,top,ans,a[N],lsh[N],lst[N];
int tot,id[N],dis[N],dep[N],f[N][16],stk[N];
bitset<N> t[82][82],tmp;
vector<int> G[N];
int rd()
{
int x=0;char c=getchar();
for(;!isdigit(c);c=getchar());
for(; isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^48);
return x;
}
void dfs(int u,int fa)
{
dis[u]=dep[u]=dep[fa]+1,f[u][0]=fa;
for(int i=1;i<=15;i++) f[u][i]=f[f[u][i-1]][i-1];
for(int v:G[u])
{
if(v==fa) continue;
dfs(v,u);
dis[u]=max(dis[u],dis[v]);
}
if(dis[u]-dep[u]>=500) dis[u]=dep[u],id[u]=++tot;
}
void dfs2(int u)
{
for(int v:G[u])
{
if(v==f[u][0]) continue;
if(id[v])
{
int x=id[stk[top]],y=id[v];
for(int i=v;i!=stk[top];i=f[i][0]) t[x][y].set(a[i]);
tmp=t[x][y];
for(int i=1;i<top;i++) t[id[stk[i]]][y]=t[id[stk[i]]][x]|tmp;
lst[v]=stk[top];
stk[++top]=v;
}
dfs2(v);
if(id[v]) top--;
}
}
int getlca(int u,int v)
{
if(dep[u]<dep[v]) swap(u,v);
for(int i=15;~i;i--) if(dep[f[u][i]]>=dep[v]) u=f[u][i];
if(u==v) return u;
for(int i=15;~i;i--) if(f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i];
return f[u][0];
}
void work(int u,int lca)
{
int pre=u;
while(dep[lst[pre]]>=dep[lca]) pre=lst[pre];
if(pre!=u) tmp|=t[id[pre]][id[u]];
while(pre!=lca) tmp.set(a[pre]),pre=f[pre][0];
}
int main()
{
n=rd(),q=rd();
for(int i=1;i<=n;i++) a[i]=lsh[i]=rd();
sort(lsh+1,lsh+1+n);
m=unique(lsh+1,lsh+1+n)-lsh-1;
for(int i=1;i<=n;i++) a[i]=lower_bound(lsh+1,lsh+1+m,a[i])-lsh;
for(int i=1;i<n;i++)
{
int u=rd(),v=rd();
G[u].push_back(v),G[v].push_back(u);
}
dfs(1,0);
if(!id[1]) id[1]=++tot;
stk[++top]=1;
dfs2(1);
while(q--)
{
int u=rd()^ans,v=rd(),lca=getlca(u,v);
tmp.reset();
while(u!=lca&&!id[u]) tmp.set(a[u]),u=f[u][0];
while(v!=lca&&!id[v]) tmp.set(a[v]),v=f[v][0];
tmp.set(a[lca]);
if(u!=lca) work(u,lca);
if(v!=lca) work(v,lca);
printf("%d\n",ans=tmp.count());
}
}
三倍经验:
雪辉这道题还有求个 mex,只需要把 bitset 取反后求 lowbit 即可,bitset 有个函数叫 _Find_first()
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+5;
int n,q,totE,top,ans,op,a[N],lst[N];
int tot,id[N],dis[N],dep[N],f[N][20],stk[N];
bitset<30005> t[155][155],tmp;
int pre[N],nxt[N<<1],to[N<<1];
void add(int u,int v){to[++totE]=v,nxt[totE]=pre[u],pre[u]=totE;}
int rd()
{
int x=0;char c=getchar();
for(;!isdigit(c);c=getchar());
for(; isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^48);
return x;
}
void dfs(int u,int fa)
{
dis[u]=dep[u]=dep[fa]+1,f[u][0]=fa;
for(int i=1;i<18;i++) f[u][i]=f[f[u][i-1]][i-1];
for(int i=pre[u];i;i=nxt[i])
{
int v=to[i];
if(v==fa) continue;
dfs(v,u);
dis[u]=max(dis[u],dis[v]);
}
if(dis[u]-dep[u]>=1000) dis[u]=dep[u],id[u]=++tot;
}
void dfs2(int u)
{
for(int i=pre[u];i;i=nxt[i])
{
int v=to[i];
if(v==f[u][0]) continue;
if(id[v])
{
int x=id[stk[top]],y=id[v];
for(int i=v;i!=stk[top];i=f[i][0]) t[x][y].set(a[i]);
tmp=t[x][y];
for(int i=1;i<top;i++) t[id[stk[i]]][y]=t[id[stk[i]]][x]|tmp;
lst[v]=stk[top];
stk[++top]=v;
}
dfs2(v);
if(id[v]) top--;
}
}
int getlca(int u,int v)
{
if(dep[u]<dep[v]) swap(u,v);
for(int i=17;~i;i--) if(dep[f[u][i]]>=dep[v]) u=f[u][i];
if(u==v) return u;
for(int i=17;~i;i--) if(f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i];
return f[u][0];
}
void work(int u,int lca)
{
int pre=u;
while(dep[lst[pre]]>=dep[lca]) pre=lst[pre];
if(pre!=u) tmp|=t[id[pre]][id[u]];
while(pre!=lca) tmp.set(a[pre]),pre=f[pre][0];
}
void get(int u,int v)
{
int lca=getlca(u,v);
while(u!=lca&&!id[u]) tmp.set(a[u]),u=f[u][0];
while(v!=lca&&!id[v]) tmp.set(a[v]),v=f[v][0];
tmp.set(a[lca]);
if(u!=lca) work(u,lca);
if(v!=lca) work(v,lca);
}
int main()
{
n=rd(),q=rd(),op=rd();
for(int i=1;i<=n;i++) a[i]=rd();
for(int i=1;i<n;i++)
{
int u=rd(),v=rd();
add(u,v),add(v,u);
}
dfs(1,0);
if(!id[1]) id[1]=++tot;
stk[++top]=1;
dfs2(1);
while(q--)
{
int cnt=rd();tmp.reset();
for(int i=1;i<=cnt;i++) get(rd()^ans,rd()^ans);
int a1=tmp.count();tmp=~tmp;
int a2=tmp._Find_first();
printf("%d %d\n",a1,a2);
if(op) ans=(a1+a2);
}
}