LGP6177 Count on a Tree II // [LG TPLT] 树分块 学习笔记

LGP6177 Count on a Tree II // [LG TPLT] 树分块 学习笔记

Luogu Link

前言

本题解主要借鉴:这篇

题意简述

给定一棵 \(n\) 个结点的树,每个结点有一个颜色。\(m\) 次询问 \(u\to v\) 的路径上的颜色种数。强制在线。

\(n\le 4\times 10^4\)\(m\le 10^5\)

做法解析

你看到数颜色,你就知道这肯定不是 \(\text{polylog}\),而是根号算法了。然后它强制在线,所以莫队out!莫队out了,我们就得用bitset了。

不过,树上怎么用bitset呢?我们考虑树分块之后用bitset。树分块,就是我们设定一个阈值 \(B\),然后在树上面撒上 \(\frac{n}{B}\) 个点,沿这些点把树大卸八块。这样我们希望两个关键点间的距离不超过 \(B\)

其实你直接随机撒点期望上就是对的,但是我们有能严谨保证之的做法。具体地,我们每次选择一个深度最大的非关键点,如果它的 \(1\sim B\) 级祖先都不是关键点,就让它的 \(B\) 级祖先变成关键点。正确性显然。另外我们钦定 \(1\) 为关键点,省的炸一堆边界情况。

接下来我们考虑每条直链上出现的关键点 \(x_0,x_1\dots x_n\),我们预处理如此相邻的每两个关键点间的bitset,然后再求两两关键点的bitset。复杂度 \(O(\frac{nB}{B}+\frac{n^2}{B^2}\times\frac{n}{w})=O(n+\frac{n^3}{wB^2})\)

求答案是容易想到的。我们把路径拆成最多六段:设 \(a=\text{lca}(u,v)\),则有 \(u\sim x_u\)\(x_u\sim x_{al}\)\(x_{al}\sim a\)\(a\sim x_{ar}\)\(x_{ar}\sim x_v\)\(x_v\sim v\)。四个散段暴力更新,两个整段直接调用与处理结果,最后调用下 count() 即可得出答案。询问的总复杂度就是 \(O(m(B+\frac{n}{w}))\)

我们先取 \(B=\sqrt{n}\)。不妨认为 \(n\) 小于 \(m\)。此时时间复杂度 \(O(m\sqrt{n}+\frac{nm}{w})\),空间复杂度 \(O(\frac{n^2}{B^2})\)。我们发现当 \(B\) 线性增长时时间常数线性增长而空间常数平方级下降,所以我们可以通过开大 \(B\) 来卡卡空间。

代码实现

#include <bits/stdc++.h>
using namespace std;
using namespace obasic;
const int MaxN=4e4+5,Ksiz=1e3,CntK=42;
int N,M,A[MaxN],B[MaxN],nln,X,Y;
vector<int> Tr[MaxN];
void addudge(int u,int v){
    Tr[u].push_back(v);
    Tr[v].push_back(u);
}
int dep[MaxN],tfa[MaxN],siz[MaxN],hson[MaxN],epd[MaxN];
int bcnt,bid[MaxN];
void dfs1(int u,int f){
    epd[u]=dep[u]=dep[f]+1,tfa[u]=f,siz[u]=1;
    for(int v : Tr[u]){
        if(v==f)continue;dfs1(v,u);
        siz[u]+=siz[v],maxxer(epd[u],epd[v]);
        if(siz[v]>siz[hson[u]])hson[u]=v;
    }
    if(epd[u]-dep[u]>=Ksiz)bid[u]=++bcnt,epd[u]=dep[u];
}
bitset<MaxN> bts[CntK][CntK],cur;
int ktp,stk[MaxN],ff[MaxN];
void dfs2(int u){
    for(int v : Tr[u]){
        if(v==tfa[u])continue;
        if(bid[v]){
            int pid=bid[stk[ktp]],cid=bid[v];
            for(int x=v;x!=stk[ktp];x=tfa[x])bts[pid][cid].set(A[x]);
            cur=bts[pid][cid];
            for(int i=1;i<ktp;i++){
                int sid=bid[stk[i]];
                bts[sid][cid]=bts[sid][pid]|cur;
            }
            ff[v]=stk[ktp],stk[++ktp]=v;
        }
        dfs2(v);if(bid[v])--ktp;
    }
}
int top[MaxN];
void dfs3(int u,int t){
    top[u]=t;if(hson[u])dfs3(hson[u],t);
    for(int v : Tr[u])if(v!=tfa[u]&&v!=hson[u])dfs3(v,v);
}
int getlca(int u,int v){
    while(top[u]!=top[v])dep[top[u]]>dep[top[v]]?u=tfa[top[u]]:v=tfa[top[v]];
    return dep[u]<dep[v]?u:v;
}
int solve(int u,int v){
    cur.reset();int anc=getlca(u,v);
    for(;u!=anc&&!bid[u];u=tfa[u])cur.set(A[u]);
    for(;v!=anc&&!bid[v];v=tfa[v])cur.set(A[v]);
    if(u!=anc){
        int au=u;
        while(dep[ff[au]]>=dep[anc])au=ff[au];
        if(au!=u)cur|=bts[bid[au]][bid[u]];
        for(;au!=anc;au=tfa[au])cur.set(A[au]);
    }
    if(v!=anc){
        int av=v;
        while(dep[ff[av]]>=dep[anc])av=ff[av];
        if(av!=v)cur|=bts[bid[av]][bid[v]];
        for(;av!=anc;av=tfa[av])cur.set(A[av]);
    }
    cur.set(A[anc]);return cur.count();
}
int main(){
    readis(N,M);
    for(int i=1;i<=N;i++)readi(A[i]),B[i]=A[i];
    sort(B+1,B+N+1),nln=unique(B+1,B+N+1)-(B+1);
    for(int i=1;i<=N;i++)A[i]=lwberi(B,nln,A[i]);
    for(int i=1;i<N;i++)readis(X,Y),addudge(X,Y);
    dfs1(1,0);if(!bid[1])bid[1]=++bcnt;
    stk[ktp=1]=1,dfs2(1),dfs3(1,1);
    for(int i=1,ans=0;i<=M;i++){
        readis(X,Y),X^=ans;
        ans=solve(X,Y),writil(ans);
    }
    return 0;
}
posted @ 2025-05-14 17:56  矞龙OrinLoong  阅读(12)  评论(0)    收藏  举报