树上莫队
一点点变化的莫队
一些些需要注意的点:
- 序列是欧拉序,所以说长度有\(2*n\)
- 重复出现两次的点是无效的,通过标记决定是加还是减
- lca要么就是链顶,要么就是在链外(有折),注意记录和修改后还要改回来
#include<bits/stdc++.h>
#define int long long
#define F(i,i0,n) for(int i=(i0);i<=(n);i++)
#define D(i,n,i0) for(int i=(n);i>=i0;--i)
#define pii pair<int,int>
#define fr first
#define sc second
#define pb push_back
using namespace std;
inline int rd(){
int f=0,x=0;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-48;ch=getchar();}
return f?-x:x;
}
const int N=2e5+5,mod=1e9+7;
struct Id{int v,nt;}e[N<<1];
int p[N],id=1;
void add(int x,int y){e[++id]={y,p[x]};p[x]=id;}
int siz[N],son[N],dep[N],fa[N];
void dfs1(int x,int ffa){
siz[x]=1;fa[x]=ffa;
dep[x]=dep[ffa]+1;
for(int i=p[x];i;i=e[i].nt){
int v=e[i].v;if(v==ffa)continue;
dfs1(v,x);
siz[x]+=siz[v];
if(siz[v]>siz[son[x]])son[x]=v;
}
}
int ola[N],Tim,in[N],out[N],top[N];
void dfs2(int x,int tp){
in[x]=++Tim;
ola[Tim]=x;
top[x]=tp;
if(son[x])dfs2(son[x],tp);
for(int i=p[x];i;i=e[i].nt){
int v=e[i].v;if(v==fa[x]||v==son[x])continue;
dfs2(v,v);
}
out[x]=++Tim;
ola[Tim]=x;
}
int Lca(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]])swap(x,y);
x=fa[top[x]];
}
return dep[x]<dep[y]?x:y;
}
int ord[N],orl,len;
int n,m,a[N];
int blo;
struct Que{
int l,r,lca,id;
bool operator<(const Que&_)const{return l/blo==_.l/blo?((l/blo)&1)^(r<_.r):l<_.l;}
}qe[N];
int fl[N],tot=0;
int cnt[N],ans[N];
void upd(int x){
if(!fl[x])tot+=(++cnt[a[x]]==1);
else tot-=(--cnt[a[x]]==0);
fl[x]^=1;
}
void ot(){F(i,1,n)cout<<cnt[i]<<" ";cout<<'\n';}
void solve(){
sort(qe+1,qe+1+m);
int l=1,r=0;
F(i,1,m){
while(r<qe[i].r)upd(ola[++r]);
while(l>qe[i].l)upd(ola[--l]);
while(r>qe[i].r)upd(ola[r--]);
while(l<qe[i].l)upd(ola[l++]);
if(qe[i].lca)upd(qe[i].lca);
ans[qe[i].id]=tot;
if(qe[i].lca)upd(qe[i].lca);
}
}
signed main(){
n=rd(),m=rd();
F(i,1,n){
a[i]=rd();
ord[++orl]=a[i];
}
sort(ord+1,ord+1+orl);
len=unique(ord+1,ord+1+orl)-ord-1;
F(i,1,n)a[i]=lower_bound(ord+1,ord+1+len,a[i])-ord;
F(i,1,n-1){
int x=rd(),y=rd();
add(x,y);add(y,x);
}
dfs1(1,0);dfs2(1,1);
blo=sqrt(n*2);
F(i,1,m){
int u=rd(),v=rd();
if(in[u]>in[v])swap(u,v);
int ffa=Lca(u,v);
if(ffa==u){
qe[i].l=in[u];
qe[i].r=in[v];
}
else {
qe[i].l=out[u];
qe[i].r=in[v];
qe[i].lca=ffa;
}
qe[i].id=i;
}
solve();
F(i,1,m)cout<<ans[i]<<'\n';
return 0;
}

浙公网安备 33010602011771号