树分块
树分块的方式有好几种,这里介绍一种比较简单易懂的方式——树上撒点。
我们设一个阈值 \(B\),并在树上选出 \(\frac{n}{B}\) 个点作为关键点,满足树上任意一个点到关键点的距离不大于 \(B\)。
具体方法如下:每次选定一个深度最大的的非关键点,若它的 \(1\)~\(B\) 级祖先均不是关键点,则选择它的 \(B\) 级祖先作为关键点,由于我们每次选定一个关键点,都有 \(B\) 个点不会被选,所以可以保证关键点的数量不超过 \(\frac{n}{B}\),并且这样选显然也能保证树上任意一个点到关键点的距离不大于 \(B\)。
接下来,我们用 \(bitset\) 维护颜色,对于一条到根路径上的所有关键点,我们预处理出它们两两之间的 \(bitset\),对于询问 \(x\)、\(y\),我们找到它们的 \(LCA\),将 \(x->y\) 的路径拆为 \(x->LCA\)、\(y->LCA\),然后我们从点 \(x\) 开始,先跳到距它最近的关键点祖先,再沿着关键点向上跳,中间累加上我们预处理的贡献,点 \(y\) 同理,最后将两条路径拼起来即为答案。
复杂度 \(O(\frac{n^2}{B}+qB+\frac{n^3}{B^2w})\),取 \(B=\sqrt{n}\),最终复杂度为 \(O(n\sqrt{n}+q\sqrt{n}+\frac{n^2}{w})\)。
代码:
#include<bits/stdc++.h>
using namespace std;
const int N=1e5+10,len=1000;
inline int read() {
int s=0,x=1;char ch=getchar();
while(!isdigit(ch)) {if(ch=='-') x=-1;ch=getchar();}
while(isdigit(ch)) s=(s<<3)+(s<<1)+(ch^48),ch=getchar();
return s*x;
}
int lstans,n,m,val[N],dep[N],son[N],top[N],stk[N],tot;
int bv[N],siz[N],mx[N],fa[N],ky[N],cnt,F[N];
vector<int>e[N];
bitset<40010>b[41][41],ans;
inline void dfs1(int u,int Fa) {
siz[u]=1;dep[u]=dep[Fa]+1;
fa[u]=Fa;mx[u]=dep[u];
for(auto v:e[u]) {
if(v==Fa) continue;
dfs1(v,u);siz[u]+=siz[v];
mx[u]=max(mx[u],mx[v]);
if(!son[u]||siz[son[u]]<siz[v]) son[u]=v;
}
if(mx[u]-dep[u]>=len) ky[u]=++cnt,mx[u]=dep[u];
}
inline void dfs2(int u,int tp,int Fa) {
top[u]=tp;
if(son[u]) dfs2(son[u],tp,u);
for(auto v:e[u])
if(v!=Fa&&v!=son[u]) dfs2(v,v,u);
}
inline void dfs3(int u,int Fa) {
for(auto v:e[u]) {
if(v==Fa) continue;
if(ky[v]) {
int dw=ky[v],up=ky[stk[tot]];
for(int i=v;i!=stk[tot];i=fa[i]) b[up][dw].set(val[i]);
for(int i=1;i<tot;++i) b[ky[stk[i]]][dw]=b[ky[stk[i]]][up]|b[up][dw];
F[v]=stk[tot];stk[++tot]=v;
}
dfs3(v,u);
if(ky[v]) tot--;
}
}
inline int LCA(int u,int v) {
while(top[u]!=top[v]) {
if(dep[top[u]]>dep[top[v]]) u=fa[top[u]];
else v=fa[top[v]];
}
return dep[u]<=dep[v]?u:v;
}
inline void solve(int u,int lca) {
int nw=u;
while(dep[F[nw]]>dep[lca]) nw=F[nw];
if(nw!=u) ans|=b[ky[nw]][ky[u]];
while(nw!=lca) ans.set(val[nw]),nw=fa[nw];
}
int main() {
n=read(),m=read();
for(int i=1;i<=n;++i) bv[i]=val[i]=read();
sort(bv+1,bv+n+1);
int tt=unique(bv+1,bv+n+1)-bv-1;
for(int i=1;i<=n;++i) val[i]=lower_bound(bv+1,bv+tt+1,val[i])-bv;
for(int i=1;i<n;++i) {
int u=read(),v=read();
e[u].push_back(v);e[v].push_back(u);
}
dfs1(1,0);ky[1]=!ky[1]?1:ky[1];
dfs2(1,1,0);stk[++tot]=1;dfs3(1,0);
for(int i=1;i<=m;++i) {
ans.reset();
int u=read()^lstans,v=read();
int lca=LCA(u,v);ans.set(val[lca]);
while(u!=lca&&!ky[u]) ans.set(val[u]),u=fa[u];
while(v!=lca&&!ky[v]) ans.set(val[v]),v=fa[v];
solve(u,lca);solve(v,lca);
printf("%d\n",lstans=ans.count());
}
return 0;
}