题解:P8339 [AHOI2022] 钥匙

题目传送门

题意

给出一棵树,树上每个节点有一把钥匙或宝箱,钥匙和宝箱都有颜色,颜色要一一对应才能打开宝箱。现在给出 \(m\) 趟旅程,每次给定 \(s\)\(t\) 求一路上最多能开多少宝箱。

思路

注意到每种颜色之间互不干扰,所以可以分开考虑。用单调栈构建不同颜色的虚树,然后对于每种颜色的虚树,我们考虑求出可能成为贡献的几对节点,我们假设钥匙是 \(k\),宝箱是 \(b\),那么用 \(k_i\) 开的 \(b_j\) 就可以作为贡献记录下来,记为 \((b_i,k_j)\)

这里简单提一下怎么建虚树(实际是笔者才刚学会),对于每种颜色的钥匙和宝箱,我们用单调栈维护,将相同颜色的点与栈顶判断 \(lca\),如果 \(lca\) 是栈顶,那么直接入栈,否则将栈顶弹出,直到栈顶的深度与 \(lca\) 的相同,如果这时栈顶不是 \(lca\) 的话,要让 \(lca\) 入栈。注意弹出的时候不要忘记连边。

具体情况见代码:

void inline ins(ll x){
    vir.push_back(x);
    if(!top){
        s[++top]=x;
        return ;
    }
    ll p=lca(s[top],x);
    while(top>1&&dep[s[top-1]]>=dep[p]){
        e[s[top-1]].push_back(s[top]);
        e[s[top]].push_back(s[top-1]);
        top--;
    }
    if(s[top]!=p){
        e[s[top]].push_back(p);
        e[p].push_back(s[top]);
        s[top]=p;
        vir.push_back(p);
    }
    s[++top]=x;
}

这里顺便提一嘴 \(lca\)。这里的 \(lca\) 是用重链剖分,不是倍增,倍增会超时。中联剖分的 \(lca\) 每次就判断是否已经在一条链上,如果是就返回深度较小的那个点,否则就继续跳。

代码:

ll lca(ll x,ll y){
    while(tp[x]!=tp[y]){
        if(dep[tp[x]]<dep[tp[y]])swap(x,y);
        x=fa[tp[x]];
    }
    return dep[x]<dep[y]?x:y;
}

现在我们考虑这个有什么用,如果想让一段旅程包含 \((b_i,k_j)\) 这个贡献,那么对 \(s\)\(t\) 便有限制。

这里需要分类讨论一下:

如果 \(b_i\)\(k_j\) 互不是对方的祖先,那么只需要 \(s\)\(b_j\) 的子树内,\(t\)\(k_i\) 的子树内即可。

如果 \(b_i\)\(k_j\) 的祖先,那么只需要 \(s\)\(b_i\) 的子树内,\(t\) 不在 \(k_j\) 的子树内,但是也可以在 \(k_j\) 这个点上。我们令 \(u\)\(k_j\)\(b_i\) 路径上的 \(k_j\) 的儿子,那么 \(s\) 只需要在 \(b_i\) 的子树内,\(t\) 只需要在 \(u\) 的子树外即可。

如果 \(k_j\)\(b_i\) 的祖先同理。

详细情况见代码:

void pushup(ll l,ll r,ll x,ll y){
    if(l>r||x>y)return;
    sum[l].push_back({x,1});
    sum[l].push_back({y+1,-1});
    sum[r+1].push_back({x,-1});
    sum[r+1].push_back({y+1,1});
}
void insert(ll x,ll y){
    if(check(x,y)){
        ll z=js(y,x);
        pushup(1,dfn[z]-1,dfn[y],dfn[y]+siz[y]-1);
        pushup(dfn[z]+siz[z],n,dfn[y],dfn[y]+siz[y]-1);
    }
    else if(check(y,x)){
        ll z=js(x,y);
        pushup(dfn[x],dfn[x]+siz[x]-1,1,dfn[z]-1);
        pushup(dfn[x],dfn[x]+siz[x]-1,dfn[z]+siz[z],n);
    }
    else{
        pushup(dfn[x],dfn[x]+siz[x]-1,dfn[y],dfn[y]+siz[y]-1);
    }
}

这里的 \(sum\) 数组是为了后面计算的差分数组。

那么要怎么判断是否在子树内呢。可以用 \(dfs\) 序,在子树内的话 \(dfs\) 序一定是连续的,所以只需要判断 \(dfs\) 序是否在对应区间内即可。

那么具体该怎么计算呢?我们想到可以用树状数组维护区间差分。由于这里有两个区间的限制,所以我们用下标来维护一个区间,用值来维护另一个区间。

这里直接讲可能不太好懂,所以借助代码一起。

代码:

for(int i=1;i<=m;i++){
    ll u,v;
    cin>>u>>v;
    num[dfn[u]].push_back({dfn[v],i});
}
for(int i=1;i<=n;i++){
    for(auto v:sum[i])add(v.first,v.second);
    for(auto v:num[i])ans[v.second]=query(v.first);
}

在这里面 \(num\) 记录的是起点和终点以及编号,因为要离线处理。\(sum\) 数组中存贮的是要差分的位置和值,因为每段旅程想要包含这段贡献的话,就要起点和终点都在对应的区间内,而我们用树状数组的下标和前缀和来维护第一个区间,保障当到这个起点时,所有能包含该起点的区间都被放入了树状数组,而第二个区间我们用树状数组的值,也就是用差分来维护,只要被包含在这段区间内的话,值就会 \(+1\)

最后给一下完整代码:

