P2486 [SDOI2011]染色
分析
我们来根据操作来讨论一下,需要维护的值有什么。
将节点 a 到节点 b 的路径上的所有点(包括 a 和 b)都染成颜色 c。
很明显,我们需要维护一下tag,来保存该区间是否发生了整体被某种颜色覆盖
这并不困难,我们把眼光放到第二个操作上
询问节点 a 到节点 b 的路径上的颜色段数量。
此时,我们很明显需要维护一个sum,表示该段上不同颜色段的数量。
同时为了维护合并后区间的sum,我们需要维护两个值lc和rc分别表示左端点的颜色和右端点的颜色。
若合并区间时,左区间的右端点颜色 = 右区间的左端点颜色,则该区间的颜色段数量减1
同时需要注意的是,在从一个条链跳到另外一条链时,可能会发生颜色连续的事情,从而使答案减1
具体解决方案就是,可以在全局开一个Lc和Rc变量,用来记此时该条链的左端点颜色和右端点颜色。
同时维护两个变量ans1和ans2,用来分别统计u的上一条链的左端点颜色和v的上一条链的左端点颜色。
还需要注意的是,当top[u]==top[v],即u,v在同一条链时,因为此时区间的两个端点分别为u,v,需要分别对u,v的上一条链的左端点颜色进行对比,若相同则减1
话不多说,直接看代码
AC_code
#include<bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10,M = N*2;
struct Node
{
int l,r,lc,rc,sum,tag;
}tr[N<<2];
int h[N],e[M],ne[M],w[N],idx;
int sz[N],fa[N],son[N],dep[N];
int id[N],top[N],nw[N],ts;
int n,m,Lc,Rc;
void add(int a,int b)
{
e[idx] = b,ne[idx] = h[a],h[a] = idx++;
}
void dfs1(int u,int pa,int depth)
{
fa[u] = pa,sz[u] = 1,dep[u] = depth;
for(int i=h[u];~i;i=ne[i])
{
int j = e[i];
if(j==pa) continue;
dfs1(j,u,depth+1);
if(sz[son[u]]<sz[j]) son[u] = j;
sz[u] += sz[j];
}
}
void dfs2(int u,int tp)
{
id[u] = ++ts,nw[ts] = w[u],top[u] = tp;
if(!son[u]) return ;
dfs2(son[u],tp);
for(int i=h[u];~i;i=ne[i])
{
int j = e[i];
if(j==son[u]||j==fa[u]) continue;
dfs2(j,j);
}
}
void pushup(int u)
{
tr[u].lc = tr[u<<1].lc,tr[u].rc = tr[u<<1|1].rc;
tr[u].sum = tr[u<<1].sum + tr[u<<1|1].sum;
if(tr[u<<1].rc==tr[u<<1|1].lc) tr[u].sum--;
}
void change(Node &u,int k)
{
u.sum = u.tag = 1;
u.lc = u.rc = k;
}
void pushdown(int u)
{
if(tr[u].tag)
{
change(tr[u<<1],tr[u].lc);
change(tr[u<<1|1],tr[u].lc);
tr[u].tag = 0;
}
}
void build(int u,int l,int r)
{
if(l==r)
{
tr[u] = {l,r,nw[l],nw[l],1,0};
return ;
}
tr[u] = {l,r,nw[l],nw[r],0,0};
int mid = l + r >> 1;
build(u<<1,l,mid),build(u<<1|1,mid+1,r);
pushup(u);
}
void modify(int u,int l,int r,int k)
{
if(l<=tr[u].l&&tr[u].r<=r)
{
change(tr[u],k);
return ;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l<=mid) modify(u<<1,l,r,k);
if(r>mid) modify(u<<1|1,l,r,k);
pushup(u);
}
int query(int u,int l,int r)
{
if(l<=tr[u].l&&tr[u].r<=r)
{
if(tr[u].l==l) Lc = tr[u].lc;
if(tr[u].r==r) Rc = tr[u].rc;
return tr[u].sum;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
int res = 0,lc = -1,rc = -1;
if(l<=mid) res += query(u<<1,l,r),lc = tr[u<<1].rc;
if(r>mid) res += query(u<<1|1,l,r),rc = tr[u<<1|1].lc;
if(lc!=-1&&rc!=-1&&lc==rc) res--;
return res;
}
int main()
{
cin>>n>>m;
memset(h,-1,sizeof h);
for(int i=1;i<=n;i++) cin>>w[i];
for(int i=0;i<n-1;i++)
{
int a,b;cin>>a>>b;
add(a,b),add(b,a);
}
dfs1(1,-1,1);
dfs2(1,1);
build(1,1,n);
while(m--)
{
string op;int a,b,c;
cin>>op>>a>>b;
if(op=="C")
{
cin>>c;
while(top[a]!=top[b])
{
if(dep[top[a]]<dep[top[b]]) swap(a,b);
modify(1,id[top[a]],id[a],c);
a = fa[top[a]];
}
if(dep[a]<dep[b]) swap(a,b);
modify(1,id[b],id[a],c);
}
else
{
int res = 0,ans1 = -1,ans2 = -1;
while(top[a]!=top[b])
{
if(dep[top[a]]<dep[top[b]]) swap(a,b),swap(ans1,ans2);
res += query(1,id[top[a]],id[a]);
if(Rc==ans1) res--;
ans1 = Lc;
a = fa[top[a]];
}
if(dep[a]<dep[b]) swap(a,b),swap(ans1,ans2);
res += query(1,id[b],id[a]);
if(Lc==ans2) res--;
if(Rc==ans1) res--;
cout<<res<<endl;
}
}
return 0;
}

浙公网安备 33010602011771号