树上散点

一个博客需要一份头图:

强制在线(论一个 ^ 引起的癫狂:


P6177 Count on a tree II/【模板】树分块

题意:

给定一棵树,每个节点有颜色。

每次询问一条路径上不同颜色的个数,强制在线。

数据范围 \(1 \leq n \leq 4\times 10^4,1\leq m \leq 10^5\),其中 \(n\) 为点数,\(m\) 为查询数。

考虑在树上随机选 \(k\) 个点作为关键点,使得树上每个点距离其最近的祖先关键点距离不超过 \(\frac{n}{k}\)

具体地,如果一个点的 \(1 \sim \frac{n}{k}\) 级祖先都不是关键点,则把 \(\frac{n}{k}\) 级祖先标记为关键点。

为方便,最好也取根节点为关键点。

然后预处理出每条路径上,各各关键节点之间的答案,可以用 bitset 维护。

那么最多要处理 \(k^2\) 个点之间的答案,预处理复杂度为 \(O(\frac{nk^2}{w})\)

具体地,只需要用栈维护当前路径的节点即可。

下面考虑如何处理询问。可以把一个询问拆成这样:

其中紫色为关键节点。

那么对于红色点与紫色点之间,则暴力跳,对于紫色点之间则用之前预处理出的答案。

对于从 \(u_0 \rightarrow u_1\),只需要记录下每个关键点在树链上的上一个关键点 \(lst\) 即可。

还要注意空间复杂度,这里取 \(k=80\)

#include<bits/stdc++.h>
using namespace std;

const int N=4e4+5;
int n,m,q,top,ans,a[N],lsh[N],lst[N];
int tot,id[N],dis[N],dep[N],f[N][16],stk[N];
bitset<N> t[82][82],tmp;
vector<int> G[N];

int rd()
{
    int x=0;char c=getchar();
    for(;!isdigit(c);c=getchar());
    for(; isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^48);
    return x;
}

void dfs(int u,int fa)
{
    dis[u]=dep[u]=dep[fa]+1,f[u][0]=fa;
    for(int i=1;i<=15;i++) f[u][i]=f[f[u][i-1]][i-1];
    for(int v:G[u])
    {
        if(v==fa) continue;
        dfs(v,u);
        dis[u]=max(dis[u],dis[v]);
    }
    if(dis[u]-dep[u]>=500) dis[u]=dep[u],id[u]=++tot;
}

void dfs2(int u)
{
    for(int v:G[u])
    {
        if(v==f[u][0]) continue;
        if(id[v])
        {
            int x=id[stk[top]],y=id[v];
            for(int i=v;i!=stk[top];i=f[i][0]) t[x][y].set(a[i]);
            tmp=t[x][y];
            for(int i=1;i<top;i++) t[id[stk[i]]][y]=t[id[stk[i]]][x]|tmp;
            lst[v]=stk[top];
            stk[++top]=v;
        }
        dfs2(v);
        if(id[v]) top--;
    }
}

int getlca(int u,int v)
{
    if(dep[u]<dep[v]) swap(u,v);
    for(int i=15;~i;i--) if(dep[f[u][i]]>=dep[v]) u=f[u][i];
    if(u==v) return u;
    for(int i=15;~i;i--) if(f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i];
    return f[u][0];
}

void work(int u,int lca)
{
    int pre=u;
    while(dep[lst[pre]]>=dep[lca]) pre=lst[pre];
    if(pre!=u) tmp|=t[id[pre]][id[u]];
    while(pre!=lca) tmp.set(a[pre]),pre=f[pre][0];
}

int main()
{
    n=rd(),q=rd();
    for(int i=1;i<=n;i++) a[i]=lsh[i]=rd();
    sort(lsh+1,lsh+1+n);
    m=unique(lsh+1,lsh+1+n)-lsh-1;
    for(int i=1;i<=n;i++) a[i]=lower_bound(lsh+1,lsh+1+m,a[i])-lsh;
    for(int i=1;i<n;i++)
    {
        int u=rd(),v=rd();
        G[u].push_back(v),G[v].push_back(u);
    }
    dfs(1,0);
    if(!id[1]) id[1]=++tot;
    stk[++top]=1;
    dfs2(1);
    while(q--)
    {
        int u=rd()^ans,v=rd(),lca=getlca(u,v);
        tmp.reset();
        while(u!=lca&&!id[u]) tmp.set(a[u]),u=f[u][0];
        while(v!=lca&&!id[v]) tmp.set(a[v]),v=f[v][0];
        tmp.set(a[lca]);
        if(u!=lca) work(u,lca);
        if(v!=lca) work(v,lca);
        printf("%d\n",ans=tmp.count());
    }
}

三倍经验:

SP10707

雪辉

雪辉这道题还有求个 mex,只需要把 bitset 取反后求 lowbit 即可,bitset 有个函数叫 _Find_first()

#include<bits/stdc++.h>
using namespace std;

const int N=1e5+5;
int n,q,totE,top,ans,op,a[N],lst[N];
int tot,id[N],dis[N],dep[N],f[N][20],stk[N];
bitset<30005> t[155][155],tmp;
int pre[N],nxt[N<<1],to[N<<1];
void add(int u,int v){to[++totE]=v,nxt[totE]=pre[u],pre[u]=totE;}

int rd()
{
    int x=0;char c=getchar();
    for(;!isdigit(c);c=getchar());
    for(; isdigit(c);c=getchar()) x=(x<<3)+(x<<1)+(c^48);
    return x;
}

void dfs(int u,int fa)
{
    dis[u]=dep[u]=dep[fa]+1,f[u][0]=fa;
    for(int i=1;i<18;i++) f[u][i]=f[f[u][i-1]][i-1];
    for(int i=pre[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==fa) continue;
        dfs(v,u);
        dis[u]=max(dis[u],dis[v]);
    }
    if(dis[u]-dep[u]>=1000) dis[u]=dep[u],id[u]=++tot;
}

void dfs2(int u)
{
    for(int i=pre[u];i;i=nxt[i])
    {
        int v=to[i];
        if(v==f[u][0]) continue;
        if(id[v])
        {
            int x=id[stk[top]],y=id[v];
            for(int i=v;i!=stk[top];i=f[i][0]) t[x][y].set(a[i]);
            tmp=t[x][y];
            for(int i=1;i<top;i++) t[id[stk[i]]][y]=t[id[stk[i]]][x]|tmp;
            lst[v]=stk[top];
            stk[++top]=v;
        }
        dfs2(v);
        if(id[v]) top--;
    }
}

int getlca(int u,int v)
{
    if(dep[u]<dep[v]) swap(u,v);
    for(int i=17;~i;i--) if(dep[f[u][i]]>=dep[v]) u=f[u][i];
    if(u==v) return u;
    for(int i=17;~i;i--) if(f[u][i]!=f[v][i]) u=f[u][i],v=f[v][i];
    return f[u][0];
}

void work(int u,int lca)
{
    int pre=u;
    while(dep[lst[pre]]>=dep[lca]) pre=lst[pre];
    if(pre!=u) tmp|=t[id[pre]][id[u]];
    while(pre!=lca) tmp.set(a[pre]),pre=f[pre][0];
}

void get(int u,int v)
{
    int lca=getlca(u,v);
    while(u!=lca&&!id[u]) tmp.set(a[u]),u=f[u][0];
    while(v!=lca&&!id[v]) tmp.set(a[v]),v=f[v][0];
    tmp.set(a[lca]);
    if(u!=lca) work(u,lca);
    if(v!=lca) work(v,lca);
}

int main()
{
    n=rd(),q=rd(),op=rd();
    for(int i=1;i<=n;i++) a[i]=rd();
    for(int i=1;i<n;i++)
    {
        int u=rd(),v=rd();
        add(u,v),add(v,u);
    }
    dfs(1,0);
    if(!id[1]) id[1]=++tot;
    stk[++top]=1;
    dfs2(1);
    while(q--)
    {
        int cnt=rd();tmp.reset();
        for(int i=1;i<=cnt;i++) get(rd()^ans,rd()^ans);
        int a1=tmp.count();tmp=~tmp;
        int a2=tmp._Find_first();
        printf("%d %d\n",a1,a2);
        if(op) ans=(a1+a2);
    }
}
posted @ 2023-09-19 16:31  spider_oyster  阅读(10)  评论(0编辑  收藏  举报