P2486 [SDOI2011]染色

P2486 [SDOI2011]染色

分析

我们来根据操作来讨论一下,需要维护的值有什么。

将节点 a 到节点 b 的路径上的所有点(包括 a 和 b)都染成颜色 c。

很明显,我们需要维护一下tag,来保存该区间是否发生了整体被某种颜色覆盖

这并不困难,我们把眼光放到第二个操作上

询问节点 a 到节点 b 的路径上的颜色段数量。

此时,我们很明显需要维护一个sum,表示该段上不同颜色段的数量。

同时为了维护合并后区间的sum,我们需要维护两个值lcrc分别表示左端点的颜色和右端点的颜色。

若合并区间时,左区间的右端点颜色 = 右区间的左端点颜色,则该区间的颜色段数量减1

同时需要注意的是,在从一个条链跳到另外一条链时,可能会发生颜色连续的事情,从而使答案减1

具体解决方案就是,可以在全局开一个LcRc变量,用来记此时该条链的左端点颜色和右端点颜色。

同时维护两个变量ans1ans2,用来分别统计u的上一条链的左端点颜色和v的上一条链的左端点颜色。

还需要注意的是,当top[u]==top[v],即u,v在同一条链时,因为此时区间的两个端点分别为u,v,需要分别对u,v的上一条链的左端点颜色进行对比,若相同则减1

话不多说,直接看代码

AC_code

#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10,M = N*2;
struct Node
{
    int l,r,lc,rc,sum,tag;
}tr[N<<2];
int h[N],e[M],ne[M],w[N],idx;
int sz[N],fa[N],son[N],dep[N];
int id[N],top[N],nw[N],ts;
int n,m,Lc,Rc;

void add(int a,int b)
{
    e[idx] = b,ne[idx] = h[a],h[a] = idx++;
}

void dfs1(int u,int pa,int depth)
{
    fa[u] = pa,sz[u] = 1,dep[u] = depth;
    for(int i=h[u];~i;i=ne[i])
    {
        int j = e[i];
        if(j==pa) continue;
        dfs1(j,u,depth+1);
        if(sz[son[u]]<sz[j]) son[u] = j;
        sz[u] += sz[j];
    }
}

void dfs2(int u,int tp)
{
    id[u] = ++ts,nw[ts] = w[u],top[u] = tp;
    if(!son[u]) return ;
    dfs2(son[u],tp);
    for(int i=h[u];~i;i=ne[i])
    {
        int j = e[i];
        if(j==son[u]||j==fa[u]) continue;
        dfs2(j,j);
    }
}

void pushup(int u)
{
    tr[u].lc = tr[u<<1].lc,tr[u].rc = tr[u<<1|1].rc;
    tr[u].sum = tr[u<<1].sum + tr[u<<1|1].sum;
    if(tr[u<<1].rc==tr[u<<1|1].lc) tr[u].sum--;
}


void change(Node &u,int k)
{
    u.sum = u.tag = 1;
    u.lc = u.rc = k;
}

void pushdown(int u)
{
    if(tr[u].tag)
    {
        change(tr[u<<1],tr[u].lc);
        change(tr[u<<1|1],tr[u].lc);
        tr[u].tag = 0;
    }
}

void build(int u,int l,int r)
{
    if(l==r)
    {
        tr[u] = {l,r,nw[l],nw[l],1,0};
        return ;
    }
    tr[u] = {l,r,nw[l],nw[r],0,0};
    int mid = l + r >> 1;
    build(u<<1,l,mid),build(u<<1|1,mid+1,r);
    pushup(u);
}

void modify(int u,int l,int r,int k)
{
    if(l<=tr[u].l&&tr[u].r<=r)
    {
        change(tr[u],k);
        return ;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    if(l<=mid) modify(u<<1,l,r,k);
    if(r>mid) modify(u<<1|1,l,r,k);
    pushup(u);
}

int query(int u,int l,int r)
{
    if(l<=tr[u].l&&tr[u].r<=r) 
    {
        if(tr[u].l==l) Lc = tr[u].lc;
        if(tr[u].r==r) Rc = tr[u].rc;
        return tr[u].sum;
    }
    pushdown(u);
    int mid = tr[u].l + tr[u].r >> 1;
    int res = 0,lc = -1,rc = -1;
    if(l<=mid) res += query(u<<1,l,r),lc = tr[u<<1].rc;
    if(r>mid) res += query(u<<1|1,l,r),rc = tr[u<<1|1].lc;
    if(lc!=-1&&rc!=-1&&lc==rc) res--;
    return res;
}

int main()
{
    cin>>n>>m;
    memset(h,-1,sizeof h);
    for(int i=1;i<=n;i++) cin>>w[i];
    for(int i=0;i<n-1;i++)
    {
        int a,b;cin>>a>>b;
        add(a,b),add(b,a);
    }
    dfs1(1,-1,1);
    dfs2(1,1);
    build(1,1,n);
    while(m--)
    {
        string op;int a,b,c;
        cin>>op>>a>>b;
        if(op=="C")
        {
            cin>>c;
            while(top[a]!=top[b])
            {
                if(dep[top[a]]<dep[top[b]]) swap(a,b);
                modify(1,id[top[a]],id[a],c);
                a = fa[top[a]];
            }
            if(dep[a]<dep[b]) swap(a,b);
            modify(1,id[b],id[a],c);
        }
        else 
        {
            int res = 0,ans1 = -1,ans2 = -1;
            while(top[a]!=top[b])
            {
                if(dep[top[a]]<dep[top[b]]) swap(a,b),swap(ans1,ans2);
                res += query(1,id[top[a]],id[a]);
                if(Rc==ans1) res--;
                ans1 = Lc;
                a = fa[top[a]];
            }
            if(dep[a]<dep[b]) swap(a,b),swap(ans1,ans2);
            res += query(1,id[b],id[a]);
            if(Lc==ans2) res--;
            if(Rc==ans1) res--;
            cout<<res<<endl;
        }
    }
    return 0;
}
posted @ 2022-03-20 17:32  艾特玖  阅读(57)  评论(0)    收藏  举报