#include<bits/stdc++.h>
#define ll int
using namespace std;
const int M=1010101;
const int N=501010;
ll n,m,ans[M],siz[N],dep[N],fa[N],son[N],tp[N],dfn[N],tot,cnt,pre[N],top,S,col;
ll c[N],t[N],s[N],tree[N];
vector<ll>g[N];
void dfs1(ll u){
    siz[u]=1;
    for(auto v:g[u]){
        if(v==fa[u])continue;
        fa[v]=u;
        dep[v]=dep[u]+1;
        dfs1(v);
        siz[u]+=siz[v];
        if(siz[v]>siz[son[u]])son[u]=v;
    }
}
void dfs2(ll u,ll top){
    tp[u]=top,dfn[u]=++cnt;
    pre[dfn[u]]=u;
    if(son[u])dfs2(son[u],top);
    for(auto v:g[u]){
        if(v==fa[u]||v==son[u])continue;
        dfs2(v,v);
    }
}
ll lca(ll x,ll y){
    while(tp[x]!=tp[y]){
        if(dep[tp[x]]<dep[tp[y]])swap(x,y);
        x=fa[tp[x]];
    }
    return dep[x]<dep[y]?x:y;
}
vector<ll>vir,e[N];
void inline ins(ll x){
    vir.push_back(x);
    if(!top){
        s[++top]=x;
        return ;
    }
    ll p=lca(s[top],x);
    while(top>1&&dep[s[top-1]]>=dep[p]){
        e[s[top-1]].push_back(s[top]);
        e[s[top]].push_back(s[top-1]);
        top--;
    }
    if(s[top]!=p){
        e[s[top]].push_back(p);
        e[p].push_back(s[top]);
        s[top]=p;
        vir.push_back(p);
    }
    s[++top]=x;
}
bool check(ll x,ll y){return dfn[x]<=dfn[y]&&dfn[x]+siz[x]>dfn[y];}
ll js(ll x,ll y){
    ll res=0;
    while(tp[x]!=tp[y]){
        res=tp[x];
        x=fa[tp[x]];
    }
    if(x==y)return res;
    return pre[dfn[y]+1];
}
vector<pair<ll,ll> >sum[N];
void pushup(ll l,ll r,ll x,ll y){
    if(l>r||x>y)return;
    sum[l].push_back({x,1});
    sum[l].push_back({y+1,-1});
    sum[r+1].push_back({x,-1});
    sum[r+1].push_back({y+1,1});
}
void insert(ll x,ll y){
    if(check(x,y)){
        ll z=js(y,x);
        pushup(1,dfn[z]-1,dfn[y],dfn[y]+siz[y]-1);
        pushup(dfn[z]+siz[z],n,dfn[y],dfn[y]+siz[y]-1);
    }
    else if(check(y,x)){
        ll z=js(x,y);
        pushup(dfn[x],dfn[x]+siz[x]-1,1,dfn[z]-1);
        pushup(dfn[x],dfn[x]+siz[x]-1,dfn[z]+siz[z],n);
    }
    else{
        pushup(dfn[x],dfn[x]+siz[x]-1,dfn[y],dfn[y]+siz[y]-1);
    }
}
void dfs(ll u,ll fa,ll w){
    if(w<0)return;
    for(auto v:e[u]){
        if(v==fa)continue;
        if(!w&&c[v]==col&&t[v]==2){
            insert(S,v);
            continue;
        }
        ll k=0;
        if(c[v]==col&&t[v]==1)k++;
        if(c[v]==col&&t[v]==2)k--;
        dfs(v,u,w+k);
    }
}
ll lowbit(ll x){return x&(-x);}
void add(ll x,ll y){
    while(x<=n){
        tree[x]+=y;
        x+=lowbit(x);
    }
}
ll query(ll x){
    ll res=0;
    while(x){
        res+=tree[x];
        x-=lowbit(x);
    }
    return res;
}
vector<ll>b[N];
vector<pair<ll,ll> >num[N];
bool cmp(ll a,ll b){return dfn[a]<dfn[b];}
int main(){
    ios::sync_with_stdio(false);
    cin.tie(0);cout.tie(0);
    cin>>n>>m; 
    for(int i=1;i<=n;i++)cin>>t[i]>>c[i],b[c[i]].push_back(i);
    for(int i=1;i<n;i++){
        ll u,v;
        cin>>u>>v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    dfs1(1);
    dfs2(1,1);
    top=0;
    for(int i=1;i<=n;i++){
    	col=i;
        if(!b[i].size())continue;
        sort(b[i].begin(),b[i].end(),cmp);
        for(auto v:b[i])ins(v);
        while(top>1){
            e[s[top-1]].push_back(s[top]);
            e[s[top]].push_back(s[top-1]);
            top--;
        }
        for(auto v:b[i]){
            if(t[v]==1&&c[v]==i){
            	S=v;
                dfs(v,0,0);
            }
        }
        for(auto v:vir)e[v].clear();
        vir.clear();
        top=0;
    }
    for(int i=1;i<=m;i++){
        ll u,v;
        cin>>u>>v;
        num[dfn[u]].push_back({dfn[v],i});
    }
    for(int i=1;i<=n;i++){
        for(auto v:sum[i])add(v.first,v.second);
        for(auto v:num[i])ans[v.second]=query(v.first);
    }
    for(int i=1;i<=m;i++)cout<<ans[i]<<'\n';
    return 0;
}

完结撒花!

posted @ 2025-09-07 21:44  一班的hoko  阅读(6)  评论(0)    收藏  举报